### Imports

In [10]:
%load_ext autoreload
%autoreload 2

from jax import nn, vmap, lax, jit
from jax import numpy as jnp, random as jr
from jax import tree_util as jtu
from jax.scipy.special import gammaln, digamma

import numpy as np

from pymdp.envs import GridWorld, rollout
from pymdp.agent import Agent

import matplotlib.pyplot as plt
import seaborn as sns


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Grid world generative model
 
Here we will explore learning of the generative model inside a simple grid 7x7 world environment, where at each state agent can move into 4 possible directions. The agent can explore the environment for 100 time steps, after which it is reaturned to the original position. We will start first with an example where likelihood is fixed, and state transitions are unkown. Next we will explore the example where likelihood is unknown but the state transitions are known, and finally we will look at learning under joint uncertainty over likelihood and transitions (we would expect this case not to work in general with flat priors on both components). 


In [11]:
# size of the grid world
grid_shape = (7, 7)

# number of agents
batch_size = 20

env = GridWorld(shape=grid_shape, include_stay=False, batch_size=batch_size)

### Define KL divergence between Dirichlet distributions

In [12]:
@jit
def kl_div_dirichlet(alpha1, alpha2):
    alpha0 = alpha1.sum(1, keepdims=True)
    kl = gammaln(alpha0.squeeze(1)) - gammaln(alpha2.sum(1))
    kl += jnp.sum(gammaln(alpha2) - gammaln(alpha1) + (alpha1 - alpha2) * (digamma(alpha1) - digamma(alpha0)), 1)

    return kl

### Initialize different sets of agents using `Agent()` class, use a different `alpha` (action selection temperature) for each set

In [52]:
# create agent with A matrix being fixed to the A of the generative process
num_obs = [a.shape[1] for a in env.params['A']]
num_states = [b.shape[-2] for b in env.params['B']]

_A = [jnp.array(a) for a in env.params['A']]
C = [jnp.zeros((batch_size, num_obs[0]))]
pB = [jnp.ones_like(env.params['B'][0]) / num_states[0]]
_B = jtu.tree_map(lambda x: x / x.sum(1, keepdims=True), pB)
_D = [jnp.ones((batch_size, num_states[0]))]

agents = []
for i in range(5):
    agents.append( 
        Agent(
            _A,
            _B,
            C,
            _D,
            E=None,
            pA=None,
            pB=pB,
            policy_len=3,
            use_utility=False,
            use_states_info_gain=True,
            use_param_info_gain=True,
            gamma=jnp.ones(batch_size),
            alpha=jnp.ones(batch_size) * i * .2,
            onehot_obs=False,
            action_selection="stochastic",
            inference_algo="ovf",
            num_iter=1,
            learn_A=False,
            learn_B=False,
            learn_D=False,
            batch_size=batch_size,
        )
    )

### Run active inference for each batch of agents with different `alpha` values, and do so for multiple independent blocks or trials (different initial conditions for each block)


In [None]:
pB_ground_truth = 1e4 * env.params['B'][0] + 1e-4
num_timesteps = 50
num_blocks = 40
key = jr.PRNGKey(0)
block_and_batch_keys = jr.split(key, num_blocks * (batch_size+1)).reshape((num_blocks, batch_size+1, -1))
divs = {i : [] for i in range(len(agents))}
for block in range(num_blocks):
    block_keys = block_and_batch_keys[block]
    for i, agent in enumerate(agents):
        _, env = env.reset(block_keys[:-1])

        last, info, env = jit(rollout, static_argnums=[2,] )(agent, env, num_timesteps, block_keys[-1])
        env = last['env']

        beliefs = info['qs']
        actions = info['actions']
        outcomes = info['observations']

        agents[i] = agent.infer_parameters(beliefs, outcomes, actions)
        divs[i].append(kl_div_dirichlet(last['agent'].pB[0], pB_ground_truth).sum(-1).mean(-1))

KeyboardInterrupt: 

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(10, 5), sharex=True, sharey=True)
for i in range(len(agents)):
    p = axes.plot(jnp.stack(divs[i]).mean(-1), lw=3, label=agents[i].alpha.mean())
    axes.plot(jnp.stack(divs[i]), color=p[0].get_color(), alpha=.2)

axes.legend(title='alpha')
axes.set_ylabel('KL divergence')
axes.set_xlabel('epoch')
fig.tight_layout()

