In [None]:
import sys
sys.path.append('..')
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from plotting import tsplot, tsplot_boot



In [None]:
from hippocampus_watermaze import HippocampalAgent
from striatum_blocking_model import Agent

In [None]:
n_trials=60
n_simulations = 1

In [None]:
list_of_escape_times = []
for num in tqdm(range(n_simulations)):
    # Gaussian RFs
    sa = Agent(n_trials=n_trials)
    sa.run_blocking_experiment()
    list_of_escape_times.append(sa.escape_times)

In [None]:
ar_striatum = np.array(list_of_escape_times) * sa.time_bin

In [None]:
ar_striatum

In [None]:
fig, ax = plt.subplots()
tsplot_boot(ax, ar_striatum)

ax.axhline(y=105, xmin=0, xmax=.66667, color='r', alpha=.5, LineWidth=4)
ax.axhline(y=100, xmin=.33333, xmax=1, color='g', alpha=.5, LineWidth=4)
ax.text(1,108,'Landmark 1', fontsize=20)
ax.text(48,103,'Landmark 2', fontsize=20)
plt.ylim([0,115])
plt.ylabel('Escape time (s)')
plt.xlabel('Trials')
filename = '../figs/blocking_effect_160218.png'
#plt.savefig(filename)
plt.show()

In [None]:
# Now the hippocampal system 

In [None]:
list_of_escape_times = []
for num in tqdm(range(n_simulations)):
    # Gaussian RFs
    aca = HippocampalAgent(n_trials=n_trials+1)
    aca.run_simulation()
    escape_times = [aca.position_log[aca.position_log['Trial'] == i].shape[0] for i in range(aca.n_trials)]
    list_of_escape_times.append(escape_times[1:])
    aca.env.trial = 0

In [None]:
ar_hpc = np.array(list_of_escape_times) *aca.env.time_bin

In [None]:
fig, ax = plt.subplots()
tsplot_boot(ax, ar_hpc)

plt.show()

In [None]:
import matplotlib
matplotlib.style.use('seaborn-poster')
font = {'family' : 'normal',
        'weight' : 'normal',
        'size'   : 35}

matplotlib.rc('font', **font)

In [None]:

fig, [ax1, ax2] = plt.subplots(1,2,sharey=True)


plt.sca(ax1)
plt.ylabel('Escape time (s)')
plt.xlabel('Trial')

ax1.axhline(y=85, xmin=0, xmax=.66667, color='y', alpha=.5, LineWidth=4)
ax1.axhline(y=80, xmin=.33333, xmax=1, color='b', alpha=.5, LineWidth=4)
ax1.text(1,86,'Cue 1', fontsize=25)
ax1.text(40,81,'Cue 2', fontsize=25)

tsplot_boot(ax1, ar_hpc, color='red')

for ax in [ax1, ax2]:
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
                 ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(35)



plt.sca(ax2)
plt.xlabel('Trial')

tsplot_boot(ax2, ar_striatum, color='green')


ax2.axhline(y=85, xmin=0, xmax=.66667, color='y', alpha=.5, LineWidth=4)
ax2.axhline(y=80, xmin=.33333, xmax=1, color='b', alpha=.5, LineWidth=4)
ax2.text(1,86,'Cue 1', fontsize=25)
ax2.text(40,81,'Cue 1', fontsize=25)

plt.tight_layout()

plt.savefig('blocking_illustration.png')
plt.show()

In [None]:
matplotlib.style.available

In [None]:
matplotlib.style.use('seaborn-white')
%matplotlib notebook

In [None]:
import plotting as pl
fig, axs = plot_trace(sa, [0, 7, 10])

fig.set_figheight(2.5)
plt.savefig('trajectories.png')
plt.show()

In [None]:
help(pl.plot_trace)

In [None]:
def plot_trace(agent, trials_to_plot=None):
    """Plot the swimming trajectory in the maze.
    """
    if not trials_to_plot:
        trials_to_plot = [1, int(agent.n_trials/2)+1, agent.n_trials]

    n_rows = int(math.ceil(len(trials_to_plot)/5))
    n_cols = int(math.ceil(len(trials_to_plot)/n_rows))
    fig, axs = plt.subplots(n_rows, n_cols, sharex='row', sharey='row')
    angles = np.linspace(0, 2 * np.pi, 100)

    x_marks = np.cos(angles) * agent.maze_radius + agent.maze_centre[0]
    y_marks = np.sin(angles) * agent.maze_radius + agent.maze_centre[1]

    axs = axs.ravel()

    for i, trial in enumerate(trials_to_plot):
        axs[i].plot(x_marks, y_marks)  # Draw the boundary of the circular maze
        trial_trajectory = agent.position_log[agent.position_log['Trial'] == trial]
        axs[i].plot(trial_trajectory['X position'], trial_trajectory['Y position'])
        axs[i].axis('equal')  # enforces equal axis sizes
        axs[i].set_title('Trial {}'.format(trial))

        platform = plt.Circle(agent.platform_centre, agent.platform_radius, color='g')
        axs[i].add_artist(platform)

        landmark1 = plt.Circle(agent.landmark_1_centre, agent.landmark_1_radius, color='r')
        axs[i].add_artist(landmark1)
        landmark2 = plt.Circle(agent.landmark_2_centre, agent.landmark_1_radius, color='y')
        axs[i].add_artist(landmark2)

        plt.xlim((agent.minx, agent.maxx))
        plt.ylim((agent.miny, agent.maxy))
        axs[i].tick_params(axis='both', which='both', bottom='off', top='off', labelbottom='off')
    return fig, axs


In [None]:
import math