In [3]:
import numpy as np
import jax
from jax import numpy as jnp 

from pymdp.jax import Distribution
from pymdp.jax.agent import Agent

np.set_printoptions(precision=2, suppress=True)

observations = ["A", "B", "C", "D"]
states = ["A", "B", "C", "D"]
controls = ["up", "down"]

data = np.zeros((len(observations), len(states)))
A = Distribution(data, {"observations": observations}, {"states": states})

A["A", "A"] = 1.0 
A["B", "B"] = 1.0
A["C", "C"] = 1.0
A["D", "D"] = 1.0

data = np.zeros((len(states), len(states), len(controls)))
B = Distribution(data, {"states": states}, {"states": states, "controls": controls})

B["B", "A", "up"] = 1.0
B["C", "B", "up"] = 1.0
B["D", "C", "up"] = 1.0
B["D", "D", "up"] = 1.0

B["A", "A", "down"] = 1.0
B["A", "B", "down"] = 1.0
B["B", "C", "down"] = 1.0
B["C", "D", "down"] = 1.0

policies = jnp.expand_dims(jnp.array([[0, 0, 0, 0], [1, 1, 1, 1]]), -1)

C = jnp.zeros((1, 4))
C = C.at[0, 0].set(1.0)

agent = Agent([A], [B], [C], policies=policies)

action = jnp.array([[0]])
qs = [jnp.zeros((1, 1, 4))]
qs[0] = qs[0].at[0, 0, 0].set(1.0)

observation = [jnp.array([[0]])]
prior, _ = agent.infer_empirical_prior(action, qs)

qs = agent.infer_states(observation, None, prior, None)
print(qs)

q_pi, G = agent.infer_policies(qs)
print(q_pi)
key = jax.random.PRNGKey(0)
action = agent.sample_action(q_pi)
print(action)

[Array([[[0., 0., 0., 1.]]], dtype=float32)]
[[0. 1.]]
[[1]]
