In [1]:
import sys
sys.path.insert(0, '../../../network')

In [2]:
import logging
import argparse
import numpy as np
from network import Population, RateNetwork
from learning import ReachingTask
from transfer_functions import ErrorFunction
from connectivity import SparseConnectivity, LinearSynapse, ThresholdPlasticityRule, set_connectivity 
from sequences import GaussianSequence
import matplotlib.pyplot as plt
import seaborn as sns
logging.basicConfig(level=logging.INFO)

In [3]:
params = np.load("./params.npz", allow_pickle=True) 
N, sequences, patterns, cp, cw, A, plasticity_rule = params['N'], params['sequences'], params['patterns'], params['cp'], params['cw'], params['A'], params['plasticity_rule']

In [4]:
phi = ErrorFunction(mu=0.22, sigma=0.1).phi
plasticity = ThresholdPlasticityRule(x_f=0.5, q_f=0.8)

# populations
ctx = Population(N=N[0], tau=1e-2, phi=phi, name='ctx')
d1 = Population(N=N[1], tau=1e-2, phi=phi, name='d1')
d2 = Population(N=N[2], tau=1e-2, phi=phi, name='d2')

J = set_connectivity([ctx, d1, d2], cp, cw, A, plasticity_rule, patterns, plasticity)
network = RateNetwork([ctx, d1, d2], J, formulation=4, disable_pbar=False)

INFO:connectivity:Building connections from ctx to ctx
INFO:connectivity:Building connections from ctx to d1
INFO:connectivity:Building connections from ctx to d2
INFO:connectivity:Building connections from d1 to ctx
INFO:connectivity:Building connections from d1 to d1
INFO:connectivity:Building connections from d1 to d2
INFO:connectivity:Building connections from d2 to ctx
INFO:connectivity:Building connections from d2 to d1
INFO:connectivity:Building connections from d2 to d2


In [None]:
# init_input_ctx = np.random.RandomState().normal(0,1,size=patterns_ctx[0][0].shape)
# init_input_d1 = np.random.RandomState().normal(0,1,size=patterns_d1[0][0].shape)

# init_input_ctx = phi(patterns_ctx[0][0])
# init_input_d1 = phi(patterns_d1[0][1])
# init_input_d2 = np.random.RandomState().normal(0,1,size=patterns_d1[0][0].shape)
init_inputs = [np.zeros(ctx.size),
               patterns[1][0][2],
               np.zeros(d2.size)]
input_patterns = [p[0] for p in patterns]

T=500 #ms
mouse = ReachingTask()
network.simulate_learning(mouse, T, init_inputs, input_patterns, plasticity, 
                          delta_t=100, eta=0.001, tau_e=500, lamb=0.8, 
                          noise=[.13, 0.13, 0.13], etrace=True, hyper=False,
                          r_ext=[lambda t:0, lambda t:1, lambda t:1], print_output=True)

INFO:network:Integrating network dynamics
  0%|          | 84/499999 [00:03<3:13:25, 43.08it/s] 

lick-->lick
lick-->aim
None 0


  0%|          | 112/499999 [00:03<2:48:22, 49.48it/s]

[-1, 0] 16


  0%|          | 253/499999 [00:09<5:06:24, 27.18it/s]

[2, 0] 129


  0%|          | 268/499999 [00:09<5:07:00, 27.13it/s]

[2, -1] 14


  0%|          | 277/499999 [00:09<5:05:24, 27.27it/s]

[-1, -1] 8
[0, -1] 4


  0%|          | 283/499999 [00:10<5:03:45, 27.42it/s]

aim-->reach


  0%|          | 739/499999 [00:26<5:04:52, 27.29it/s]

[0, 1] 455
[0, -1] 2


  0%|          | 763/499999 [00:27<5:04:42, 27.31it/s]

[-1, -1] 22
[1, -1] 1
reach-->lick


  0%|          | 943/499999 [00:34<5:07:28, 27.05it/s]

[1, 2] 176
[1, -1] 0
lick-->lick


  0%|          | 952/499999 [00:34<5:06:43, 27.12it/s]

[1, 2] 6
[1, -1] 0
lick-->lick
[1, 2] 3


  0%|          | 976/499999 [00:35<5:08:05, 27.00it/s]

