In [1]:
import numpy as np
import copy
from pprint import pprint
import tqdm

from pymdp.utils import plot_beliefs, plot_likelihood
from pymdp import utils
from pymdp.envs import TMazeEnv, MultiArmedBanditEnv
from pymdp.pdo_agents import PDOAgentGradient, EVAgentGradient, EVAgentDirect, PDOAgentDirect
from pymdp.agent import Agent

np.set_printoptions(linewidth=60, precision=3, suppress=False)

In [2]:
# reward_probabilities = [0.98, 0.02] # probabilities used in the original SPM T-maze demo
# reward_probabilities = [0.5, 0.5] # probabilities used in the original SPM T-maze demo
# env = TMazeEnv(reward_probs = reward_probabilities)·
# env = MultiArmedBanditEnv(4)
# env.values = [1.0, 2.0, -1.0, 0.0, 0.0] # The last one is the initial state value

In [3]:
reward_probabilities = [1.0, 0.0]
# reward_probabilities = [0.98, 0.02] # probabilities used in the original SPM T-maze demo
env = TMazeEnv(reward_probs=reward_probabilities)

# these are useful for displaying read-outs during the loop over time
reward_conditions = ["Right", "Left"]
location_observations = ['CENTER', 'RIGHT ARM', 'LEFT ARM', 'CUE LOCATION']
location_codes = "0RLC"
reward_observations = ['No reward', 'Reward!', 'Loss!']
cue_observations = ['Cue Right', 'Cue Left']


def obs_text(obs):
    return str((location_observations[obs[0]], reward_observations[obs[1]], cue_observations[obs[2]]))


def counts(vals):
    counts = {}
    for v in vals:
        counts.setdefault(v, 0)
        counts[v] += 1
    return sorted(counts.items(), key=lambda x: x[1], reverse=True)


def run_experiment(env, agent_init_func, reps, steps, progress=True, reinit_agent=True, show_reps=0):
    rews = []
    loc_obs = []
    if not reinit_agent:
        # Initialize the agent only once
        agent = agent_init_func()
        # Let it infer the complete policy (if e.g. PDO agent)
        agent.infer_states(env.reset())
        agent.infer_policies()

    for rep in (bar := tqdm.trange(reps, desc="Sampling runs", disable=not progress, leave=True)):
        if reinit_agent:
            agent = agent_init_func()
        else:
            agent.reset()

        obs = env.reset()  # reset the environment and get an initial observation
        rew = 0.0
        loc_ob = f"c{'RL'[env.reward_condition]}:{location_codes[obs[0]]}"

        if rep < show_reps:
            bar.write(f"=== New run ({
                      reward_conditions[env.reward_condition]}, observation: {obs_text(obs)} ===")

        for t in range(steps):
            agent.infer_states(obs)
            agent.infer_policies()
            action = agent.sample_action()
            obs = env.step(action)
            if rep < show_reps:
                bar.write(f"[Step {t}] action: [Move to {
                          location_observations[int(action[0])]}], observation: {obs_text(obs)}")

            for i in range(len(agent.C)):
                rew += agent.C[i][obs[i]]
            loc_ob = f"{loc_ob}-{location_codes[obs[0]]}"

        rews.append(float(rew))
        loc_obs.append(loc_ob)
        bar.set_postfix(Umean=np.mean(rews), Ustd=np.std(rews))
    bar.close()

    rews = np.array(rews)

    print(f"Rewards: {counts(rews)}")
    print(f"Mean reward of {reps}: {np.mean(rews):.3}, std={np.std(rews):.3}")
    print(f"Location sequences: {counts(loc_obs)}")

    return agent


T_CLUE = -1.7
T_WIN = 1.0
T_LOSS = -4.0

PDO_BETA = 0.0
PDO_ITERATIONS = 1000
PDO_LR = 1.
PDO_EV = False
PDO_DIRECT = True
EFE_SOPHISTICATED = False

REPS = 100
T = 2


def init_efe():
    A_gm = copy.deepcopy(env.get_likelihood_dist())
    B_gm = copy.deepcopy(env.get_transition_dist())
    agent = Agent(A=A_gm, B=B_gm, control_fac_idx=[
                  0], sophisticated=EFE_SOPHISTICATED)
    agent.D[0] = utils.onehot(0, agent.num_states[0])
    agent.C[1][1] = T_WIN
    agent.C[1][2] = T_LOSS
    agent.C[0][3] = T_CLUE
    return agent


def init_pdo():
    A_gm = copy.deepcopy(env.get_likelihood_dist())
    B_gm = copy.deepcopy(env.get_transition_dist())
    if PDO_EV:
        if PDO_DIRECT:
            agent = EVAgentDirect(A=A_gm, B=B_gm, time_horizon=T, env=env)
        else:
            agent = EVAgentGradient(A=A_gm, B=B_gm, time_horizon=T, env=env,
                            policy_lr=PDO_LR, policy_iterations=PDO_ITERATIONS)
    else:
        if PDO_DIRECT:
            agent = PDOAgentDirect(A=A_gm, B=B_gm, time_horizon=T, env=env, beta=PDO_BETA)
        else:
            agent = PDOAgentGradient(A=A_gm, B=B_gm, time_horizon=T, env=env,
                             policy_lr=PDO_LR, beta=PDO_BETA, policy_iterations=PDO_ITERATIONS)
    print(f"Using {type(agent)}")
    agent.D[0] = utils.onehot(0, agent.num_states[0])
    agent.C[1][1] = T_WIN
    agent.C[1][2] = T_LOSS
    agent.C[0][3] = T_CLUE
    return agent


