In [1]:
import sys
sys.path.insert(0, '../../../network')

In [2]:
import logging
import argparse
import numpy as np
from network import Population, RateNetwork
from learning import ReachingTask
from transfer_functions import ErrorFunction
from connectivity import SparseConnectivity, LinearSynapse, ThresholdPlasticityRule
from sequences import GaussianSequence
import matplotlib.pyplot as plt
import seaborn as sns
logging.basicConfig(level=logging.INFO)

In [3]:
phi = ErrorFunction(mu=0.22, sigma=0.1).phi
ctx = Population(N=1000, tau=1e-2, phi=phi, name='ctx')
bg = Population(N=1000, tau=1e-2, phi=phi, name='bg')
plasticity = ThresholdPlasticityRule(x_f=0.5, q_f=0.8) 

S, P = 1, 3
sequences_ctx = [GaussianSequence(P,ctx.size, seed=114) for i in range(S)]
patterns_ctx = np.stack([s.inputs for s in sequences_ctx])
sequences_bg = [GaussianSequence(P,ctx.size, seed=29) for i in range(S)]
patterns_bg = np.stack([s.inputs for s in sequences_bg])

# Symmetric connections
J_cc = SparseConnectivity(source=ctx, target=ctx, p=0.05)
synapse_cc = LinearSynapse(J_cc.K, A=5)
J_cc.store_attractors(patterns_ctx[0], patterns_ctx[0], synapse_cc.h_EE, 
                      plasticity.f, plasticity.g)
J_bb = SparseConnectivity(source=bg, target=bg, p=0.05)
synapse_bb = LinearSynapse(J_bb.K, A=5)
J_bb.store_attractors(patterns_bg[0], patterns_bg[0], synapse_bb.h_EE, 
                      plasticity.f, plasticity.g)
J_cb  = SparseConnectivity(source=bg, target=ctx, p=0.05)
synapse_cb = LinearSynapse(J_cb.K, A=1)
J_cb.store_attractors(patterns_bg[0], patterns_ctx[0], synapse_cb.h_EE, 
                      plasticity.f, plasticity.g)

# Asymmetric connection
J_bc = SparseConnectivity(source=ctx, target=bg, p=0.05)
synapse_bc = LinearSynapse(J_bc.K, A=0)
J_bc.store_sequences(patterns_ctx, patterns_bg, synapse_bc.h_EE, plasticity.f, plasticity.g)

net_ctx = RateNetwork(ctx, c_EE=J_cc, c_IE=J_bc, formulation=4)
net_bg = RateNetwork(bg, c_II=J_bb, c_EI=J_cb, formulation=4)

