In [1]:
import sys
sys.path.insert(0, '../../../network')
import logging
import argparse
import numpy as np
from network import Population, RateNetwork
from learning import ReachingTask
from transfer_functions import ErrorFunction
from connectivity import SparseConnectivity, LinearSynapse, ThresholdPlasticityRule, set_connectivity 
from sequences import GaussianSequence
import matplotlib.pyplot as plt
import seaborn as sns
logging.basicConfig(level=logging.INFO)

### Initialization

In [2]:
params = np.load("./ctx_str_params.npz", allow_pickle=True) 
N, sequences, patterns, cp, cw, A = params['N'], params['sequences'], params['patterns'], params['cp'], params['cw'], params['A']

In [3]:
phi = ErrorFunction(mu=0.22, sigma=0.1).phi
plasticity = ThresholdPlasticityRule(x_f=0.5, q_f=0.8)

# populations
ctx = Population(N=N[0], tau=1e-2, phi=phi, name='ctx')
d1 = Population(N=N[1], tau=1e-2, phi=phi, name='d1')
# d2 = Population(N=N[2], tau=1e-2, phi=phi, name='d2')

J = set_connectivity([ctx, d1], cp, cw, A, patterns, plasticity)
# synapse = LinearSynapse(J[0][1].K, 1.5)
# J[0][1].update_sequences(patterns[0][0][1], patterns[1][0][2], synapse.h_EE, plasticity.f, plasticity.g)
# J[0][1].update_sequences(patterns[0][0][0], patterns[1][0][1], synapse.h_EE, plasticity.f, plasticity.g)
network = RateNetwork([ctx, d1], J, formulation=4, disable_pbar=False)

INFO:connectivity:Building connections from ctx to ctx
INFO:connectivity:Building connections from ctx to d1
INFO:connectivity:Building connections from d1 to ctx
INFO:connectivity:Building connections from d1 to d1


### Simulation

In [None]:
init_inputs = [np.zeros(ctx.size),
               np.zeros(d1.size)]
input_patterns = [p[0] for p in patterns]

T=100 #ms
mouse = ReachingTask()
network.simulate_learning(mouse, T, init_inputs, input_patterns, plasticity, 
                          delta_t=70, eta=0.05, tau_e=1500, lamb=0.6, 
                          noise=[0.13,0.13,0.13], e_bl = [0.06,0.0187,0.04,0.07], # [0.06,0.016,0.04,0.07]
                          alpha=0, gamma=0, adap=0, env=2.4, etrace=False,  
                          r_ext=[lambda t:0, lambda t: .5], print_output=True, track=True)

INFO:network:Integrating network dynamics
  0%|          | 112/99999 [00:03<18:03, 92.22it/s] 

null-->aim


  0%|          | 460/99999 [00:05<09:20, 177.63it/s]

aim-->scavenge


  1%|          | 866/99999 [00:07<09:14, 178.63it/s]

scavenge-->aim


  1%|▏         | 1273/99999 [00:09<09:17, 177.22it/s]

aim-->lick


  2%|▏         | 1587/99999 [00:11<09:15, 177.08it/s]

lick-->scavenge


  2%|▏         | 1973/99999 [00:13<09:09, 178.36it/s]

scavenge-->aim


  2%|▏         | 2415/99999 [00:16<09:07, 178.10it/s]

aim-->reach


  3%|▎         | 2639/99999 [00:17<08:55, 181.65it/s]

reach-->lick


  3%|▎         | 2770/99999 [00:17<08:58, 180.65it/s]

Mouse drank water


  3%|▎         | 2911/99999 [00:19<10:01, 161.52it/s]

lick-->scavenge


  3%|▎         | 3268/99999 [00:21<08:48, 183.14it/s]

scavenge-->aim


  4%|▎         | 3698/99999 [00:23<09:08, 175.69it/s]

aim-->lick


  4%|▍         | 3981/99999 [00:25<08:46, 182.37it/s]

lick-->scavenge


  4%|▍         | 4452/99999 [00:27<08:50, 180.10it/s]

scavenge-->aim


  5%|▍         | 4909/99999 [00:30<08:39, 183.01it/s]

aim-->lick


  5%|▌         | 5163/99999 [00:32<38:14, 41.33it/s] 

lick-->scavenge


  5%|▌         | 5214/99999 [01:09<33:54:12,  1.29s/it]

### Results

In [None]:
overlaps_ctx = sequences[0][0].overlaps(network.pops[0])
overlaps_d1 = sequences[1][0].overlaps(network.pops[1])
# overlaps_d2 = sequences[2][0].overlaps(network.pops[2])
filename = 'simulation-test'
np.savez('/work/jp464/striatum-sequence/' + 'test' + '.npz', 
         overlaps_ctx=overlaps_ctx, overlaps_d1=overlaps_d1)

In [None]:
# Set font family globally
sns.set_style('white') 
colors = sns.color_palette('deep')
plt.rc('xtick', labelsize=20) 
plt.rc('ytick', labelsize=20) 
plt.rcParams['axes.linewidth'] = 0.1

fig, axes = plt.subplots(2, 1, sharex=True, sharey=True, tight_layout=True, figsize=(20,10))
axes[0].plot(overlaps_ctx[0], linestyle='solid', linewidth=3, color=colors[8], label='Aim')
axes[0].plot(overlaps_ctx[1], linestyle='dashed', linewidth=3, color=colors[0], label='Reach')
axes[0].plot(overlaps_ctx[2], linestyle='dotted', linewidth=3, color=colors[3], label='Lick')
axes[0].plot(overlaps_ctx[3], linestyle='dotted', linewidth=3, color=colors[2], label='Scavenge')
# axes[0].set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])
axes[0].set_title("CTX", fontsize=25)
# axes.set_xlabel('Time (ms)', fontsize=20)
axes[1].plot(overlaps_d1[0], linestyle='solid', linewidth=3, color=colors[8])
axes[1].plot(overlaps_d1[1], linestyle='dashed', linewidth=3, color=colors[0])
axes[1].plot(overlaps_d1[2], linestyle='dotted', linewidth=3, color=colors[3])
axes[1].plot(overlaps_d1[3], linestyle='dotted', linewidth=3, color=colors[2])
axes[1].set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])
axes[1].set_title("STR", fontsize=25)
axes[1].set_xlabel('Time (ms)', fontsize=20)
axes[1].set_xlabel('Time (ms)', fontsize=20)

fig.text(-0.01, 0.5, 'Overlap', va='center', rotation='vertical', fontsize=20)
plt.setp(axes, xlim=(0, 10000))
plt.setp(axes, ylim=(-.2, .4))
# plt.figlegend(fontsize=20, loc='upper right')
# plt.figlegend(labels=['Aim', 'Reach', 'Lick'], fontsize=20)
plt.savefig('/work/jp464/striatum-sequence/output/simulation-online-learning-concept.jpg', bbox_inches = "tight", format='jpg')

plt.show()


In [None]:
def temporal_diff(A, B, max_iter):
    sum = 0
    cnt = 0
    for i in range(len(A)):
        if i == 0: continue
        if B[i+1] == None:
            break
        diff = B[i+1][1] - A[i][1]
        sum += (diff)
        cnt += 1
    return sum / cnt

temporal_diff(mouse.behaviors[0], mouse.behaviors[1], 100)