In [None]:
from environments import PlusMaze
from combined_agent import Agent
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
from plotting import tsplot_boot
%matplotlib notebook
# Which brain area do we sample from? 
from collections import Counter
import pandas as pd

In [None]:
n_agents = 30
n_episodes = 25
env=PlusMaze()



In [None]:
import os
output_folder = 'figures/'

In [None]:
escape_times = np.zeros((n_agents, n_episodes, 3))

for it in tqdm(range(n_agents)):
    hipp = Agent(env=env, lesion_hippocampus=True)
    stria = Agent(env=env, lesion_striatum=True)
    stria.hippocampus.max_goal_response = .05
    cont = Agent(env=env)
    cont.hippocampus.max_goal_response = .05
    for ep in range(n_episodes):
        th, rh, locs, choices = hipp.train_one_episode()
        #hipp.striatum.field_width = 5
        ts, rs, locs, choices = stria.train_one_episode()
        tc, rc, locs, choices = cont.train_one_episode()
        #cont.striatum.field_width = 5
        escape_times[it, ep, 0] = rh
        escape_times[it, ep, 1] = rs
        escape_times[it, ep, 2] = rc

In [None]:
fig, ax = plt.subplots()
tsplot_boot(ax, escape_times[:,:,0])
tsplot_boot(ax, escape_times[:,:,1], color='r')
tsplot_boot(ax, escape_times[:,:,2], color='g')
plt.ylabel('Reward')
plt.xlabel('Trial')
plt.legend(['Hippocampal lesion', 'Striatal lesion', 'Control'])
plt.show()

In [None]:
fig, ax = plt.subplots()
plt.imshow(cont.striatum.weight_mat)

In [None]:
cont.hippocampus.plot_value_function()

In [None]:
n_agents = 30
n_episodes = 15

allrand = np.zeros((n_agents, n_episodes))
allstriat = np.zeros((n_agents, n_episodes))
allhipp = np.zeros((n_agents, n_episodes))

for it in tqdm(range(n_agents)):
    all_choices = []

    ag = Agent(env=env)
    ag.hippocampus.max_goal_response = .05
    for ep in range(n_episodes):
        tc, r, locs, choices = ag.train_one_episode()
        all_choices.append(choices)

    rand = []
    striat = []
    hipp = []
    for ep in all_choices:
        counts = Counter(ep)
        rand.append(counts['random'])
        striat.append(counts['striatum'])
        hipp.append(counts['hippocampus'])
        
    allrand[it] = rand
    allstriat[it] = striat
    allhipp[it] = hipp

In [None]:
rand = np.mean(allrand, axis=0)
striat = np.mean(allrand, axis=0)
hipp = np.mean(allhipp, axis=0)

In [None]:
# Data
r = np.arange(len(all_choices))
raw_data = {'Random': rand, 'Striatum': striat,'Hippocampus': hipp}
df = pd.DataFrame(raw_data)




In [None]:
plt.figure()

# From raw value to percentage
totals = [i+j for i,j in zip(df['Striatum'], df['Hippocampus'])]
greenBars = [i / j * 100 for i,j in zip(df['Hippocampus'], totals)]
orangeBars = [i / j * 100 for i,j in zip(df['Striatum'], totals)]
 
# plot
barWidth = 0.85
names = r
# Create green Bars
plt.bar(r, greenBars, color='#b5ffb9', edgecolor='white', width=barWidth)
# Create orange Bars
plt.bar(r, orangeBars, bottom=greenBars, color='#f9bc86', edgecolor='white', width=barWidth)
 
# Custom x axis
plt.xticks(r, names)
plt.xlabel("Trial")
plt.ylabel('% choices')
plt.legend(['Hippocampus', 'Striatum'])
plt.title('Frequency of choices')
# Show graphic
plt.show()

## Now do the actual packard & McGaugh experiment

In [None]:
agent = Agent(env=PlusMaze())

In [None]:
def packard_exp(agent, n_trials=5):
    agent.env.start_on_original_side()
    for ep in range(n_trials):
        t, reward, _, choices = agent.train_one_episode()
    agent.env.start_on_opposite_side()
    t, reward, locs, choices = agent.train_one_episode()
    return locs[-1]

In [None]:
#agent.env.start_on_opposite_side()
agent.env.start_on_original_side()
t, reward, locs, choices = agent.train_one_episode()

In [None]:
plt.figure()
plt.plot(locs[:, 0], locs[:, 1])
plt.xlim([agent.env.minx, agent.env.maxx])
plt.ylim([agent.env.miny, agent.env.maxy])
agent.env.draw_maze()



## Sanity check that hippocampal lesions only do response strategies

In [None]:
agent = Agent(env=PlusMaze(), lesion_hippocampus=True)
choices= []
for session in range(10):
    end_location = packard_exp(agent,n_trials=5)
    if end_location[0] <= .1:
        choice = 'place'
    elif end_location[0] >= .9:
        choice = 'response'
    else:
        choice = 'other'
    choices.append(choice)

In [None]:
choices

## Sanity check that striatal lesions only do place strategies