INFO:connectivity:Building connections from ctx to ctx
INFO:connectivity:Storing attractors
100%|█████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 19238.51it/s]
INFO:connectivity:Building connections from bg to bg
INFO:connectivity:Storing attractors
100%|█████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 21301.48it/s]
INFO:connectivity:Building connections from bg to ctx
INFO:connectivity:Storing attractors
100%|█████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 22019.08it/s]
INFO:connectivity:Building connections from ctx to bg
INFO:connectivity:Storing sequences
  0%|                                                                                                          | 0/1 [00:00<?, ?it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████| 1000/

In [None]:
init_input_ctx = np.random.RandomState().normal(0,1,size=patterns_ctx[0][0].shape)
init_input_bg = np.random.RandomState().normal(0,1,size=patterns_bg[0][0].shape)
T=1000 #ms
mouse = ReachingTask()
net_ctx.simulate_learning(mouse, net_bg, T, init_input_ctx, init_input_bg, 
                          patterns_ctx[0], patterns_bg[0], plasticity, 
                          delta_t=300, eta=0.0005, tau_e=1600, lamb=0.5, print_output=True)

INFO:network:Integrating network dynamics
  0%|                                                                                         | 118/999999 [00:02<3:55:59, 70.62it/s]

null-->lick


  0%|                                                                                        | 290/999999 [00:03<1:13:29, 226.71it/s]

None 0


  0%|                                                                                         | 344/999999 [00:04<3:34:16, 77.75it/s]

[-1, 2] 45


  0%|                                                                                        | 539/999999 [00:12<11:11:07, 24.82it/s]

[2, 2] 185


  0%|                                                                                        | 695/999999 [00:19<13:11:07, 21.05it/s]

[2, -1] 155


  0%|                                                                                       | 1200/999999 [00:47<11:15:32, 24.64it/s]

[2, 2] 504


  0%|                                                                                       | 1347/999999 [00:53<11:12:40, 24.74it/s]

[2, -1] 147


  0%|                                                                                       | 1404/999999 [00:55<11:08:48, 24.89it/s]

lick-->aim


  0%|▏                                                                                      | 1695/999999 [01:07<11:30:18, 24.10it/s]

[2, 0] 348


  0%|▏                                                                                      | 1704/999999 [01:07<11:20:41, 24.44it/s]

[-1, 0] 8


  0%|▏                                                                                      | 1851/999999 [01:13<11:11:01, 24.79it/s]

[0, 0] 146


  0%|▏                                                                                      | 1998/999999 [01:19<11:12:10, 24.75it/s]

[0, -1] 146


  0%|▏                                                                                      | 2505/999999 [01:40<11:04:35, 25.02it/s]

[0, 0] 505


  0%|▏                                                                                      | 2652/999999 [01:46<11:15:02, 24.62it/s]

[0, -1] 145


  0%|▏                                                                                      | 2721/999999 [01:49<11:10:47, 24.78it/s]

aim-->reach


  0%|▎                                                                                      | 3012/999999 [02:00<11:13:15, 24.68it/s]

[0, 1] 361


  0%|▎                                                                                      | 3021/999999 [02:01<11:14:54, 24.62it/s]

[-1, 1] 6


  0%|▎                                                                                      | 3156/999999 [02:06<11:08:26, 24.85it/s]

[1, 1] 135


  0%|▎                                                                                      | 3324/999999 [02:13<11:10:14, 24.78it/s]

[1, -1] 167


  0%|▎                                                                                      | 3828/999999 [02:34<11:17:02, 24.52it/s]

[1, 1] 505


  0%|▎                                                                                      | 3996/999999 [02:40<11:17:43, 24.49it/s]

[1, -1] 165


  0%|▎                                                                                      | 4050/999999 [02:43<11:24:33, 24.25it/s]

reach-->lick


  0%|▍                                                                                      | 4350/999999 [02:55<11:42:59, 23.60it/s]

[1, 2] 353
Mouse received reward
[-1, 2] 1


  0%|▍                                                                                      | 4500/999999 [03:01<11:47:45, 23.44it/s]

[2, 2] 148


  0%|▍                                                                                      | 4647/999999 [03:07<11:49:29, 23.38it/s]

[2, -1] 145


  1%|▍                                                                                      | 5151/999999 [03:29<11:48:55, 23.39it/s]

[2, 2] 505


  1%|▍                                                                                      | 5337/999999 [03:37<11:43:33, 23.56it/s]

[2, -1] 183


  1%|▌                                                                                      | 5841/999999 [03:59<11:43:17, 23.56it/s]

[2, 2] 505


  1%|▌                                                                                      | 5982/999999 [04:05<11:52:44, 23.24it/s]

[2, -1] 138


  1%|▌                                                                                      | 6054/999999 [04:08<11:44:41, 23.51it/s]

lick-->aim


  1%|▌                                                                                      | 6339/999999 [04:20<11:45:12, 23.48it/s]

[2, 0] 358


  1%|▌                                                                                      | 6354/999999 [04:20<11:45:37, 23.47it/s]

[-1, 0] 12


  1%|▌                                                                                      | 6486/999999 [04:26<11:43:07, 23.55it/s]

[0, 0] 133


  1%|▌                                                                                      | 6630/999999 [04:32<11:48:05, 23.38it/s]

[0, -1] 142


  1%|▌                                                                                      | 7134/999999 [04:54<11:45:54, 23.44it/s]

[0, 0] 504


  1%|▋                                                                                      | 7269/999999 [05:00<11:57:30, 23.06it/s]

[0, -1] 132


  1%|▋                                                                                      | 7773/999999 [05:21<11:45:33, 23.44it/s]

[0, 0] 504


  1%|▋                                                                                      | 7905/999999 [05:27<11:44:13, 23.48it/s]

[0, -1] 130


  1%|▋                                                                                      | 7974/999999 [05:30<11:47:16, 23.38it/s]

aim-->reach


  1%|▋                                                                                      | 8265/999999 [05:42<11:57:37, 23.03it/s]

[0, 1] 359


  1%|▋                                                                                      | 8274/999999 [05:43<11:52:18, 23.20it/s]

[-1, 1] 8


  1%|▋                                                                                      | 8409/999999 [05:48<11:54:16, 23.14it/s]

[1, 1] 135


  1%|▋                                                                                      | 8535/999999 [05:54<12:01:37, 22.90it/s]

[1, -1] 126


  1%|▋                                                                                      | 8598/999999 [05:57<12:37:12, 21.82it/s]

reach-->lick


  1%|▊                                                                                      | 8895/999999 [06:10<11:43:09, 23.49it/s]

[1, 2] 357


  1%|▊                                                                                      | 8898/999999 [06:10<12:00:14, 22.93it/s]

Mouse received reward
[-1, 2] 4


  1%|▊                                                                                      | 9042/999999 [06:16<10:58:18, 25.09it/s]

[2, 2] 141


  1%|▊                                                                                      | 9195/999999 [06:22<11:03:47, 24.88it/s]

[2, -1] 151


  1%|▊                                                                                      | 9279/999999 [06:25<11:01:51, 24.95it/s]

lick-->aim


  1%|▊                                                                                      | 9567/999999 [06:37<10:53:39, 25.25it/s]

[2, 0] 374


  1%|▊                                                                                      | 9579/999999 [06:37<11:00:51, 24.98it/s]

[-1, 0] 9


  1%|▊                                                                                      | 9699/999999 [06:42<10:56:10, 25.15it/s]

[0, 0] 119


  1%|▊                                                                                      | 9840/999999 [06:47<10:55:49, 25.16it/s]

[0, -1] 140


  1%|▉                                                                                     | 10347/999999 [07:08<10:55:53, 25.15it/s]

[0, 0] 505


  1%|▉                                                                                     | 10470/999999 [07:13<11:04:13, 24.83it/s]

[0, -1] 123


  1%|▉                                                                                     | 10536/999999 [07:15<10:53:35, 25.23it/s]

aim-->reach


  1%|▉                                                                                     | 10830/999999 [07:27<10:55:43, 25.14it/s]

[0, 1] 359


  1%|▉                                                                                     | 10836/999999 [07:27<11:04:36, 24.81it/s]

[-1, 1] 6


  1%|▉                                                                                     | 10977/999999 [07:33<10:57:50, 25.06it/s]

[1, 1] 138


  1%|▉                                                                                     | 11112/999999 [07:38<11:15:28, 24.40it/s]

[1, -1] 136


  1%|▉                                                                                     | 11172/999999 [07:41<10:54:43, 25.17it/s]

reach-->lick
lick-->lick


  1%|▉                                                                                     | 11223/999999 [07:43<10:49:56, 25.36it/s]

In [None]:
overlaps_ctx = sequences_ctx[0].overlaps(net_ctx, ctx)
overlaps_bg = sequences_bg[0].overlaps(net_bg, bg)
correlations_ctx = sequences_ctx[0].overlaps(net_ctx, ctx, correlation=True)
correlations_bg = sequences_bg[0].overlaps(net_bg, bg, correlation=True)
filename = 'learning-0005-1600-600-5-1000-v0'
np.savez('./data/' + filename + '.npz', 
         overlaps_ctx=overlaps_ctx, overlaps_bg=overlaps_bg, 
         correlations_ctx=correlations_ctx, correlations_bg=correlations_bg, 
         state_ctx=net_ctx.exc.state, state_bg=net_bg.exc.state)