In [1]:
import os
import sys
import pathlib
import numpy as np
import copy

from pymdp.agent import Agent
from pymdp.utils import plot_beliefs, plot_likelihood
from pymdp import utils
from pymdp.envs import TMazeEnv

In [116]:
import numpy as np
import itertools

def infer_states(A, B, D, obs, policy):
    """ infer the posterior over states given observations and input. The inputs matrices using the functions below. """

    num_states = A.shape[1]
    T = len(policy) + 1
    t = len(obs)

    # list all possible state trajectories
    state_trajectories = list(itertools.product(*[range(s) for s in [num_states]*T]))

    q_unnormalized = np.zeros((len(state_trajectories),))

    for idx, state_trajectory in enumerate(state_trajectories):

        prod1 = 1
        for tau in range(t):
            prod1 *= A[obs[tau], state_trajectory[tau]]

        prod2 = D[state_trajectory[0]]
        for tau in range(1,T):
            prod2 *= B[state_trajectory[tau], state_trajectory[tau-1], policy[tau-1]]

        joint = prod1 * prod2

        q_unnormalized[idx] = joint

    q = q_unnormalized / np.sum(q_unnormalized)

    return q, state_trajectories


def make_A_flat(A):
    """ go from [p(o^1|s^1,..., s^f), ..., p(o^m|s^1,..., s^f)] to p(o|s) """
    
    multi_index_list_states = make_state_mapping(A)

    A_flat_states = [flatten_state_factors2(A_i, multi_index_list_states) for A_i in A]

    multi_index_list_observations = make_observation_mapping(A)
    A_flat = combine_observation_modalities(A_flat_states, multi_index_list_observations)

    return A_flat


def flatten_state_factors(A_m, multi_index_list_states):
    """ go from p(o^i|s^1,..., s^f) to p(o^i|s)"""
    num_obs = A_m.shape[0]

    A_flat_states_m = np.reshape(A_m, (num_obs, -1))

    return A_flat_states_m


def flatten_state_factors2(A_m, multi_index_list_states):
    
    num_obs = A_m.shape[0]
    num_states = len(multi_index_list_states)
    
    A_m_flat = np.zeros((num_obs, num_states))
    
    for idx_obs in range(num_obs):
        
        # TODO find better solution for slicing from array with variable dimensions (num_factors)
        for idx_state in range(num_states):
            state_indices = multi_index_list_states[idx_state]
            p = A_m[idx_obs]
            for state_index in state_indices:
                p = p[state_index]
            A_m_flat[idx_obs, idx_state] = p
    
    return A_m_flat
    

def combine_observation_modalities(A_flat_states, multi_index_list_observations):
    """ go from p(o^1, ..., o^m|s) to p(o|s) """

    num_states = len(A_flat_states[0][0, :])
    num_modalities = len(A_flat_states) 
    num_observations = len(multi_index_list_observations)

    A_flat = np.zeros((num_observations, num_states))

    for idx_s in range(num_states):

        for single_idx_obs, multi_idx_obs in enumerate(multi_index_list_observations):
            # compute list of p(o^m | s) for different modalities m and fixed state s = ss[idx_s]
            list_probs = [A_flat_states[m][multi_idx_obs[m], idx_s] for m in range(num_modalities)]

            joint_prob = product_of_elements(list_probs)

            A_flat[single_idx_obs, idx_s] = joint_prob

    return A_flat

def make_observation_mapping(A):
    num_observations_per_modality = [A_m.shape[0] for A_m in A]
    multi_index_list_observations = list(itertools.product(*[range(s) for s in num_observations_per_modality]))
    return multi_index_list_observations

def make_state_mapping(A=None, B=None):
    if A is not None:
        num_states_per_factor = list(A_gp[0].shape[1:])
    if B is not None:
        num_states_per_factor = [B[f].shape[0] for f in range(num_factors)]
    multi_index_list_states = list(itertools.product(*[range(s) for s in num_states_per_factor]))
    return multi_index_list_states
    

def make_B_flat(B):
    """
    go from [p(s^1_tau|s^1_{tau-1}, u_{tau-1}), ..., p(s^f_tau|s^f_{tau-1}, u_{tau-1})] to
    p(s_tau|s_{tau-1}, u_{tau-1}) """
    num_factors = len(B)
    num_states_per_factor = [B[f].shape[0] for f in range(num_factors)]
    num_states = product_of_elements(num_states_per_factor)
    num_actions = max([B[f].shape[-1] for f in range(num_factors)])

    B = preprocess_B(B, num_actions)

    B_flat = np.zeros((num_states, num_states, num_actions))

    multi_index_list_states = make_state_mapping(B=B)
    for idx_action in range(num_actions):

        for single_idx_state_tau, multi_idx_state_tau in enumerate(multi_index_list_states):
            for single_idx_state_tau_prev, multi_idx_state_tau_prev in enumerate(multi_index_list_states):

                list_probs = [B[f][multi_idx_state_tau[f], multi_idx_state_tau_prev[f], idx_action] for f in
                              range(num_factors)]
                joint_prob = product_of_elements(list_probs)

                B_flat[single_idx_state_tau, single_idx_state_tau_prev, idx_action] = joint_prob

    return B_flat, multi_index_list_states


def preprocess_B(B, num_actions):
    """ to account for uncontrollable transition dynamics that have a smaller number of action """
    B_new = []
    for B_i in B:
        if B_i.shape[-1] < num_actions:
            B_i = np.repeat(B_i[:, :, 0][:, :, np.newaxis], num_actions, axis=2)
        B_new.append(B_i)
    return B_new


def product_of_elements(lst):
    result = 1
    for num in lst:
        result *= num
    return result

