In [1]:
import sys
sys.path.insert(0, '../../../network')

In [2]:
import logging
import argparse
import numpy as np
from network import Population, RateNetwork
from learning import ReachingTask
from transfer_functions import ErrorFunction
from connectivity import SparseConnectivity, LinearSynapse, ThresholdPlasticityRule, set_connectivity 
from sequences import GaussianSequence
import matplotlib.pyplot as plt
import seaborn as sns
logging.basicConfig(level=logging.INFO)

In [3]:
params = np.load("./params.npz", allow_pickle=True) 
N, sequences, patterns, cp, cw, A = params['N'], params['sequences'], params['patterns'], params['cp'], params['cw'], params['A']

In [4]:
phi = ErrorFunction(mu=0.22, sigma=0.1).phi
plasticity = ThresholdPlasticityRule(x_f=0.5, q_f=0.8)

# populations
ctx = Population(N=N[0], tau=1e-2, phi=phi, name='ctx')
d1 = Population(N=N[1], tau=1e-2, phi=phi, name='d1')
d2 = Population(N=N[2], tau=1e-2, phi=phi, name='d2')

J = set_connectivity([ctx, d1, d2], cp, cw, A, patterns, plasticity)
network = RateNetwork([ctx, d1, d2], J, formulation=4, disable_pbar=False)

INFO:connectivity:Building connections from ctx to ctx
INFO:connectivity:Building connections from ctx to d1
INFO:connectivity:Building connections from ctx to d2
INFO:connectivity:Building connections from d1 to ctx
INFO:connectivity:Building connections from d1 to d1
INFO:connectivity:Building connections from d1 to d2





INFO:connectivity:Building connections from d2 to ctx
INFO:connectivity:Building connections from d2 to d1
INFO:connectivity:Building connections from d2 to d2





In [None]:
init_inputs = [np.zeros(ctx.size),
               np.zeros(d1.size),
               np.zeros(d2.size)]
input_patterns = [p[0] for p in patterns]

T=50 #ms
mouse = ReachingTask()
network.simulate_learning(mouse, T, init_inputs, input_patterns, plasticity, 
                          delta_t=750, eta=0.01, tau_e=2350, lamb=0.5, 
                          noise=[0.3,0.13,0.13], env=0.4, a=0.6, etrace=True, hyper=False,
                          r_ext=[lambda t:0, lambda t: .5, lambda t: 0], print_output=True)

INFO:network:Integrating network dynamics
  0%|          | 78/49999 [00:02<15:07, 55.01it/s]  

null-->aim


  1%|▏         | 735/49999 [00:05<03:23, 242.16it/s]

aim-->lick
None 0


  2%|▏         | 760/49999 [00:05<07:46, 105.55it/s]

[-1, 0] 16


  2%|▏         | 779/49999 [00:07<17:38, 46.48it/s] 

[-1, -1] 7
[-1, 2] 4


  3%|▎         | 1362/49999 [00:41<48:23, 16.75it/s]

lick-->scavenge


  3%|▎         | 1388/49999 [00:42<46:27, 17.44it/s]

[0, 2] 603


  3%|▎         | 1398/49999 [00:43<47:21, 17.10it/s]

[0, -1] 9


  3%|▎         | 1470/49999 [00:47<47:26, 17.05it/s]

[0, 3] 70


  3%|▎         | 1484/49999 [00:48<47:31, 17.01it/s]

[-1, 3] 13


  4%|▍         | 1974/49999 [01:17<48:17, 16.58it/s]

[2, 3] 490


  4%|▍         | 2108/49999 [01:25<49:27, 16.14it/s]

scavenge-->aim


  4%|▍         | 2112/49999 [01:25<49:42, 16.06it/s]

[-1, 3] 136
[3, 3] 1


  4%|▍         | 2116/49999 [01:25<49:43, 16.05it/s]

[-1, 3] 1


  4%|▍         | 2142/49999 [01:27<47:10, 16.90it/s]

