In [1]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import random as jr
from pymdp.jax.agent import Agent as AIFAgent
from pymdp.utils import random_A_matrix, random_B_matrix

In [7]:
def scan(f, init, xs, length=None, unroll=1):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    for _ in range(unroll):
      carry, y = f(carry, x)
    if y is not None:
       ys.append(y)
  
  ys = None if len(ys) < 1 else jtu.tree_map(lambda *x: jnp.stack(x), *ys)

  return carry, ys

def evolve_trials(agent, env, block_idx, num_timesteps):

    def step_fn(carry, xs):
        actions = carry['actions']
        outcomes = carry['outcomes']
        beliefs = agent.infer_states(outcomes, actions, *carry['args'])
        # q_pi, _ = agent.infer_policies(beliefs)
        q_pi = jnp.ones((batch_size, len(agent.policies)))/len(agent.policies)
        actions_t = agent.sample_action(q_pi)

        outcome_t = env.step(actions_t)
        outcomes = jtu.tree_map(
           lambda prev_o, new_o: jnp.concatenate([prev_o, jnp.expand_dims(new_o, -1)], -1), outcomes, outcome_t
          )

        if actions is not None:
          actions = jnp.concatenate([actions, jnp.expand_dims(actions_t, -2)], -2)
        else:
          actions = jnp.expand_dims(actions_t, -2)

        # args = agent.update_empirical_prior(actions_t, beliefs)
        args = (jtu.tree_map( lambda x: x[:, -1], beliefs), beliefs)  
        
        # args = (pred_{t+1}, [post_1, post_{2}, ..., post_{t}])
        # beliefs =  [post_1, post_{2}, ..., post_{t}]
        return {'args': args, 'outcomes': outcomes, 'beliefs': beliefs, 'actions': actions}, None

    outcome_0  = jtu.tree_map(lambda x: jnp.expand_dims(x, -1), env.step())
    init = {
       'args': (agent.D, None,),
       'outcomes': outcome_0, 
       'beliefs': [],
       'actions': None
    }
    last, _ = scan(step_fn, init, range(num_timesteps))

    return last, env

def step_fn(carry, block_idx):
    agent, env = carry
    output, env = evolve_trials(agent, env, block_idx, num_timesteps)
    output.pop('args')          

    # How to deal with contiguous blocks of trials? Two options we can imagine: 
    # A) you use final posterior (over current and past timesteps) to compute the smoothing distribution over qs_{t=0} and update pD, and then pass pD as the initial state prior ($D = \mathbb{E}_{pD}[qs_{t=0}]$);
    # B) we don't assume that blocks 'reset time', and are really just adjacent chunks of one long sequence, so you set the initial state prior to be the final output (`output['beliefs']`) passed through
    # the transition model entailed by the action taken at the last timestep of the previous block.
    
    # agent = agent.learning(**output)
    
    return (agent, env), output

# define an agent and environment here
batch_size = 10
num_obs = [3, 3]
num_states = [3, 3]
num_controls = [2, 2]
num_blocks = 2
num_timesteps = 5

A_np = random_A_matrix(num_obs=num_obs, num_states=num_states)
B_np = random_B_matrix(num_states=num_states, num_controls=num_controls)
A = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(A_np))
B = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(B_np))
C = [jnp.zeros((batch_size, no)) for no in num_obs]
D = [jnp.ones((batch_size, ns)) / ns for ns in num_states]
E = jnp.ones((batch_size, 4 )) / 4 

class TestEnv:
    def __init__(self, num_obs, prng_key=jr.PRNGKey(0)):
      self.num_obs=num_obs
      self.key = prng_key
    def step(self, actions=None):
      # return a list of random observations for each agent or parallel realization (each entry in batch_dim)
      obs = [jr.randint(self.key, (batch_size,), 0, no) for no in self.num_obs]
      self.key, _ = jr.split(self.key)
      return obs

agents = AIFAgent(A, B, C, D, E, inference_algo='mmp')
env = TestEnv(num_obs)

init = (agents, env)
(agents, env), sequences = scan(step_fn, init, range(num_blocks) )

sequences = jtu.tree_map(lambda x: x.swapaxes(1, 2), sequences)

# NOTE: all elements of sequences will have dimensionality blocks, trials, batch_size, ...


In [4]:
# def scan(f, init, xs, length=None):
#   if xs is None:
#     xs = [None] * length
#   carry = init
#   ys = []
#   for x in xs:
#     carry, y = f(carry, x)
#     ys.append(y)
  
#   return carry, jnp.stack(ys)

# def evolve_trials(agent, env, block_idx, num_timesteps):

#     def step_fn(carry, xs):
#         actions = carry['actions']
#         outcomes = carry['outcomes']
#         beliefs = agent.infer_states(outcomes, actions, *carry['args'])
#         q_pi, _ = agent.infer_policies(beliefs)
#         actions_t = agent.sample_action(q_pi)

#         outcome_t = env.step(actions_t)
#         outcomes = jtu.tree_map(lambda prev_o, new_o: jnp.stack([prev_o, jnp.expand_dims(new_o, 0)], 0), outcomes, outcome_t)

#         actions = jnp.stack([actions, jnp.expand_dims(actions_t, 0)], 0) if actions is not None else actions_t
#         args = agent.update_empirical_prior(actions_t, beliefs)
#         # (pred, [cond_1, ..., cond_{t-1}])

#         # ovf beliefs = (post_T, [cond_1, cond_2, ..., cond_{T-1}])
#         # else beliefs = (post_T, post_{T-1}, ..., post_1)
#         return {'args': args, 'outcomes': outcomes, 'beliefs': beliefs, 'actions': actions}, None

#     outcome_0  = env.step()
#     init = ((agent.D, None), outcome_0, None, None)
#     last, _ = scan(step_fn, init, range(num_timesteps))

#     return last, env

# def step_fn(carry, block_idx):
#     agent, env = carry
#     output, env = evolve_trials(agent, env, block_idx, num_timesteps)

#     # How to deal with contiguous blocks of trials? Two options we can imagine: 
#     # A) you use final posterior (over current and past timesteps) to compute the smoothing distribution over qs_{t=0} and update pD, and then pass pD as the initial state prior ($D = \mathbb{E}_{pD}[qs_{t=0}]$);
#     # B) we don't assume that blocks 'reset time', and are really just adjacent chunks of one long sequence, so you set the initial state prior to be the final output (`output['beliefs']`) passed through
#     # the transition model entailed by the action taken at the last timestep of the previous block.
    
#     agent = agent.learning(**output)
    
#     return (agent, env), output

# init = (agent, env)
# agent, squences = scan(step_fn, init, range(num_blocks) )