def make_D_flat(D):

    num_factors = len(D)
    num_states_per_factor = [D[f].shape[0] for f in range(num_factors)]
    num_states = product_of_elements(num_states_per_factor)

    D_flat = np.zeros((num_states,))

    multi_index_list_states = list(itertools.product(*[range(s) for s in num_states_per_factor]))

    for single_idx_state, multi_idx_state in enumerate(multi_index_list_states):
        list_probs = [D[f][multi_idx_state[f]] for f in range(num_factors)]
        joint_prob = product_of_elements(list_probs)
        D_flat[single_idx_state] = joint_prob

    return D_flat



In [3]:
reward_probabilities = [0.98, 0.02] # probabilities used in the original SPM T-maze demo
env = TMazeEnv(reward_probs = reward_probabilities)

In [10]:
A_gp = env.get_likelihood_dist()
A_gp[2][:,:,0]

array([[0.5, 0.5, 0.5, 1. ],
       [0.5, 0.5, 0.5, 0. ]])

In [117]:
make_A_flat(A_gp)

array([[0.5 , 0.5 , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.5 , 0.5 , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.49, 0.01, 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.49, 0.01, 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.01, 0.49, 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.01, 0.49, 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.01, 0.49, 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.01, 0.49, 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.49, 0.01, 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  ,

In [113]:
x = np.array([[0,1],[0,1]])
x

array([[0, 1],
       [0, 1]])

In [115]:
x[0,0]

0

In [105]:
type(np.array([0])) != 'NoneType'

True

In [106]:
type(None) == 'NoneType'

False

In [18]:
B_gp = env.get_transition_dist()

In [27]:
B_gp[1][:,:,0]

array([[1., 0.],
       [0., 1.]])

In [29]:
make_B_flat(B_gp)[:,:,1]

[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (3, 0), (3, 1)]


array([[0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 1., 0., 1., 0., 1., 0.],
       [0., 1., 0., 1., 0., 1., 0., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.]])

In [30]:
num_factors = 2
D = utils.obj_array(num_factors)

D_location = np.array([1/4, 1/4, 1/4, 1/4])
D[0] = D_location


D_context = np.array([0.5,0.5])
D[1] = D_context



In [34]:
make_D_flat(D)

array([0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])

In [46]:
reward_probabilities = [0.98, 0.02] # probabilities used in the original SPM T-maze demo
env = TMazeEnv(reward_probs = reward_probabilities)
A_gp = env.get_likelihood_dist()
A_flat, observation_mapping = make_A_flat(A_gp)
B_gp = env.get_transition_dist()
B_flat, state_mapping = make_B_flat(B_gp)

num_factors = 2
D = utils.obj_array(num_factors)

D_location = np.array([1/4, 1/4, 1/4, 1/4])
D[0] = D_location


D_context = np.array([0.5,0.5])
D[1] = D_context

D_flat = make_D_flat(D)


[(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (0, 2, 0), (0, 2, 1), (1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1), (1, 2, 0), (1, 2, 1), (2, 0, 0), (2, 0, 1), (2, 1, 0), (2, 1, 1), (2, 2, 0), (2, 2, 1), (3, 0, 0), (3, 0, 1), (3, 1, 0), (3, 1, 1), (3, 2, 0), (3, 2, 1)]
[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (3, 0), (3, 1)]


In [81]:
obs_and_pol = {
    "go to cue":{
        "obs":[0,18],
        "policy": [3,1,0]
    },
    "go to reward arm":{
        "obs":[0,8],
        "policy": [1,0,0]
    }
}

q, trajectory_mapping = infer_states(A_flat, B_flat, D_flat, 
                                     obs_and_pol["go to reward arm"]["obs"], 
                                     obs_and_pol["go to reward arm"]["policy"])
len(q)

4096

In [82]:
for trajectory_idx, q_s in enumerate(q):
    trajectory = trajectory_mapping[trajectory_idx]
    states = [state_mapping[state_idx] for state_idx in trajectory]
    if q_s > 0:
        print(q_s, states)

0.9995836802664447 [(0, 0), (1, 0), (1, 0), (0, 0)]
0.0004163197335553706 [(0, 1), (1, 1), (1, 1), (0, 1)]


In [73]:
A_gp[1][:,:,0]

array([[1.  , 0.  , 0.  , 1.  ],
       [0.  , 0.98, 0.02, 0.  ],
       [0.  , 0.02, 0.98, 0.  ]])

$o^1$ = location  
$o^2$ = reward  
$o^3$ = cue  

$s^1$ = location  
$s^2$ = context  

location 1 = start  
location 2 = arm 1  
location 3 = arm 2  
location 4 = cue location  

first cue = first context = reward in first arm

$o^2$[0] = no reward    
$o^2$[1] = reward   
$o^2$[2] = loss 

In [68]:
for idx, obs in enumerate(observation_mapping):
    print(idx, obs)

0 (0, 0, 0)
1 (0, 0, 1)
2 (0, 1, 0)
3 (0, 1, 1)
4 (0, 2, 0)
5 (0, 2, 1)
6 (1, 0, 0)
7 (1, 0, 1)
8 (1, 1, 0)
9 (1, 1, 1)
10 (1, 2, 0)
11 (1, 2, 1)
12 (2, 0, 0)
13 (2, 0, 1)
14 (2, 1, 0)
15 (2, 1, 1)
16 (2, 2, 0)
17 (2, 2, 1)
18 (3, 0, 0)
19 (3, 0, 1)
20 (3, 1, 0)
21 (3, 1, 1)
22 (3, 2, 0)
23 (3, 2, 1)


In [89]:
list(A_gp[0].shape[1:])

[4, 2]