In [None]:
from combined_agent import Agent
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from plotting import tsplot_boot
%matplotlib notebook

## Performance of hippocampus, striatum and combined

In [None]:
n_agents = 20
n_episodes = 15

escape_times = np.zeros((n_agents, n_episodes, 3))

for it in tqdm(range(n_agents)):
    hipp = Agent(lesion_hippocampus=True)
    stria = Agent(lesion_striatum=True)
    cont = Agent()
    cont.hippocampus.max_goal_response= 8
    for ep in range(n_episodes):
        th, r, locs, choices = hipp.train_one_episode()
        ts, r, locs, choices = stria.train_one_episode()
        tc, r, locs, choices = cont.train_one_episode()
        escape_times[it, ep, 0] = th
        escape_times[it, ep, 1] = ts
        escape_times[it, ep, 2] = tc

In [None]:
fig, ax = plt.subplots()
tsplot_boot(ax, escape_times[:,:,0])
tsplot_boot(ax, escape_times[:,:,1], color='r')
tsplot_boot(ax, escape_times[:,:,2], color='g')
plt.legend(['Hippocampal lesion', 'Striatal lesion', 'Control'])
plt.show()

In [None]:
plt.figure()
cont.hippocampus.plot_value_function()

## Distribution of choices from hippocampus and striatum

In [None]:
# Which brain area do we sample from? 
from collections import Counter
import pandas as pd

In [None]:
n_agents = 30
n_episodes = 15

allrand = np.zeros((n_agents, n_episodes))
allstriat = np.zeros((n_agents, n_episodes))
allhipp = np.zeros((n_agents, n_episodes))

for it in tqdm(range(n_agents)):
    all_choices = []

    ag = Agent()
    ag.hippocampus.max_goal_response = 8
    for ep in range(n_episodes):
        tc, r, locs, choices = ag.train_one_episode()
        all_choices.append(choices)

    rand = []
    striat = []
    hipp = []
    for ep in all_choices:
        counts = Counter(ep)
        rand.append(counts['random'])
        striat.append(counts['striatum'])
        hipp.append(counts['hippocampus'])
        
    allrand[it] = rand
    allstriat[it] = striat
    allhipp[it] = hipp

In [None]:
rand = np.mean(allrand, axis=0)
striat = np.mean(allrand, axis=0)
hipp = np.mean(allhipp, axis=0)

In [None]:
rand

In [None]:
striat

In [None]:
hipp

In [None]:
# Data
r = np.arange(len(all_choices))
raw_data = {'Random': rand, 'Striatum': striat,'Hippocampus': hipp}
df = pd.DataFrame(raw_data)



In [None]:
#ag.hippocampus.plot_value_function()

In [None]:
fig, ax = plt.subplots()
# From raw value to percentage
totals = [i+j for i,j in zip(df['Striatum'], df['Hippocampus'])]
greenBars = [i / j * 100 for i,j in zip(df['Hippocampus'], totals)]
orangeBars = [i / j * 100 for i,j in zip(df['Striatum'], totals)]
 
# plot
barWidth = 0.85
names = r
# Create green Bars
plt.bar(r, greenBars, color='#b5ffb9', edgecolor='white', width=barWidth)
# Create orange Bars
plt.bar(r, orangeBars, bottom=greenBars, color='#f9bc86', edgecolor='white', width=barWidth)
 
# Custom x axis
plt.xticks(r, names)
plt.xlabel("Trial")
plt.ylabel('% choices')
plt.legend(['Hippocampus', 'Striatum'])
plt.title('Frequency of choices')
# Show graphic
plt.show()
