In [21]:
import numpy as np
import jax
from jax import numpy as jnp 

from pymdp.jax import Distribution
from pymdp.jax.agent import Agent

np.set_printoptions(precision=2, suppress=True)

observations = ["A", "B", "C", "D"]
states = ["A", "B", "C", "D"]
controls = ["up", "down"]

data = np.zeros((len(observations), len(states)))
likelihood = Distribution(data, {"observations": observations}, {"states": states})

likelihood["A", "A"] = 1.0 
likelihood["B", "B"] = 1.0
likelihood["C", "C"] = 1.0
likelihood["D", "D"] = 1.0

data = np.zeros((len(states), len(states), len(controls)))
transition = Distribution(data, {"states": states}, {"states": states, "controls": controls})

transition["B", "A", "up"] = 1.0
transition["C", "B", "up"] = 1.0
transition["D", "C", "up"] = 1.0
transition["D", "D", "up"] = 1.0

transition["A", "A", "down"] = 1.0
transition["A", "B", "down"] = 1.0
transition["B", "C", "down"] = 1.0
transition["C", "D", "down"] = 1.0

A = [jnp.broadcast_to(likelihood.data, (1,) + likelihood.data.shape)]
B = [jnp.broadcast_to(transition.data, (1,) + transition.data.shape)]


C = [jnp.zeros((1, 4))]
C[0] = C[0].at[0, 0].set(1.0)
D = jnp.ones((1, 4)) / 8.0
E = jnp.ones((1, 2)) / 4.0

policies = jnp.expand_dims(jnp.array([[0, 0, 0, 0], [1, 1, 1, 1]]), -1)


agent = Agent(A, B, C, D, E, A, B, policies=policies)

observation = [jnp.array([[3]])]
action = jnp.array([[0]])

qs = [jnp.zeros((1, 1, 4))]
qs[0] = qs[0].at[0, 0, 3].set(1.0)

prior, _ = agent.update_empirical_prior(action, qs)


qs = agent.infer_states(observation, None, prior, None)
print(qs)

q_pi, G = agent.infer_policies(qs)
print(q_pi)
key = jax.random.PRNGKey(0)
action = agent.sample_action(q_pi)
print(action)

[Array([[[0., 0., 0., 1.]]], dtype=float32)]
[[0. 1.]]
[[1]]
