In [None]:
import sys
sys.path.append('../..')

from tqdm import tqdm_notebook as tqdm
from hippocampus.plotting import tsplot_boot
import matplotlib.pyplot as plt
from definitions import ROOT_FOLDER
import os
import pandas as pd
from hippocampus.agents import CombinedAgent
from hippocampus.environments import HexWaterMaze
import numpy as np
import random
import seaborn as sns
from multiprocessing import Pool
%matplotlib notebook

In [None]:

g = HexWaterMaze(6)
g.plot_grid()


In [None]:
inv_temp = 6.

# determine platform sequence
possible_platform_states = np.array([48, 45, 42, 39, 60, 57, 54, 51])
#possible_platform_states = np.array([192, 185, 181, 174, 216, 210, 203, 197])  # for the r = 10 case

indices = np.arange(len(possible_platform_states))
usage = np.zeros(possible_platform_states.shape)

platform_sequence = [np.random.choice(possible_platform_states)]
for ses in range(1,11):
    distances = np.array([g.grid.distance(platform_sequence[ses-1], s) for s in possible_platform_states])
    candidates = indices[np.logical_and(usage < 2, distances > g.grid.radius)]
    platform_idx = np.random.choice(candidates)
    platform_sequence.append(possible_platform_states[platform_idx])
    usage[platform_idx] += 1.

In [None]:
platform_sequence

In [None]:


#random.shuffle(possible_platform_states)
g.set_platform_state(possible_platform_states[6]) 


agent = CombinedAgent(g, init_sr='rw', lesion_dls=False, lesion_hpc=True, inv_temp=inv_temp, gamma=.99)
agent_results = []
agent_ets = []
session = 0

total_trial_count = 0

for ses in tqdm(range(11)):
    for trial in tqdm(range(4),leave=False):
        if trial == 0: 
            g.set_platform_state(platform_sequence[ses])
        res = agent.one_episode(random_policy=False)
        res['trial'] = trial
        res['escape time'] = res.time.max()
        res['session'] = ses
        res['total trial'] = total_trial_count
        agent_results.append(res)
        agent_ets.append(res.time.max())

        total_trial_count += 1
    #inv_temp += .8
    #agent.set_exploration(inv_temp)
    
agent_df = pd.concat(agent_results)
agent_df['total time']= np.arange(len(agent_df))



In [None]:
agent.weights.shape

In [None]:
agent.DLS.get_feature_rep(0,30).shape

In [None]:
plt.figure()
sns.lineplot(data=agent_df, x='total trial', y='escape time')

for i in range(44):
    if (i % 4) == 0:
        plt.axvline(x=i, ymin=0, ymax=1, linewidth=1, color='r', alpha=.3)


In [None]:
sns.lineplot(data=agent_df, x='total trial', y='P(SR)')

In [None]:
sns.regplot(data=agent_df[agent_df.trial==0], x='P(SR)', y='escape time')

In [None]:
first_and_last = agent_df[ np.logical_or(agent_df.trial == 0, agent_df.trial==3)]
plt.figure()
sns.lineplot(data=first_and_last, x='session', y='escape time', hue='trial', estimator=None)

In [None]:
# analyse two subsequent sessions

In [None]:


ses6 = agent_df[agent_df.session==2]
ses5 = agent_df[agent_df.session==1]

In [None]:
states = ses6[ses6['trial']==0]['state']

In [None]:
agent.env.plot_occupancy_on_grid(ses6[ses6['trial']==0])

In [None]:
ses6.head()

In [None]:
ses6.state.head()

In [None]:
mm = ses6.M_hat.iloc[0]

In [None]:
# previous platform 
ses5.platform.iloc[1]

In [None]:
rr = ses6.R_hat.iloc[0]

In [None]:
agent.env.plot_grid(mm[72], show_state_idx=True) 

In [None]:
agent.env.plot_grid(mm @ rr)

In [None]:
ses6[ses6['trial']==0]['Q_mf'].iloc[1]

In [None]:
ses6[ses6['trial']==0]['Q'].iloc[1]

In [None]:
v = mm @ rr

In [None]:
v[56]

In [None]:
ses6[ses6['trial']==0]['P(SR)'].iloc[1]

In [None]:
df1 = agent_df[['P(SR)','escape time', 'total trial', 'session', 'HPC reliability', 'DLS reliability']]


In [None]:
sns.pairplot(df1)

In [None]:
[agent_df[agent_df.session==i]['platform'].iloc[0] for i in range(11)]agent.inv_temp

In [None]:
trial_40 = agent_df[agent_df['total trial']==36]
trial_39 = agent_df[agent_df['total trial']==35]

g.plot_occupancy_on_grid(trial_40)

In [None]:
g.plot_grid(trial_40['M_hat'].iloc[0] @ trial_40['R_hat'].iloc[0])

In [None]:
trial_40['R_hat'].iloc[0]

In [None]:
trial_40['platform'].iloc[0]

In [None]:
trial_39['platform'].iloc[0]

In [None]:
from hippocampus.utils import softmax

In [None]:
softmax(trial_40['weights'].iloc[1].T @ trial_40['features'].iloc[1], beta=agent.inv_temp)

In [None]:
trial_40['Q_mf'].iloc[1]

In [None]:
trial_40['Q'].iloc[1]

In [None]:
V = trial_40['M_hat'].iloc[1] @ trial_40['R_hat'].iloc[1]

In [None]:
V[88]

In [None]:
# what are platform sequences for good runs versus bad runs? 
platform_sequence