[3, 3] 26


  4%|▍         | 2156/49999 [01:28<46:28, 17.15it/s]

[3, -1] 12


  5%|▌         | 2560/49999 [01:51<45:22, 17.43it/s]

[3, 0] 404


  5%|▌         | 2568/49999 [01:52<45:15, 17.47it/s]

[-1, 0] 7


  5%|▌         | 2578/49999 [01:52<47:04, 16.79it/s]

[3, 0] 10
[-1, 0] 0


  5%|▌         | 2588/49999 [01:53<48:13, 16.39it/s]

[3, 0] 6


  6%|▌         | 2856/49999 [02:09<47:10, 16.66it/s]

[-1, 0] 268
aim-->lick


  6%|▌         | 2884/49999 [02:10<44:54, 17.49it/s]

[0, 0] 27


  6%|▌         | 2900/49999 [02:11<45:04, 17.41it/s]

[0, -1] 15


  7%|▋         | 3582/49999 [02:53<48:32, 15.94it/s]

[0, 2] 681


  7%|▋         | 3606/49999 [02:55<47:51, 16.16it/s]

[-1, 2] 23


  7%|▋         | 3634/49999 [02:56<48:38, 15.88it/s]

[2, 2] 26
[-1, 2] 1


  7%|▋         | 3692/49999 [03:00<47:32, 16.24it/s]

lick-->scavenge


  7%|▋         | 3724/49999 [03:02<47:26, 16.26it/s]

[2, 2] 87


  7%|▋         | 3734/49999 [03:03<47:20, 16.29it/s]

[2, -1] 10


  8%|▊         | 4116/49999 [03:27<47:40, 16.04it/s]

[2, 3] 381


  9%|▊         | 4308/49999 [03:39<47:13, 16.12it/s]

scavenge-->aim


  9%|▊         | 4342/49999 [03:41<47:07, 16.15it/s]

[-1, 3] 225


  9%|▊         | 4354/49999 [03:42<49:24, 15.40it/s]

[-1, -1] 10


  9%|▉         | 4442/49999 [03:47<47:09, 16.10it/s]

[-1, 0] 87


 10%|▉         | 4878/49999 [04:14<46:35, 16.14it/s]

[3, 0] 436


 10%|█         | 5058/49999 [04:26<46:03, 16.26it/s]

[-1, 0] 178


 10%|█         | 5196/49999 [04:35<46:20, 16.11it/s]

aim-->lick


 10%|█         | 5226/49999 [04:36<46:51, 15.92it/s]

[0, 0] 167


 10%|█         | 5236/49999 [04:37<46:24, 16.08it/s]

[0, -1] 10


 11%|█         | 5544/49999 [04:56<46:05, 16.08it/s]

[0, 2] 306
[-1, 2] 0
[0, 2] 0


 12%|█▏        | 5790/49999 [05:11<46:48, 15.74it/s]

lick-->scavenge


 12%|█▏        | 5812/49999 [05:13<46:11, 15.94it/s]

[-1, 2] 266


 12%|█▏        | 5826/49999 [05:14<45:54, 16.04it/s]

[-1, -1] 12


 12%|█▏        | 5944/49999 [05:21<45:39, 16.08it/s]

[-1, 3] 118


 13%|█▎        | 6474/49999 [05:54<44:51, 16.17it/s]

[2, 3] 528


 13%|█▎        | 6538/49999 [05:58<44:59, 16.10it/s]

[-1, 3] 64


 13%|█▎        | 6544/49999 [05:58<44:50, 16.15it/s]

[3, 3] 5
scavenge-->aim


 13%|█▎        | 6548/49999 [05:59<44:54, 16.13it/s]

[-1, 3] 3
[3, 3] 0


 13%|█▎        | 6552/49999 [05:59<44:57, 16.11it/s]

[-1, 3] 2


 13%|█▎        | 6586/49999 [06:01<44:22, 16.30it/s]