In [None]:
agent = Agent(env=PlusMaze(), lesion_striatum=True)
choices= []
for session in range(10):
    end_location = packard_exp(agent,n_trials=5)
    if end_location[0] <= .1:
        choice = 'place'
    elif end_location[0] >= .9:
        choice = 'response'
    else:
        choice = 'other'
    choices.append(choice)

In [None]:
choices

## Now the full experiment

In [None]:
all_choices = []
for ag in range(30):
    agent = Agent(env=PlusMaze())
    agent.hippocampus.goal_cell_decay_factor = 1
    choices= []
    for session in range(5):
        end_location = packard_exp(agent,n_trials=5)
        if end_location[0] <= .1:
            choice = 'place'
        elif end_location[0] >= .9:
            choice = 'response'
        else:
            choice = 'other'
        choices.append(choice)
    all_choices.append(choices)

In [None]:
ac = np.array(all_choices)

In [None]:
data = []
for session in range(5):
    c = Counter(ac.T[session])
    choice_proportions = np.array([c['place'], c['response']]) / (c['place'] + c['response']) * 100
    data.append(choice_proportions)

In [None]:
data = np.array(data).T
len(data[0])

In [None]:
data[0]

In [None]:
r = np.arange(len(data[0]))

In [None]:
plt.figure()
# plot
barWidth = 0.85
names = r + 1
# Create green Bars
plt.bar(r, data[0], color='#b5ffb9', edgecolor='white', width=barWidth)
# Create orange Bars
plt.bar(r, data[1], bottom=data[0], color='#f9bc86', edgecolor='white', width=barWidth)
 
# Custom x axis
plt.xticks(r, names)
plt.xlabel("Session")
plt.ylabel('% choices')
plt.legend(['Place', 'Response'])
plt.title('Frequency of choices on probe trial')
# Show graphic
plt.show()

## Recreate original figure

In [None]:
all_choices_hipp = []
for ag in range(30):
    agent = Agent(env=PlusMaze(), lesion_hippocampus=True)
    choices= []
    for session in range(5):
        end_location = packard_exp(agent,n_trials=5)
        if end_location[0] <= .1:
            choice = 'place'
        elif end_location[0] >= .9:
            choice = 'response'
        else:
            choice = 'other'
        choices.append(choice)
    all_choices_hipp.append(choices)

In [None]:
all_choices_striat = []
for ag in range(30):
    agent = Agent(env=PlusMaze(), lesion_striatum=True)
    agent.hippocampus.goal_cell_decay_factor = 1
    choices= []
    for session in range(5):
        end_location = packard_exp(agent,n_trials=5)
        if end_location[0] <= .1:
            choice = 'place'
        elif end_location[0] >= .9:
            choice = 'response'
        else:
            choice = 'other'
        choices.append(choice)
    all_choices_striat.append(choices)

In [None]:
acs = np.array(all_choices_striat)
data_striat = []
for session in [0,4]:
    c = Counter(acs.T[session])
    choice_proportions = np.array([c['place'], c['response']]) / (c['place'] + c['response']) * 100
    data_striat.append(choice_proportions)

In [None]:
ach = np.array(all_choices_hipp)
data_hipp = []
for session in [0,4]:
    c = Counter(ach.T[session])
    choice_proportions = np.array([c['place'], c['response']]) / (c['place'] + c['response']) * 100
    data_hipp.append(choice_proportions)

In [None]:
# now make the plots

In [None]:
bottom_saline = [data[0][0], data[0][0]]
top_saline = [data[1][0], data[1][0]]
bottom_lido = [data_striat[0][0], data_hipp[0][0]]
top_lido = [data_striat[0][1], data_hipp[0][1]]

bottom_saline_l = [data[0][-1], data[0][-1]]
top_saline_l = [data[1][-1], data[1][-1]]
bottom_lido_l = [data_striat[-1][0], data_hipp[-1][0]]
top_lido_l = [data_striat[-1][1], data_hipp[-1][1]]


In [None]:
barWidth = 0.35
r1 = np.arange(len(bottom_saline))
r2 = r1 + barWidth

In [None]:
def zip_lists(list_1, list_2):
    """Appends two list with alternating elements from list 1 and list 2.

    Note: this method assumes lists of equal length. If one list is longer than the other list, the later elements of
    the list are ignored, and the resulting  list will have length len(shortest_list) * 2.

    :param list_1:
    :param list_2:
    :return:
    """
    if len(list_1) != len(list_2):
        print('Warning: lists are of unequal length. Result length will match (2x) shortest.')
    return [j for i in zip(list_1, list_2) for j in i]



In [None]:
fig, axs = plt.subplots(1,2,sharey=True, figsize =(7,5))

