In [None]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import random as jr
from pymdp.jax.agent import Agent as AIFAgent

In [None]:
def scan(f, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  
  return carry, jnp.stack(ys)

def evolve_trials(agent, env, num_timesteps):

    def step_fn(carry, xs):
        actions = carry['actions']
        outcomes = carry['outcomes']
        beliefs = agent.infer_states(outcomes, actions, *carry['args'])
        q_pi, _ = agent.infer_policies(beliefs)
        actions_t = agent.sample_action(q_pi)

        outcome_t = env.step(actions_t)
        outcomes = jtu.tree_map(lambda prev_o, new_o: jnp.stack([prev_o, jnp.expand_dims(new_o, 0)], 0), outcomes, outcome_t)

        actions = jnp.stack([actions, jnp.expand_dims(actions_t, 0)], 0) if actions is not None else actions_t
        args = agent.update_empirical_prior(actions_t, beliefs)
        # (pred, [cond_1, ..., cond_{t-1}])

        # ovf beliefs = (post_T, [cond_1, cond_2, ..., cond_{T-1}])
        # else beliefs = (post_T, post_{T-1}, ..., post_1)
        return {'args': args, 'outcomes': outcomes, 'beliefs': beliefs, 'actions': actions}, None

    outcome_0  = env.step()
    init = ((agent.D, None), outcome_0, None, None)
    last, _ = scan(step_fn, init, range(num_timesteps))

    return last

def step_fn(carry, b):
    agent = carry
    output = evolve_trials(agent, b)

    # How to deal with contiguous blocks of trials? Two options we can imagine: 
    # A) you use final posterior (over current and past timesteps) to compute the smoothing distribution over qs_{t=0} and update pD, and then pass pD as the initial state prior ($D = \mathbb{E}_{pD}[qs_{t=0}]$);
    # B) we don't assume that blocks 'reset time', and are really just adjacent chunks of one long sequence, so you set the initial state prior to be the final output (`output['beliefs']`) passed through
    # the transition model entailed by the action taken at the last timestep of the previous block.
    
    agent = agent.learning(**output)
    
    return agent, output

init_agent = agent
agent, squences = scan(step_fn, init_agent, range(num_blocks) )