# PDO_EV = True
# PDO_DIRECT = True
# print(f"\nEV - direct\n")
# ag = run_experiment(env, init_pdo, REPS, steps=T,
#                     show_reps=0, reinit_agent=False)
# print(f"Stats: {ag.stats}")

# PDO_EV = False
# PDO_DIRECT = True
# print(f"\nPDO - direct\n")
# ag = run_experiment(env, init_pdo, REPS, steps=T, show_reps=0, reinit_agent=False)
# print(f"Stats: {ag.stats}")

# PDO_EV = True
# PDO_DIRECT = False
# print(f"\nEV - GD\n")
# ag = run_experiment(env, init_pdo, REPS, steps=T,
#                     show_reps=0, reinit_agent=False)

# PDO_EV = False
# PDO_DIRECT = False
# print(f"\nPDO - GD\n")
# run_experiment(env, init_pdo, REPS, steps=T, show_reps=0, reinit_agent=False)

# EFE_SOPHISTICATED = False
# print(f"\nEFE (soph={EFE_SOPHISTICATED})\n")
# run_experiment(env, init_efe, REPS, steps=T, show_reps=0, reinit_agent=True)

EFE_SOPHISTICATED = True
print(f"\nEFE (soph={EFE_SOPHISTICATED})\n")
run_experiment(env, init_efe, reps=100, steps=T, show_reps=1, reinit_agent=True)


EFE (soph=True)



Sampling runs:   0%|          | 0/100 [00:00<?, ?it/s]

Sampling runs:   1%|          | 1/100 [00:00<00:11,  8.38it/s, Umean=-3, Ustd=0]

=== New run (Left, observation: ('CENTER', 'No reward', 'Cue Left') ===
[Step 0] action: [Move to RIGHT ARM], observation: ('RIGHT ARM', 'Loss!', 'Cue Right')
[Step 1] action: [Move to LEFT ARM], observation: ('LEFT ARM', 'Reward!', 'Cue Right')


Sampling runs: 100%|██████████| 100/100 [00:13<00:00,  7.41it/s, Umean=-0.6, Ustd=2.5]  

Rewards: [(np.float64(-3.0), 52), (np.float64(2.0), 48)]
Mean reward of 100: -0.6, std=2.5
Location sequences: [('cL:0-R-L', 27), ('cR:0-L-R', 25), ('cL:0-L-L', 24), ('cR:0-R-R', 24)]





<pymdp.agent.Agent at 0x73916b616d50>

In [4]:

# agent = PDOAgent(A=A_gm, B=B_gm, time_horizon=T, env=env,
#                  policy_lr=10., beta=10.0, policy_iterations=500)
# obs = env.reset()  # reset the environment and get an initial observation

# print(f"Consistent observation sequences: {len(agent.generate_consistent_observation_seqs(
# ))} (out of approx {len(agent.possible_observations) ** T})")

# if isinstance(env, TMazeEnv):
#     agent.D[0] = utils.onehot(0, agent.num_states[0])
#     agent.C[1][1] = 1.0
#     agent.C[1][2] = -1.0
#     agent.C[0][3] = -T_CLUE_PENALTY

#     # these are useful for displaying read-outs during the loop over time
#     reward_conditions = ["Right", "Left"]
#     location_observations = ['CENTER', 'RIGHT ARM', 'LEFT ARM', 'CUE LOCATION']
#     reward_observations = ['No reward', 'Reward!', 'Loss!']
#     cue_observations = ['Cue Right', 'Cue Left']

#     rews = []
#     for reps in range(REPS):
#         rew = 0.0
#         agent.reset()
#         obs = env.reset()  # reset the environment and get an initial observation
#         if reps < REPS_SHOW:
#             msg = """ === Starting experiment === \n Reward condition: {}, Observation: [{}, {}, {}]"""
#             print(msg.format(reward_conditions[env.reward_condition],
#                   location_observations[obs[0]], reward_observations[obs[1]], cue_observations[obs[2]]))
#         for t in range(T):
#             agent.infer_states(obs)
#             agent.infer_policies()
#             action = agent.sample_action()

#             if reps < REPS_SHOW:
#                 msg = """[Step {}] Action: [Move to {}]"""
#                 print(msg.format(t, location_observations[int(action[0])]))
#             obs = env.step(action)
#             if reps < REPS_SHOW:
#                 msg = """[Step {}] Observation: [{},  {}, {}]"""
#                 print(msg.format(
#                     t, location_observations[obs[0]], reward_observations[obs[1]], cue_observations[obs[2]]))

#             rew += efe_agent.C[0][obs[0]]
#             rew += efe_agent.C[1][obs[1]]
#             rew += efe_agent.C[2][obs[2]]

#         rews.append(rew)

#     print(f"Mean reward (of {REPS}): {np.mean(rews)}")

# elif isinstance(env, MultiArmedBanditEnv):
#     agent.D[0] = utils.onehot(env.INITIAL_STATE, agent.num_states[0])
#     agent.C[1][:] = env.values

#     print(f"Initial observation: {obs}")
#     for t in range(T):
#         agent.infer_states(obs)
#         agent.infer_policies()
#         action = agent.sample_action()
#         obs = env.step(action)
#         print(f"Action: {np.array(action)}, Observation: {
#               np.array(obs)}, Reward: {env.values[obs[0]]}")

# else:
#     raise ValueError("Unknown environment type")