[1, -1] 20
[-1, -1] 0
lick-->aim


  0%|          | 988/499999 [00:36<5:07:49, 27.02it/s]

[-1, 0] 8


  0%|          | 1321/499999 [00:48<5:12:03, 26.63it/s]

[2, 0] 334
[2, -1] 0
[2, 0] 0
aim-->aim
[2, -1] 1


  0%|          | 1336/499999 [00:48<5:08:25, 26.95it/s]

[-1, -1] 10


  0%|          | 1354/499999 [00:49<5:08:17, 26.96it/s]

[0, -1] 16
aim-->reach


  0%|          | 1804/499999 [01:06<5:03:26, 27.36it/s]

[0, 1] 449
[0, -1] 2


  0%|          | 1816/499999 [01:06<5:02:24, 27.46it/s]

[-1, -1] 8


  0%|          | 1828/499999 [01:07<5:03:30, 27.36it/s]

[1, -1] 10
reach-->lick


  0%|          | 1987/499999 [01:12<5:07:26, 27.00it/s]

[1, 2] 160


  0%|          | 1999/499999 [01:13<5:07:40, 26.98it/s]

[1, -1] 11


  0%|          | 2005/499999 [01:13<5:05:17, 27.19it/s]

[-1, -1] 5
lick-->aim


  0%|          | 2014/499999 [01:13<5:05:13, 27.19it/s]

[-1, 0] 7


  0%|          | 2437/499999 [01:29<5:10:12, 26.73it/s]

[2, 0] 422


  0%|          | 2461/499999 [01:30<5:21:09, 25.82it/s]

[2, -1] 24
[-1, -1] 0
aim-->reach


  0%|          | 2479/499999 [01:31<5:15:41, 26.27it/s]

[-1, 1] 14


  1%|          | 2968/499999 [01:49<5:06:23, 27.04it/s]

[0, 1] 490


  1%|          | 2992/499999 [01:50<5:03:26, 27.30it/s]

[0, -1] 23
[-1, -1] 3


  1%|          | 2998/499999 [01:50<5:03:27, 27.30it/s]

reach-->lick
[-1, 2] 6


  1%|          | 3298/499999 [02:01<5:12:51, 26.46it/s]

Mouse received reward


  1%|          | 3334/499999 [02:02<5:05:35, 27.09it/s]

[1, 2] 328


  1%|          | 3355/499999 [02:03<5:05:44, 27.07it/s]

[1, -1] 20
lick-->aim


  1%|          | 3364/499999 [02:03<5:05:52, 27.06it/s]

[-1, 0] 9


  1%|          | 3793/499999 [02:19<5:03:54, 27.21it/s]

[2, 0] 428


  1%|          | 3802/499999 [02:19<5:02:19, 27.35it/s]

[2, -1] 9


  1%|          | 3814/499999 [02:20<5:02:57, 27.30it/s]

[-1, -1] 11
[0, -1] 1
aim-->reach


  1%|          | 4312/499999 [02:38<5:04:13, 27.16it/s]

[0, 1] 493


  1%|          | 4318/499999 [02:38<5:10:12, 26.63it/s]

[0, -1] 5


  1%|          | 4339/499999 [02:39<5:07:41, 26.85it/s]

[-1, -1] 21
[1, -1] 4


  1%|          | 4345/499999 [02:40<5:07:00, 26.91it/s]

reach-->lick


  1%|          | 4576/499999 [02:48<5:05:05, 27.06it/s]

[1, 2] 230


  1%|          | 4591/499999 [02:49<5:14:58, 26.21it/s]

[1, -1] 16
lick-->aim


  1%|          | 4621/499999 [02:50<5:02:17, 27.31it/s]

[1, 0] 27


  1%|          | 4630/499999 [02:50<5:04:12, 27.14it/s]

[-1, 0] 8


  1%|          | 4939/499999 [03:01<5:10:44, 26.55it/s]

[2, 0] 309


  1%|          | 4951/499999 [03:02<5:06:49, 26.89it/s]

[-1, 0] 11


  1%|          | 4966/499999 [03:02<5:05:32, 27.00it/s]

[0, 0] 14


  1%|          | 4984/499999 [03:03<5:04:03, 27.13it/s]