[3, 3] 33


 13%|█▎        | 6598/49999 [06:02<44:26, 16.28it/s]

[3, -1] 11


 14%|█▍        | 7046/49999 [06:28<40:29, 17.68it/s]

[3, 0] 446


 15%|█▍        | 7252/49999 [06:41<43:37, 16.33it/s]

aim-->lick


 15%|█▍        | 7262/49999 [06:42<43:52, 16.23it/s]

[-1, 0] 216


 15%|█▍        | 7276/49999 [06:43<43:46, 16.27it/s]

[-1, -1] 13


 15%|█▍        | 7296/49999 [06:44<42:28, 16.76it/s]

[-1, 2] 19


 16%|█▌        | 7788/49999 [07:12<39:41, 17.73it/s]

[0, 2] 491


 16%|█▌        | 7906/49999 [07:18<39:37, 17.70it/s]

lick-->scavenge


 16%|█▌        | 7922/49999 [07:19<39:56, 17.56it/s]

[-1, 2] 132


 16%|█▌        | 7934/49999 [07:20<40:06, 17.48it/s]

[-1, -1] 11


 16%|█▌        | 8000/49999 [07:24<39:38, 17.66it/s]

[-1, 3] 66


 17%|█▋        | 8558/49999 [07:55<38:58, 17.72it/s]

[2, 3] 556


 17%|█▋        | 8562/49999 [07:55<39:21, 17.55it/s]

scavenge-->aim


 17%|█▋        | 8586/49999 [07:57<38:57, 17.71it/s]

[-1, 3] 27


 17%|█▋        | 8596/49999 [07:57<39:02, 17.67it/s]

[-1, -1] 9


 17%|█▋        | 8654/49999 [08:01<39:10, 17.59it/s]

[-1, 0] 58


 18%|█▊        | 9024/49999 [08:22<38:55, 17.55it/s]

[3, 0] 368


 18%|█▊        | 9028/49999 [08:22<38:39, 17.66it/s]

[-1, 0] 4


 18%|█▊        | 9032/49999 [08:22<38:48, 17.59it/s]

[3, 0] 3


 19%|█▊        | 9312/49999 [08:38<38:31, 17.60it/s]

[-1, 0] 278


 19%|█▉        | 9416/49999 [08:44<38:18, 17.66it/s]

aim-->lick
[0, 0] 107


 19%|█▉        | 9432/49999 [08:45<38:35, 17.52it/s]

[0, -1] 12


 20%|█▉        | 9984/49999 [09:16<38:20, 17.39it/s]

[0, 2] 550


 20%|█▉        | 9988/49999 [09:17<38:04, 17.51it/s]

lick-->scavenge


 20%|██        | 10014/49999 [09:18<37:31, 17.76it/s]

[-1, 2] 30


 20%|██        | 10028/49999 [09:19<37:48, 17.62it/s]

[-1, -1] 12


 20%|██        | 10166/49999 [09:27<37:45, 17.58it/s]

[-1, 3] 137


 21%|██        | 10620/49999 [09:53<37:12, 17.64it/s]

[2, 3] 453


 21%|██▏       | 10738/49999 [09:59<38:35, 16.96it/s]

[-1, 3] 117
[3, 3] 1


 21%|██▏       | 10744/49999 [10:00<37:49, 17.30it/s]

[-1, 3] 3


 22%|██▏       | 10838/49999 [10:05<37:23, 17.45it/s]

scavenge-->aim


 22%|██▏       | 10864/49999 [10:07<37:12, 17.53it/s]

[3, 3] 120


 22%|██▏       | 10878/49999 [10:07<40:11, 16.22it/s]

[3, -1] 12


 22%|██▏       | 11164/49999 [10:24<36:58, 17.50it/s]