In [None]:
fig, axes = plt.subplots(3, 5, figsize=(16, 8), sharex=True, sharey=True)

for i in range(num_controls[0]):
    for j, agent in enumerate(agents[:2]):
        sns.heatmap(agent.B[0][0, ..., i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.)
    sns.heatmap(B[0][0, ..., i], ax=axes[2, i], cmap='viridis', vmax=1., vmin=0.)

fig.tight_layout()

In [None]:
# create agent with B matrix being fixed to the B of the generative process
C = [jnp.zeros((n_batches, num_obs[0]))]
pA = [jnp.ones_like(A[0]) / num_obs[0]]
_A = jtu.tree_map(lambda x: x / x.sum(1, keepdims=True), pA)

agents = []
for i in range(5):
    agents.append( 
        AIFAgent(
            _A,
            B,
            C,
            D,
            E=None,
            pA=pA,
            pB=None,
            policy_len=3,
            use_utility=False,
            use_states_info_gain=True,
            use_param_info_gain=True,
            gamma=jnp.ones(1),
            alpha=jnp.ones(1) * i * .2,
            onehot_obs=False,
            action_selection="stochastic",
            inference_algo="ovf",
            num_iter=1,
            learn_A=True,
            learn_B=False,
            learn_D=False,
            batch_size=n_batches,
        )
    )

In [None]:
from pymdp.jax.inference import smoothing_ovf

key, _key = jr.split(key)
grid_world = grid_world.reset(_key)

key, _key = jr.split(key)
last, info = jit(rollout, static_argnums=[3,] )(_key, agents[0], grid_world, num_timesteps)

beliefs = info['qs']
actions = info['actions']
smoothed_marginals_and_joints = vmap(smoothing_ovf)(beliefs, agents[0].B, actions)

In [None]:
pA0 = 1e4 * A[0] + 1e-4
num_timesteps = 50
num_blocks = 20
key = jr.PRNGKey(0)
divs = {i: [] for i in range(len(agents))}
for block in range(num_blocks):
    for i, agent in enumerate(agents):
        key, _key = jr.split(key)
        grid_world = grid_world.reset(_key)

        key, _key = jr.split(key)
        last, info = jit(rollout, static_argnums=[3,] )(_key, agent, grid_world, num_timesteps)
        grid_world = last['env']

        beliefs = info['qs']
        actions = info['actions']
        outcomes = info['observations']

        agents[i] = agent.infer_parameters(beliefs, outcomes, actions)
        divs[i].append(kl_div_dirichlet(agents[i].pA[0], pA0).mean(-1))

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(10, 5), sharex=True, sharey=True)
for i in range(len(agents)):
    p = axes.plot(jnp.stack(divs[i]).mean(-1), lw=3, label=agents[i].alpha.mean())
    axes.plot(jnp.stack(divs[i]), color=p[0].get_color(), alpha=.2)

axes.legend(title='alpha')
axes.set_ylabel('KL divergence')
axes.set_xlabel('epoch')
fig.tight_layout()

In [None]:
fig, axes = plt.subplots(3, 5, figsize=(16, 8), sharex=True, sharey=True)

for i in range(5):
    for j, agent in enumerate(agents[:2]):
        sns.heatmap(agent.A[0][i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.)
    sns.heatmap(A[0][i], ax=axes[2, i], cmap='viridis', vmax=1., vmin=0.)

    axes[0, i].set_title(f'batch={i+1}')

fig.tight_layout()

In [None]:
# create agent with B matrix being fixed to the B of the generative process but flat beliefs over initial states
C = [jnp.zeros((n_batches, num_obs[0]))]
pA = [jnp.ones_like(A[0]) / num_obs[0]]
_A = jtu.tree_map(lambda x: x / x.sum(1, keepdims=True), pA)

agents = []
for i in range(5):
    agents.append( 
        AIFAgent(
            _A,
            B,
            C,
            _D,
            E=None,
            pA=pA,
            pB=None,
            policy_len=3,
            use_utility=False,
            use_states_info_gain=True,
            use_param_info_gain=True,
            gamma=jnp.ones(1),
            alpha=jnp.ones(1) * i * .2,
            onehot_obs=False,
            action_selection="stochastic",
            inference_algo="ovf",
            num_iter=1,
            learn_A=True,
            learn_B=False,
            learn_D=False,
            batch_size=n_batches,
        )
    )

In [None]:
pA0 = 1e4 * A[0] + 1e-4
num_timesteps = 50
num_blocks = 20
key = jr.PRNGKey(0)
divs = {i: [] for i in range(len(agents))}
for block in range(num_blocks):
    for i, agent in enumerate(agents):
        key, _key = jr.split(key)
        grid_world = grid_world.reset(_key)

        key, _key = jr.split(key)
        last, info = jit(rollout, static_argnums=[3,] )(_key, agent, grid_world, num_timesteps)
        grid_world = last['env']

        beliefs = info['qs']
        actions = info['actions']
        outcomes = info['observations']

        agents[i] = agent.infer_parameters(beliefs, outcomes, actions)
        divs[i].append(kl_div_dirichlet(agents[i].pA[0], pA0).mean(-1))

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(10, 5), sharex=True, sharey=True)
for i in range(len(agents)):
    p = axes.plot(jnp.stack(divs[i]).mean(-1), lw=3, label=agents[i].alpha.mean())
    axes.plot(jnp.stack(divs[i]), color=p[0].get_color(), alpha=.2)

