In [21]:
# a way to edit and run code and see the effects in the notebook without having to restart the kernel
%load_ext autoreload
%autoreload 2

# importing necessary libraries
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import mediapy

from jax import random as jr
from pymdp.envs import TMaze, rollout
from pymdp.agent import Agent
from PIL import Image

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
# setting the parameters for the environment
batch_size = 1 # number of environments to run in parallel
reward_condition = None # 0 is reward in left arm, 1 is reward in right arm, None is random allocation
reward_probability = 1.0 # 100% chance of reward in the correct arm
punishment_probability = 1.0 # 100% chance of punishment in the other arm
cue_validity = 1.0 # 100% valid cues

# initialising the environment. see tmaze.py in pymdp/envs for the implementation details.
env = TMaze( 
    batch_size=batch_size, 
    reward_probability=reward_probability,     
    punishment_probability=punishment_probability, 
    cue_validity=cue_validity,          
    reward_condition=reward_condition   
)

# you may print the environment parameters to see the shapes of the tensors and the values by editing and uncommenting the following lines and running the code: 

# print([a.shape for a in env.params["A"]]) # shape of all A tensors; the shape should start with the batch_size, then the rows, columns, and additional dimensions for the dependencies
# print(env.params["A"][1][0][:,:,1]) # likelihood of observing no outcome, reward, or punishment (rows), in each location (columns), when the reward condition is 1 (right arm)
# print(env.params["A"][2][0][:,:,0]) # likelihood of observing no cue, left arm cued, or right arm cued (rows), in each location (columns), when the reward condition is 0 (left arm)

# print([b.shape for b in env.params["B"]]) # shape of all B tensors
# print(env.params["B"][0][0][:,:,4]) # probability of transitioning to each location (rows), from each location (columns), when the agent wants to move to the middle of the arms (location 4)

In [23]:
key = jr.PRNGKey(0)  # initialize random key

A = []
shapes = [(batch_size, 5, 5, 2), (batch_size, 3, 5, 2), (batch_size, 3, 5, 2)]

for shape in shapes:
    key, subkey = jr.split(key)
    random_a = jr.uniform(subkey, shape=shape)
    random_a = random_a / jnp.sum(random_a, axis=1, keepdims=True)
    A.append(jnp.array(random_a, dtype=jnp.float32))

A_dependencies = [[0, 1], [0, 1], [0, 1]]
pA = A

# setting B tensors from the environment parameters
B = [jnp.array(b, dtype=jnp.float32) for b in env.params["B"]]
B_dependencies = env.dependencies["B"]

# creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes
C = [jnp.zeros((batch_size, a.shape[1]), dtype=jnp.float32) for a in A] 
# setting preferences for outcomes only
C[1] = C[1].at[:,1].set(2.0)    # prefer reward
C[1] = C[1].at[:,2].set(-3.0)   # avoid punishment


# creating D tensors [location], [reward] based on B shapes
D = []
# D[0]: location - all zeros except location 0 (centre) because the agent always starts in the centre
D_loc = jnp.zeros((batch_size, B[0].shape[1]), dtype=jnp.float32) 
D_loc = D_loc.at[0,0].set(1.0)  # set centre location to 1.0
D.append(D_loc)

# D[1]: reward location - uniform distribution
D_reward = jnp.ones((batch_size, B[1].shape[1]), dtype=jnp.float32) 
D_reward = D_reward / jnp.sum(D_reward, axis=1, keepdims=True)  # normalise to get uniform distribution
D.append(D_reward)


# initialising the agent
agent = Agent(
    A, B, C, D, 
    pA=pA,
    # pB=pB, # adding the noisy A tensor for learning
    policy_len=5, # how long the action sequence is that the agent is evaluating
    A_dependencies=A_dependencies, 
    B_dependencies=B_dependencies,
    apply_batch=False, 
    learn_A=True,
    learn_B=False,
    gamma=0.1,
    action_selection="stochastic"
)

In [24]:
key = jr.PRNGKey(0) # random key for the aif loop
T = 50 # number of timesteps to rollout the aif loop for
_, info, _ = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop


In [25]:
print("at first timestep")
print(agent.A[1][0][:,:,1])
print()
print(agent.A[2][0][:,:,1])
print()
print("at last timestep")
print(info["agent"].A[1][-1,0,:,:,1])
print()
print(info["agent"].A[2][-1,0,:,:,1])
print()
print("the environment's A tensor")
print(env.params["A"][1][0][:,:,1])
print(env.params["A"][2][0][:,:,1])

at first timestep
[[0.5509876  0.31461224 0.2522971  0.39348188 0.59899706]
 [0.25669733 0.23397164 0.41215912 0.45787704 0.2677251 ]
 [0.19231506 0.4514161  0.33554378 0.14864108 0.13327783]]

[[0.384104   0.28346223 0.67107844 0.16076607 0.56186813]
 [0.19267914 0.2668223  0.04772528 0.4034486  0.01521517]
 [0.42321685 0.44971547 0.2811963  0.43578532 0.4229167 ]]

at last timestep
[[0.9647639  0.02103086 0.03604244 0.79782724 0.9749153 ]
 [0.02014425 0.94879335 0.05887987 0.15262568 0.01674751]
 [0.01509187 0.03017577 0.90507764 0.04954703 0.00833718]]

[[9.5166773e-01 9.5210171e-01 9.5301121e-01 5.3588688e-02 9.7259277e-01]
 [1.5120438e-02 1.7836249e-02 6.8178973e-03 8.0114955e-01 9.5178297e-04]
 [3.3211816e-02 3.0062092e-02 4.0170901e-02 1.4526178e-01 2.6455507e-02]]

the environment's A tensor
[[1. 0. 0. 1. 1.]
 [0. 0. 1. 0. 0.]
 [0. 1. 0. 0. 0.]]
[[1. 1. 1. 0. 1.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0.]]


In [19]:
frames = []
for t in range(info["observation"][0].shape[0]):  # iterate over timesteps
    # get observations for this timestep
    observations_t = [
        info["observation"][0][t, :, :],
        info["observation"][1][t, :, :],  
        info["observation"][2][t, :, :]   
    ]
       
    frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
    frame = np.asarray(frame, dtype=np.uint8)
    plt.close()  # close the figure to prevent memory leak
    frames.append(frame)

frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)

0
This browser does not support the video tag.


### A and B Learning

In [31]:
###### NOT WORKING YET - RUNS FOREVER. I think the B is not shaped correctly for exhaustive dependencies.######

B = []
shapes = [(batch_size, 5, 5, 2, 5), (batch_size, 2, 2, 2, 5)]

for shape in shapes:
    key, subkey = jr.split(key)
    random_B = jr.uniform(subkey, shape=shape)
    random_B = random_B / jnp.sum(random_B, axis=1, keepdims=True)
    B.append(jnp.array(random_B, dtype=jnp.float32))

B_dependencies = [[0, 1], [0, 1]]
pB = B

# initialising the agent
agent = Agent(
    A, B, C, D, 
    pA=pA,
    pB=pB, # adding the noisy B tensor for learning
    policy_len=5, # how long the action sequence is that the agent is evaluating
    A_dependencies=A_dependencies, 
    B_dependencies=B_dependencies,
    apply_batch=False, 
    learn_A=True,
    learn_B=True,
    gamma=0.1,
    action_selection="stochastic"
)

KeyboardInterrupt: 

In [26]:
print("B shapes:", [b.shape for b in B])
print("Number of states in each factor:", [b.shape[0] for b in B])

B shapes: [(1, 5, 5, 5), (1, 2, 2, 1)]
Number of states in each factor: [1, 1]