[0, -1] 18
aim-->reach


  1%|          | 5374/499999 [03:17<5:05:00, 27.03it/s]

[0, 1] 389


  1%|          | 5392/499999 [03:18<5:03:48, 27.13it/s]

[0, -1] 16


  1%|          | 5407/499999 [03:19<5:08:21, 26.73it/s]

[-1, -1] 15
reach-->lick


  1%|          | 5422/499999 [03:19<5:05:43, 26.96it/s]

[-1, 2] 12


  1%|          | 5701/499999 [03:30<5:01:37, 27.31it/s]

[1, 2] 279


  1%|          | 5710/499999 [03:30<5:01:26, 27.33it/s]

Mouse received reward
[1, -1] 10
lick-->aim


  1%|          | 5737/499999 [03:31<5:02:39, 27.22it/s]

[1, 0] 23


  1%|          | 5743/499999 [03:31<5:06:33, 26.87it/s]

[-1, 0] 5


  1%|          | 5989/499999 [03:40<5:04:09, 27.07it/s]

In [None]:
overlaps_ctx = sequences[0][0].overlaps(network.pops[0])
overlaps_d1 = sequences[1][0].overlaps(network.pops[1])
overlaps_d2 = sequences[2][0].overlaps(network.pops[2])
# filename = 'learning-0005-1600-600-5-1000-v0'
# np.savez('./data/' + filename + '.npz', 
#          overlaps_ctx=overlaps_ctx, overlaps_bg=overlaps_bg, 
#          correlations_ctx=correlations_ctx, correlations_bg=correlations_bg, 
#          state_ctx=net_ctx.exc.state, state_bg=net_bg.exc.state)

In [None]:
sns.set_style('white') 
colors = sns.color_palette('deep')

fig, axes = plt.subplots(3,1, sharex=True, sharey=True, tight_layout=True, figsize=(20,20))
axes[0].plot(overlaps_ctx[0], linestyle='solid', linewidth=3, color=colors[8], label='Aim')
axes[0].plot(overlaps_ctx[1], linestyle='dashed', linewidth=3, color=colors[0], label='Reach')
axes[0].plot(overlaps_ctx[2], linestyle='dotted', linewidth=3, color=colors[3], label='Lick')
axes[0].set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])
axes[0].set_title("CTX", fontsize=25)
axes[1].plot(overlaps_d1[0], linestyle='solid', linewidth=3, color=colors[8])
axes[1].plot(overlaps_d1[1], linestyle='dashed', linewidth=3, color=colors[0])
axes[1].plot(overlaps_d1[2], linestyle='dotted', linewidth=3, color=colors[3])
axes[1].set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])
axes[1].set_title("D1", fontsize=25)
# axes[1].set_xlabel('Time (ms)', fontsize=20)
axes[2].plot(overlaps_d2[0], linestyle='solid', linewidth=3, color=colors[8])
axes[2].plot(overlaps_d2[1], linestyle='dashed', linewidth=3, color=colors[0])
axes[2].plot(overlaps_d2[2], linestyle='dotted', linewidth=3, color=colors[3])
axes[2].set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])
axes[2].set_title("D2", fontsize=25)
axes[2].set_xlabel('Time (ms)', fontsize=20)
# axes[3].plot(np.average(network.pops[1].depression, axis=0), label='y')
# axes[3].plot(np.average(network.pops[1].state, axis=0), label='r')
# axes[3].plot(np.average(network.pops[1].depression * network.pops[1].state, axis=0), label='yr')
# axes[4].plot(np.average(network.pops[2].depression, axis=0))
# axes[4].plot(np.average(network.pops[2].state, axis=0))
# axes[4].plot(np.average(network.pops[2].depression * network.pops[2].state, axis=0))
fig.text(-0.01, 0.5, 'Overlap', va='center', rotation='vertical', fontsize=20)
plt.setp(axes, xlim=(0, 1000))
plt.setp(axes, ylim=(-.2, .4))
plt.figlegend(fontsize=20)
# plt.figlegend(labels=['Aim', 'Reach', 'Lick'], fontsize=20)
plt.savefig('/work/jp464/striatum-sequence/output/adaptation-symmetric.jpg', bbox_inches = "tight", format='jpg')

plt.show()