axes.legend(title='alpha')
axes.set_ylabel('KL divergence')
axes.set_xlabel('epoch')
fig.tight_layout()

In [None]:
fig, axes = plt.subplots(3, 5, figsize=(16, 8), sharex=True, sharey=True)

for i in range(5):
    for j, agent in enumerate(agents[:2]):
        sns.heatmap(agent.A[0][i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.)
    sns.heatmap(A[0][i], ax=axes[2, i], cmap='viridis', vmax=1., vmin=0.)

    axes[0, i].set_title(f'batch={i+1}')

fig.tight_layout()

In [None]:
# create agent with B matrix being fixed to the B of the generative process but flat beliefs over initial states
C = [jnp.zeros((n_batches, num_obs[0]))]
pA = [jnp.ones_like(A[0]) / num_obs[0]]
_A = jtu.tree_map(lambda x: x / x.sum(1, keepdims=True), pA)
pB = [jnp.ones_like(B[0]) / num_states[0]]
_B = jtu.tree_map(lambda x: x / x.sum(1, keepdims=True), pB)

agents = []
for i in range(5):
    agents.append( 
        AIFAgent(
            _A,
            _B,
            C,
            D,
            E=None,
            pA=pA,
            pB=pB,
            policy_len=3,
            use_utility=False,
            use_states_info_gain=True,
            use_param_info_gain=True,
            gamma=jnp.ones(1),
            alpha=jnp.ones(1) * i * .2,
            onehot_obs=False,
            action_selection="stochastic",
            inference_algo="ovf",
            num_iter=1,
            learn_A=True,
            learn_B=False,
            learn_D=False,
            batch_size=n_batches,
        )
    )

In [None]:
pA0 = 1e4 * A[0] + 1e-4
pB0 = 1e4 * B[0] + 1e-4
num_timesteps = 50
num_blocks = 100
key = jr.PRNGKey(0)
divs1 = {i: [] for i in range(len(agents))}
divs2 = {i: [] for i in range(len(agents))}
for block in range(num_blocks):
    for i, agent in enumerate(agents):
        key, _key = jr.split(key)
        grid_world = grid_world.reset(_key)

        key, _key = jr.split(key)
        last, info = jit(rollout, static_argnums=[3,] )(_key, agent, grid_world, num_timesteps)
        grid_world = last['env']

        beliefs = info['qs']
        actions = info['actions']
        outcomes = info['observations']

        agents[i] = agent.infer_parameters(beliefs, outcomes, actions)
        divs1[i].append(kl_div_dirichlet(agents[i].pA[0], pA0).mean(-1))
        divs2[i].append(kl_div_dirichlet(agents[i].pB[0], pB0).sum(-1).mean(-1))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=False)
for i in range(len(agents)):
    p = axes[0].plot(jnp.stack(divs1[i]).mean(-1), lw=3, label=agents[i].alpha.mean())
    axes[0].plot(jnp.stack(divs1[i]), color=p[0].get_color(), alpha=.2)

    p = axes[1].plot(jnp.stack(divs2[i]).mean(-1), lw=3, label=agents[i].alpha.mean())
    axes[1].plot(jnp.stack(divs2[i]), color=p[0].get_color(), alpha=.2)

