In [1]:
import numpy as np
import copy

from pymdp.utils import plot_beliefs, plot_likelihood
from pymdp import utils
from pymdp.envs import TMazeEnv, MultiArmedBanditEnv
from pymdp.pdo_agent import PDOAgent, EVAgent

In [2]:
reward_probabilities = [0.98, 0.02] # probabilities used in the original SPM T-maze demo
# reward_probabilities = [0.5, 0.5] # probabilities used in the original SPM T-maze demo
env = TMazeEnv(reward_probs = reward_probabilities)
#env = MultiArmedBanditEnv(4)
env.values = [1.0, 2.0, -1.0, 0.0, 0.0] # The last one is the initial state value

A_gp = env.get_likelihood_dist()
B_gp = env.get_transition_dist()
A_gm = copy.deepcopy(A_gp) # make a copy of the true observation likelihood to initialize the observation model
B_gm = copy.deepcopy(B_gp) # make a copy of the true transition likelihood to initialize the transition model

In [3]:
T = 3 # number of timesteps: 2 or 3 for TMaze, 1 or more for MultiArmedBandit

agent = EVAgent(A=A_gm, B=B_gm, time_horizon=T, env=env, policy_lr=10.0, policy_iterations=100)
obs = env.reset() # reset the environment and get an initial observation

print(f"Consistent observation sequences: {len(agent.generate_consistent_observation_seqs())} (out of approx {len(agent.possible_observations) ** T})")

if isinstance(env, TMazeEnv):
    agent.D[0] = utils.onehot(0, agent.num_states[0])
    agent.C[1][1] = 1.0
    agent.C[1][2] = -1.0
    # these are useful for displaying read-outs during the loop over time
    reward_conditions = ["Right", "Left"]
    location_observations = ['CENTER','RIGHT ARM','LEFT ARM','CUE LOCATION']
    reward_observations = ['No reward','Reward!','Loss!']
    cue_observations = ['Cue Right','Cue Left']

    for reps in range(10):
        agent.reset()
        obs = env.reset() # reset the environment and get an initial observation
        msg = """ === Starting experiment === \n Reward condition: {}, Observation: [{}, {}, {}]"""
        print(msg.format(reward_conditions[env.reward_condition], location_observations[obs[0]], reward_observations[obs[1]], cue_observations[obs[2]]))
        for t in range(T):
            agent.infer_states(obs)
            agent.infer_policies()
            action = agent.sample_action()
            msg = """[Step {}] Action: [Move to {}]"""
            print(msg.format(t, location_observations[int(action[0])]))
            obs = env.step(action)
            msg = """[Step {}] Observation: [{},  {}, {}]"""
            print(msg.format(t, location_observations[obs[0]], reward_observations[obs[1]], cue_observations[obs[2]]))

elif isinstance(env, MultiArmedBanditEnv):
    agent.D[0] = utils.onehot(env.INITIAL_STATE, agent.num_states[0])
    agent.C[1][:] = env.values

    print(f"Initial observation: {obs}")
    for t in range(T):
        agent.infer_states(obs)
        agent.infer_policies()
        action = agent.sample_action()
        obs = env.step(action)
        print(f"Action: {np.array(action)}, Observation: {np.array(obs)}, Reward: {env.values[obs[0]]}")

else: 
    raise ValueError("Unknown environment type")


Consistent observation sequences: 1816 (out of approx 13824)
 === Starting experiment === 
 Reward condition: Right, Observation: [CENTER, No reward, Cue Left]


100%|██████████| 100/100 [02:59<00:00,  1.79s/it, G=-1.8944035] 


[Step 0] Action: [Move to CUE LOCATION]
[Step 0] Observation: [CUE LOCATION,  No reward, Cue Right]
[Step 1] Action: [Move to RIGHT ARM]
[Step 1] Observation: [RIGHT ARM,  Reward!, Cue Left]
[Step 2] Action: [Move to RIGHT ARM]
[Step 2] Observation: [RIGHT ARM,  Reward!, Cue Right]
 === Starting experiment === 
 Reward condition: Left, Observation: [CENTER, No reward, Cue Left]
[Step 0] Action: [Move to CUE LOCATION]
[Step 0] Observation: [CUE LOCATION,  No reward, Cue Left]
[Step 1] Action: [Move to LEFT ARM]
[Step 1] Observation: [LEFT ARM,  Reward!, Cue Left]
[Step 2] Action: [Move to LEFT ARM]
[Step 2] Observation: [LEFT ARM,  Reward!, Cue Right]
 === Starting experiment === 
 Reward condition: Right, Observation: [CENTER, No reward, Cue Left]
[Step 0] Action: [Move to CUE LOCATION]
[Step 0] Observation: [CUE LOCATION,  No reward, Cue Right]
[Step 1] Action: [Move to RIGHT ARM]
[Step 1] Observation: [RIGHT ARM,  Reward!, Cue Left]
[Step 2] Action: [Move to RIGHT ARM]
[Step 2] Obser

In [5]:
agent.policy.observation_sequences
agent.policy.policy_for_observations(((0,0,0),))

Array([[2.7379731e-04],
       [3.6035053e-04],
       [3.6035053e-04],
       [9.9900550e-01]], dtype=float32)