## Early 
plt.sca(axs[0])
plt.title('Early in training', y=1.15)
# Saline group
plt.bar(r1, bottom_saline, color='#b5ffb9', edgecolor='white', width=barWidth)
plt.bar(r1, top_saline, bottom=bottom_saline, color='#f9bc86', edgecolor='white', width=barWidth)
plt.box(False)
# Lidocaine group
plt.bar(r2, bottom_lido, color='#b5ffb9', edgecolor='white', width=barWidth)
plt.bar(r2, top_lido, bottom=bottom_lido, color='#f9bc86', edgecolor='white', width=barWidth)
plt.ylabel('% of agents')
plt.xticks(zip_lists(r1,r2), 
           [r'$\bf{Saline}$'+'\n(Control)', r'$\bf{Lidocaine}$'+'\n(Inactivation)', 
            r'$\bf{Saline}$'+'\n(Control)', r'$\bf{Lidocaine}$'+'\n(Inactivation)'], 
           rotation=60, fontsize=10)

plt.tight_layout()
plt.text(-.15,105,'Striatum')
plt.text(.7,105,'Hippocampus')

## Late
plt.sca(axs[1])
plt.title('Late in training', y=1.15)

# Saline group
plt.bar(r1, bottom_saline_l, color='#b5ffb9', edgecolor='white', width=barWidth)
plt.bar(r1, top_saline_l, bottom=bottom_saline_l, color='#f9bc86', edgecolor='white', width=barWidth)
plt.box(False)
# Lidocaine group
plt.bar(r2, bottom_lido_l, color='#b5ffb9', edgecolor='white', width=barWidth)
plt.bar(r2, top_lido_l, bottom=bottom_lido_l, color='#f9bc86', edgecolor='white', width=barWidth)

plt.xticks(zip_lists(r1,r2), 
           [r'$\bf{Saline}$'+'\n(Control)', r'$\bf{Lidocaine}$'+'\n(Inactivation)', 
            r'$\bf{Saline}$'+'\n(Control)', r'$\bf{Lidocaine}$'+'\n(Inactivation)'], 
           rotation=60, fontsize=10)
plt.tight_layout(pad=5,w_pad=4, h_pad=3)
#plt.text(-.2,-35,'Striatum')
plt.text(-0.15,105,'Striatum')
plt.text(.7,105,'Hippocampus')

plt.legend(['Place Strategy', 'Response Strategy'])



plt.savefig(os.path.join(output_folder, 'PackardMcGaughModel.svg'))




In [None]:
help(plt.title)

## Plot original data

In [None]:
bottom_saline = [10 / 12 * 100, 12 / 14 * 100]
top_saline = [2 / 12 * 100, 2 / 14 * 100]
bottom_lido = [10 / 12 * 100, 6/12*100]
top_lido = [2 / 12 * 100, 6/12*100]

bottom_saline_l = [2/ 12*100, 3/14*100]
top_saline_l = [10/12 *100, 11/14*100]
bottom_lido_l = [11/12 *100, 2/12*100]
top_lido_l = [1 / 12 * 100, 10/12*100]


In [None]:
fig, axs = plt.subplots(1,2,sharey=True, figsize =(7,5))

orange = '#f9bc86'
green = '#b5ffb9'

## Early 
plt.sca(axs[0])
plt.title('Test day 8', y=1.15)
# Saline group
plt.bar(r1, bottom_saline, color=green, edgecolor='white', width=barWidth)
plt.bar(r1, top_saline, bottom=bottom_saline, color=orange, edgecolor='white', width=barWidth)
plt.box(False)
# Lidocaine group
plt.bar(r2, bottom_lido, color=green, edgecolor='white', width=barWidth)
plt.bar(r2, top_lido, bottom=bottom_lido, color=orange, edgecolor='white', width=barWidth)
plt.ylabel('% of animals')
plt.xticks(zip_lists(r1,r2), 
           [r'$\bf{Saline}$'+'\n(Control)', r'$\bf{Lidocaine}$'+'\n(Inactivation)', 
            r'$\bf{Saline}$'+'\n(Control)', r'$\bf{Lidocaine}$'+'\n(Inactivation)'], 
           rotation=60, fontsize=10)

plt.tight_layout()
plt.text(-.15,105,'Striatum')
plt.text(.7,105,'Hippocampus')

## Late
plt.sca(axs[1])
plt.title('Test Day 16', y=1.15)

# Saline group
plt.bar(r1, bottom_saline_l, color=green, edgecolor='white', width=barWidth)
plt.bar(r1, top_saline_l, bottom=bottom_saline_l, color=orange, edgecolor='white', width=barWidth)
plt.box(False)
# Lidocaine group
plt.bar(r2, bottom_lido_l, color=green, edgecolor='white', width=barWidth)
plt.bar(r2, top_lido_l, bottom=bottom_lido_l, color=orange, edgecolor='white', width=barWidth)
plt.xticks(zip_lists(r1,r2), 
           [r'$\bf{Saline}$'+'\n(Control)', r'$\bf{Lidocaine}$'+'\n(Inactivation)', 
            r'$\bf{Saline}$'+'\n(Control)', r'$\bf{Lidocaine}$'+'\n(Inactivation)'], 
           rotation=60, fontsize=10)

plt.tight_layout(pad=5,w_pad=4, h_pad=3)
plt.text(-.15,105,'Striatum')
plt.text(.7,105,'Hippocampus')

plt.legend(['Place Strategy', 'Response Strategy'])


plt.savefig(os.path.join(output_folder, 'PackardMcGaughData.svg'))