axes[0].legend(title='alpha')
axes[0].set_ylabel('KL divergence')
axes[0].set_xlabel('epoch')
axes[1].set_xlabel('epoch')
axes[0].set_title('A matrix')
axes[1].set_title('B matrix')
fig.tight_layout()

In [None]:
fig, axes = plt.subplots(3, 5, figsize=(16, 8), sharex=True, sharey=True)

for i in range(5):
    for j, agent in enumerate(agents[:2]):
        sns.heatmap(agent.A[0][i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.)
    sns.heatmap(A[0][i], ax=axes[2, i], cmap='viridis', vmax=1., vmin=0.)

    axes[0, i].set_title(f'batch={i+1}')

fig.tight_layout()

In [None]:
fig, axes = plt.subplots(3, 5, figsize=(16, 8), sharex=True, sharey=True)

for i in range(num_controls[0]):
    for j, agent in enumerate(agents[:2]):
        sns.heatmap(agent.B[0][0, ..., i], ax=axes[j, i], cmap='viridis', vmax=1., vmin=0.)
    sns.heatmap(B[0][0, ..., i], ax=axes[2, i], cmap='viridis', vmax=1., vmin=0.)
    axes[0, i].set_title(f'action {i+1}')

fig.tight_layout()

In [None]:
# create agent with B matrix being fixed to the B of the generative process but flat beliefs over initial states
C = [jnp.zeros((n_batches, num_obs[0]))]
pA = [jnp.ones_like(A[0]) / num_obs[0]]
_A = jtu.tree_map(lambda x: x / x.sum(1, keepdims=True), pA)
tmpB = jnp.clip(B[0].sum(-1), max=1)
pB = [jnp.expand_dims(tmpB, -1) + jnp.ones_like(B[0]) / num_states[0]]
_B = jtu.tree_map(lambda x: x / x.sum(1, keepdims=True), pB)

agents = []
for i in range(5):
    agents.append( 
        AIFAgent(
            _A,
            _B,
            C,
            _D,
            E=None,
            pA=pA,
            pB=pB,
            policy_len=3,
            use_utility=False,
            use_states_info_gain=True,
            use_param_info_gain=True,
            gamma=jnp.ones(1),
            alpha=jnp.ones(1) * i * .2,
            onehot_obs=False,
            action_selection="stochastic",
            inference_algo="ovf",
            num_iter=1,
            learn_A=True,
            learn_B=False,
            learn_D=False,
            batch_size=n_batches,
        )
    )

In [None]:
pA0 = 1e4 * A[0] + 1e-4
pB0 = 1e4 * B[0] + 1e-4
num_timesteps = 50
num_blocks = 100
key = jr.PRNGKey(0)
divs1 = {i: [] for i in range(len(agents))}
divs2 = {i: [] for i in range(len(agents))}
for block in range(num_blocks):
    for i, agent in enumerate(agents):
        key, _key = jr.split(key)
        grid_world = grid_world.reset(_key)

        key, _key = jr.split(key)
        last, info = jit(rollout, static_argnums=[3,] )(_key, agent, grid_world, num_timesteps)
        grid_world = last['env']

        beliefs = info['qs']
        actions = info['actions']
        outcomes = info['observations']

        agents[i] = agent.infer_parameters(beliefs, outcomes, actions)
        divs1[i].append(kl_div_dirichlet(agents[i].pA[0], pA0).mean(-1))
        divs2[i].append(kl_div_dirichlet(agents[i].pB[0], pB0).sum(-1).mean(-1))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=False)
for i in range(len(agents)):
    p = axes[0].plot(jnp.stack(divs1[i]).mean(-1), lw=3, label=agents[i].alpha.mean())
    axes[0].plot(jnp.stack(divs1[i]), color=p[0].get_color(), alpha=.2)

    p = axes[1].plot(jnp.stack(divs2[i]).mean(-1), lw=3, label=agents[i].alpha.mean())
    axes[1].plot(jnp.stack(divs2[i]), color=p[0].get_color(), alpha=.2)

axes[0].legend(title='alpha')
axes[0].set_ylabel('KL divergence')
axes[0].set_xlabel('epoch')
axes[1].set_xlabel('epoch')
axes[0].set_title('A matrix')
axes[1].set_title('B matrix')
fig.tight_layout()