In [None]:
overlaps_ctx = sequences[0][0].overlaps(network.pops[0])
overlaps_d1 = sequences[1][0].overlaps(network.pops[1])
overlaps_d2 = sequences[2][0].overlaps(network.pops[2])
# filename = 'learning-0005-1600-600-5-1000-v0'
# np.savez('./data/' + filename + '.npz', 
#          overlaps_ctx=overlaps_ctx, overlaps_bg=overlaps_bg, 
#          correlations_ctx=correlations_ctx, correlations_bg=correlations_bg, 
#          state_ctx=net_ctx.exc.state, state_bg=net_bg.exc.state)

In [None]:
sns.set_style('white') 
colors = sns.color_palette('deep')

fig, axes = plt.subplots(2, 1, sharex=True, sharey=True, tight_layout=True, figsize=(20,10))
axes[0].plot(overlaps_ctx[0], linestyle='solid', linewidth=3, color=colors[8], label='Aim')
axes[0].plot(overlaps_ctx[1], linestyle='dashed', linewidth=3, color=colors[0], label='Reach')
axes[0].plot(overlaps_ctx[2], linestyle='dotted', linewidth=3, color=colors[3], label='Lick')
axes[0].plot(overlaps_ctx[3], linestyle='dotted', linewidth=3, color=colors[2], label='Scavenge')
# axes[0].set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])
axes[0].set_title("CTX", fontsize=25)
# axes.set_xlabel('Time (ms)', fontsize=20)
axes[1].plot(overlaps_d1[0], linestyle='solid', linewidth=3, color=colors[8])
axes[1].plot(overlaps_d1[1], linestyle='dashed', linewidth=3, color=colors[0])
axes[1].plot(overlaps_d1[2], linestyle='dotted', linewidth=3, color=colors[3])
axes[1].plot(overlaps_d1[3], linestyle='dotted', linewidth=3, color=colors[2])
# axes[1].set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])
axes[1].set_title("D1", fontsize=25)
axes[1].set_xlabel('Time (ms)', fontsize=20)
axes[1].set_xlabel('Time (ms)', fontsize=20)
# axes[2].plot(overlaps_d2[0], linestyle='solid', linewidth=3, color=colors[8])
# axes[2].plot(overlaps_d2[1], linestyle='dashed', linewidth=3, color=colors[0])
# axes[2].plot(overlaps_d2[2], linestyle='dotted', linewidth=3, color=colors[3])
# axes[2].set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])
# axes[2].set_title("D2", fontsize=25)
# axes[2].set_xlabel('Time (ms)', fontsize=20)
# axes[3].plot(np.average(network.pops[1].depression, axis=0), label='y')
# axes[3].plot(np.average(network.pops[1].state, axis=0), label='r')
# axes[3].plot(np.average(network.pops[1].depression * network.pops[1].state, axis=0), label='yr')
# axes[4].plot(np.average(network.pops[2].depression, axis=0))
# axes[4].plot(np.average(network.pops[2].state, axis=0))
# axes[4].plot(np.average(network.pops[2].depression * network.pops[2].state, axis=0))
fig.text(-0.01, 0.5, 'Overlap', va='center', rotation='vertical', fontsize=20)
plt.setp(axes, xlim=(0, 20000))
plt.setp(axes, ylim=(-.2, .4))
plt.figlegend(fontsize=20, loc='upper left')
# plt.figlegend(labels=['Aim', 'Reach', 'Lick'], fontsize=20)
plt.savefig('/work/jp464/striatum-sequence/output/secondary-env.jpg', bbox_inches = "tight", format='jpg')

plt.show()


In [None]:
def temporal_diff(A, B, max_iter):
    sum = 0
    cnt = 0
    for i in range(len(A)):
        if i == 0: continue
        if B[i+1] == None:
            break
        diff = B[i+1][1] - A[i][1]
        sum += (diff)
        cnt += 1
    return sum / cnt

temporal_diff(mouse.behaviors[0], mouse.behaviors[1], 100)

In [None]:
mouse.behaviors[1][0:50]