In [None]:
from agent import KalmanSR
from environment import SimpleMDP
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
from dynamic_programming import value_iteration
import seaborn as sns
from tqdm import tqdm_notebook as tqdm
from itertools import product

In [None]:
%matplotlib notebook

In [None]:
env = SimpleMDP(5)

In [None]:
env.create_graph()

In [None]:
plt.figure()

positions = {0: (0, 0), 1: (1, 0), 2: (2, 0), 3: (3, 0), 4: (4, 0)}

env.show_graph(layout=positions)

# Learning the SR using a kalman filter 

In [None]:
transition_noise = .005 * np.eye(env.nr_states ** 2)
gamma = .9
prior_M = np.eye(env.nr_states).flatten()
prior_covariance = np.eye(env.nr_states ** 2)  # np.ones((env.nr_states**2, env.nr_states**2))
#prior_covariance = np.ones((env.nr_states**2, env.nr_states**2))
observation_noise_variance = np.eye(env.nr_states)  # np.ones([env.nr_states, env.nr_states])
#observation_noise_variance = np.ones([env.nr_states, env.nr_states])

M = prior_M
covariance = prior_covariance


def get_feature_representation(state_idx):
    """Get one-hot feature representation from state index.
    """
    if env.is_terminal(state_idx):
        return np.zeros(env.nr_states)
    else:
        return np.eye(env.nr_states)[state_idx]

In [None]:
for episode in tqdm(range(100)):
    env.reset()
    t = 0
    s = env.get_current_state()
    features = get_feature_representation(s)

    while not env.is_terminal(env.get_current_state()) and t < 1000:
        a = 1 #np.random.choice([0,1])

        next_state, reward = env.act(a)
        next_features = get_feature_representation(next_state)
        H = features - gamma * next_features  # Temporal difference features

        # Prediction step;
        a_priori_covariance = covariance + transition_noise

        # Compute statistics of interest;
        feature_block_matrix = np.kron(H, np.eye(env.nr_states)).T

        phi_hat = np.matmul(feature_block_matrix.T, M)
        delta_t = features - phi_hat
        parameter_error_cov = np.matmul(a_priori_covariance, feature_block_matrix)
        residual_cov = np.matmul(np.matmul(feature_block_matrix.T, a_priori_covariance),
                                   feature_block_matrix) + observation_noise_variance

        # Correction step;
        kalman_gain = np.matmul(parameter_error_cov, np.linalg.inv(residual_cov))
        if t==0:
            kmgain = kalman_gain
        delta_M = np.matmul(kalman_gain, delta_t)

        M += delta_M

        covariance = a_priori_covariance - np.matmul(np.matmul(kalman_gain, residual_cov), kalman_gain.T)

        s = next_state
        features = get_feature_representation(s)

        t += 1
np.around(M.reshape(env.nr_states, -1), decimals=3)

In [None]:
plt.figure()
plt.plot(M.reshape(env.nr_states, -1)[2])

In [None]:
# how does the 1--> 2 predictiveness covary with the 2-->3 predictiveness? 
M[1]

In [None]:
np.nonzero(covariance[0])

In [None]:
m_labels = ['M{}-{}'.format(i, j) for i, j in product(list(range(env.nr_states)), list(range(env.nr_states)))]

In [None]:
plt.figure()
plt.imshow(covariance[:-env.nr_states, :-env.nr_states])
plt.xticks(ticks=list(range(env.nr_states**2 - env.nr_states)), labels=m_labels, rotation=90)
plt.yticks(ticks=list(range(env.nr_states**2 - env.nr_states)), labels=m_labels, rotation=0)

plt.colorbar()

In [None]:
M