In [1]:
import sys
sys.path.insert(0, '../../../network')

In [2]:
import logging
import argparse
import numpy as np
from network import Population, RateNetwork
from learning import ReachingTask
from transfer_functions import ErrorFunction
from connectivity import SparseConnectivity, LinearSynapse, ThresholdPlasticityRule
from sequences import GaussianSequence
import matplotlib.pyplot as plt
import seaborn as sns
logging.basicConfig(level=logging.INFO)

In [3]:
phi = ErrorFunction(mu=0.22, sigma=0.1).phi
ctx = Population(N=1000, tau=1e-2, phi=phi, name='ctx')
bg = Population(N=1000, tau=1e-2, phi=phi, name='bg')
plasticity = ThresholdPlasticityRule(x_f=0.5, q_f=0.8) 

S, P = 1, 3
sequences_ctx = [GaussianSequence(P,ctx.size, seed=114) for i in range(S)]
patterns_ctx = np.stack([s.inputs for s in sequences_ctx])
sequences_bg = [GaussianSequence(P,ctx.size, seed=29) for i in range(S)]
patterns_bg = np.stack([s.inputs for s in sequences_bg])

# Symmetric connections
J_cc = SparseConnectivity(source=ctx, target=ctx, p=0.05)
synapse_cc = LinearSynapse(J_cc.K, A=5)
J_cc.store_attractors(patterns_ctx[0], patterns_ctx[0], synapse_cc.h_EE, 
                      plasticity.f, plasticity.g)
J_bb = SparseConnectivity(source=bg, target=bg, p=0.05)
synapse_bb = LinearSynapse(J_bb.K, A=5)
J_bb.store_attractors(patterns_bg[0], patterns_bg[0], synapse_bb.h_EE, 
                      plasticity.f, plasticity.g)
J_cb  = SparseConnectivity(source=bg, target=ctx, p=0.05)
synapse_cb = LinearSynapse(J_cb.K, A=1)
J_cb.store_attractors(patterns_bg[0], patterns_ctx[0], synapse_cb.h_EE, 
                      plasticity.f, plasticity.g)

# Asymmetric connection
J_bc = SparseConnectivity(source=ctx, target=bg, p=0.05)
synapse_bc = LinearSynapse(J_bc.K, A=0)
J_bc.store_sequences(patterns_ctx, patterns_bg, synapse_bc.h_EE, plasticity.f, plasticity.g)

net_ctx = RateNetwork(ctx, c_EE=J_cc, c_IE=J_bc, formulation=4)
net_bg = RateNetwork(bg, c_II=J_bb, c_EI=J_cb, formulation=4)

INFO:connectivity:Building connections from ctx to ctx
INFO:connectivity:Storing attractors
100%|████████████████████████████████████| 1000/1000 [00:00<00:00, 31623.62it/s]
INFO:connectivity:Building connections from bg to bg
INFO:connectivity:Storing attractors
100%|████████████████████████████████████| 1000/1000 [00:00<00:00, 33764.04it/s]
INFO:connectivity:Building connections from bg to ctx
INFO:connectivity:Storing attractors
100%|████████████████████████████████████| 1000/1000 [00:00<00:00, 34513.63it/s]
INFO:connectivity:Building connections from ctx to bg
INFO:connectivity:Storing sequences
  0%|                                                     | 0/1 [00:00<?, ?it/s]
100%|████████████████████████████████████| 1000/1000 [00:00<00:00, 34404.64it/s][A
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 31.97it/s]
INFO:connectivity:Applying synaptic transfer function
INFO:connectivity:Building sparse matrix


In [None]:
init_input_ctx = np.random.RandomState().normal(0,1,size=patterns_ctx[0][0].shape)
init_input_bg = np.random.RandomState().normal(0,1,size=patterns_bg[0][0].shape)
T=1000 #ms
mouse = ReachingTask()
net_ctx.simulate_learning(mouse, net_bg, T, init_input_ctx, init_input_bg, 
                          patterns_ctx[0], patterns_bg[0], plasticity, 
                          delta_t=300, eta=0.0005, tau_e=1600, lamb=0.5, print_output=True)

INFO:network:Integrating network dynamics
  return 1 - erfc(x)
  0%|                                   | 180/999999 [00:01<1:52:44, 147.80it/s]

null-->lick


  0%|                                   | 266/999999 [00:02<1:11:04, 234.45it/s]

None 0


  0%|                                   | 342/999999 [00:03<2:03:11, 135.24it/s]

[-1, 2] 49


  0%|                                    | 545/999999 [00:07<6:03:19, 45.85it/s]

[2, 2] 188


  0%|                                    | 700/999999 [00:11<6:35:12, 42.14it/s]

[2, -1] 153


  0%|                                    | 755/999999 [00:12<6:38:55, 41.75it/s]

lick-->aim


  0%|                                   | 1055/999999 [00:19<6:37:28, 41.89it/s]

[2, 0] 352
[-1, 0] 1


  0%|                                   | 1205/999999 [00:23<6:37:33, 41.87it/s]

[0, 0] 150


  0%|                                   | 1390/999999 [00:27<6:33:43, 42.27it/s]

[0, -1] 182


  0%|                                   | 1460/999999 [00:29<6:33:12, 42.32it/s]

aim-->lick


  0%|                                   | 1745/999999 [00:36<6:33:02, 42.33it/s]

[0, 2] 354


  0%|                                   | 1760/999999 [00:36<6:38:19, 41.77it/s]

[-1, 2] 14


  0%|                                   | 1895/999999 [00:39<6:55:55, 39.99it/s]

[2, 2] 134


  0%|                                   | 2064/999999 [00:44<6:37:48, 41.81it/s]

[2, -1] 169


  0%|                                   | 2124/999999 [00:45<6:36:12, 41.98it/s]

lick-->lick


  0%|                                   | 2129/999999 [00:45<6:42:50, 41.28it/s]

lick-->aim


  0%|                                   | 2419/999999 [00:52<6:36:37, 41.92it/s]

[2, 0] 355
[-1, 0] 0
[2, 0] 0


  0%|                                   | 2429/999999 [00:52<6:40:30, 41.51it/s]

[-1, 0] 7


  0%|                                   | 2569/999999 [00:56<6:38:07, 41.75it/s]

[0, 0] 139


  0%|                                   | 2724/999999 [00:59<6:36:07, 41.96it/s]

[0, -1] 150


  0%|                                   | 3229/999999 [01:11<6:35:39, 41.99it/s]

[0, 0] 505


  0%|                                   | 3364/999999 [01:15<6:34:57, 42.06it/s]

[0, -1] 134


  0%|                                   | 3444/999999 [01:16<6:36:07, 41.93it/s]

aim-->aim


  0%|                                   | 3454/999999 [01:17<6:40:19, 41.49it/s]

aim-->reach


  0%|▏                                  | 3744/999999 [01:24<6:32:43, 42.28it/s]

[0, 1] 379
[-1, 1] 0
[0, 1] 1


  0%|▏                                  | 3754/999999 [01:24<6:43:39, 41.13it/s]

[-1, 1] 8


  0%|▏                                  | 3869/999999 [01:27<6:34:43, 42.06it/s]

[1, 1] 113


  0%|▏                                  | 4029/999999 [01:30<6:36:36, 41.85it/s]

[1, -1] 158


  0%|▏                                  | 4534/999999 [01:42<6:37:51, 41.70it/s]

[1, 1] 505


  0%|▏                                  | 4684/999999 [01:46<6:35:26, 41.95it/s]

[1, -1] 148


  0%|▏                                  | 4804/999999 [01:49<6:35:16, 41.96it/s]

reach-->reach
reach-->aim


  1%|▏                                  | 5104/999999 [01:56<6:38:16, 41.63it/s]

[1, 0] 418
[-1, 0] 0
[1, 0] 0
[-1, 0] 2


  1%|▏                                  | 5189/999999 [01:58<6:38:04, 41.65it/s]

[0, 0] 81


  1%|▏                                  | 5324/999999 [02:01<6:38:40, 41.58it/s]

[0, -1] 132


  1%|▏                                  | 5399/999999 [02:03<6:37:13, 41.73it/s]

aim-->reach


  1%|▏                                  | 5694/999999 [02:10<6:35:45, 41.87it/s]

[0, 1] 369
[-1, 1] 4


  1%|▏                                  | 5829/999999 [02:14<6:36:00, 41.84it/s]

[1, 1] 129


  1%|▏                                  | 5984/999999 [02:17<6:34:16, 42.02it/s]

[1, -1] 158


  1%|▏                                  | 6494/999999 [02:29<6:35:23, 41.88it/s]

[1, 1] 505


  1%|▏                                  | 6669/999999 [02:34<6:36:11, 41.79it/s]

[1, -1] 178


  1%|▏                                  | 6719/999999 [02:35<6:35:43, 41.83it/s]

reach-->lick


  1%|▏                                  | 7019/999999 [02:42<6:37:03, 41.68it/s]

[1, 2] 345
Mouse received reward
[-1, 2] 2


  1%|▎                                  | 7179/999999 [02:46<6:40:21, 41.33it/s]

[2, 2] 156


  1%|▎                                  | 7319/999999 [02:49<6:37:55, 41.58it/s]

[2, -1] 141


  1%|▎                                  | 7394/999999 [02:51<6:40:14, 41.33it/s]

lick-->aim


  1%|▎                                  | 7679/999999 [02:58<6:40:47, 41.27it/s]

[2, 0] 361


  1%|▎                                  | 7694/999999 [02:58<6:43:48, 40.96it/s]

[-1, 0] 14


  1%|▎                                  | 7824/999999 [03:01<6:39:38, 41.38it/s]

[0, 0] 127


  1%|▎                                  | 7974/999999 [03:05<6:37:36, 41.58it/s]

[0, -1] 151


  1%|▎                                  | 8054/999999 [03:07<6:38:19, 41.50it/s]

aim-->lick


  1%|▎                                  | 8344/999999 [03:14<6:38:59, 41.42it/s]

[0, 2] 365


  1%|▎                                  | 8354/999999 [03:14<6:42:58, 41.01it/s]

[-1, 2] 11


  1%|▎                                  | 8484/999999 [03:17<6:37:18, 41.59it/s]

[2, 2] 127


  1%|▎                                  | 8634/999999 [03:21<6:37:33, 41.56it/s]

[2, -1] 151


  1%|▎                                  | 8694/999999 [03:22<6:39:37, 41.34it/s]

lick-->reach


  1%|▎                                  | 8994/999999 [03:30<6:40:45, 41.21it/s]

[2, 1] 358
[-1, 1] 2


  1%|▎                                  | 9139/999999 [03:33<6:38:53, 41.40it/s]

[1, 1] 143


  1%|▎                                  | 9264/999999 [03:36<6:38:45, 41.41it/s]

[1, -1] 123


  1%|▎                                  | 9769/999999 [03:48<6:39:12, 41.34it/s]

[1, 1] 504


  1%|▎                                  | 9899/999999 [03:52<6:36:45, 41.59it/s]

[1, -1] 131


  1%|▎                                 | 10409/999999 [04:04<6:36:02, 41.65it/s]

[1, 1] 505


  1%|▎                                 | 10544/999999 [04:07<6:33:45, 41.88it/s]

[1, -1] 136


  1%|▍                                 | 11048/999999 [04:19<6:39:18, 41.28it/s]

[1, 1] 504


  1%|▍                                 | 11172/999999 [04:22<6:35:20, 41.69it/s]

[1, -1] 124


  1%|▍                                 | 11682/999999 [04:35<7:00:46, 39.15it/s]

[1, 1] 505


  1%|▍                                 | 11808/999999 [04:38<6:39:15, 41.25it/s]

[1, -1] 129


  1%|▍                                 | 11926/999999 [04:41<6:46:49, 40.48it/s]

reach-->aim
aim-->aim


  1%|▍                                 | 12221/999999 [04:49<6:39:56, 41.16it/s]

[1, 0] 409
[-1, 0] 4
[0, 0] 0
[-1, 0] 0


  1%|▍                                 | 12316/999999 [04:51<6:47:00, 40.44it/s]

[0, 0] 88


  1%|▍                                 | 12484/999999 [04:55<6:40:55, 41.05it/s]

[0, -1] 167
[0, 0] 0
[0, -1] 1
[0, 0] 1
[0, -1] 0


  1%|▍                                 | 12998/999999 [05:07<6:31:52, 41.98it/s]

[0, 0] 505


  1%|▍                                 | 13133/999999 [05:11<6:33:15, 41.82it/s]

[0, -1] 138


  1%|▍                                 | 13643/999999 [05:23<6:28:48, 42.28it/s]

[0, 0] 505


  1%|▍                                 | 13778/999999 [05:26<6:31:10, 42.02it/s]

[0, -1] 138


  1%|▍                                 | 13858/999999 [05:28<6:32:42, 41.85it/s]

aim-->reach


  1%|▍                                 | 14143/999999 [05:35<6:37:56, 41.29it/s]

[0, 1] 363


  1%|▍                                 | 14158/999999 [05:35<6:40:01, 41.07it/s]

[-1, 1] 12


  1%|▍                                 | 14288/999999 [05:38<6:35:30, 41.54it/s]

[1, 1] 128


  1%|▍                                 | 14413/999999 [05:41<6:35:09, 41.57it/s]

[1, -1] 125


  1%|▌                                 | 14918/999999 [05:54<6:36:50, 41.37it/s]

[1, 1] 504


  2%|▌                                 | 15043/999999 [05:57<6:36:15, 41.43it/s]

[1, -1] 127


  2%|▌                                 | 15553/999999 [06:09<6:36:59, 41.33it/s]

[1, 1] 505


  2%|▌                                 | 15708/999999 [06:13<6:37:10, 41.30it/s]

[1, -1] 154


  2%|▌                                 | 16213/999999 [06:25<6:34:36, 41.55it/s]

[1, 1] 505


  2%|▌                                 | 16348/999999 [06:28<6:34:19, 41.57it/s]

[1, -1] 135


  2%|▌                                 | 16413/999999 [06:30<6:36:10, 41.38it/s]

reach-->lick


  2%|▌                                 | 16708/999999 [06:37<6:37:45, 41.20it/s]

[1, 2] 361
Mouse received reward
[-1, 2] 1


  2%|▌                                 | 16853/999999 [06:40<6:31:23, 41.87it/s]

[2, 2] 141


  2%|▌                                 | 16988/999999 [06:43<6:38:40, 41.10it/s]

[2, -1] 131


  2%|▌                                 | 17037/999999 [06:45<6:33:00, 41.68it/s]

lick-->reach


  2%|▌                                 | 17337/999999 [06:52<6:32:01, 41.78it/s]

[2, 1] 348
[-1, 1] 1


  2%|▌                                 | 17492/999999 [06:56<6:32:53, 41.68it/s]

[1, 1] 153


  2%|▌                                 | 17607/999999 [06:58<7:04:42, 38.55it/s]

[1, -1] 116


  2%|▌                                 | 18116/999999 [07:11<6:41:38, 40.74it/s]

[1, 1] 505


  2%|▌                                 | 18231/999999 [07:14<6:39:56, 40.91it/s]

[1, -1] 114


  2%|▋                                 | 18736/999999 [07:26<6:34:01, 41.51it/s]

[1, 1] 505


  2%|▋                                 | 18851/999999 [07:29<6:37:47, 41.11it/s]

[1, -1] 113


  2%|▋                                 | 19353/999999 [07:41<6:30:38, 41.84it/s]

[1, 1] 504


  2%|▋                                 | 19473/999999 [07:44<6:46:05, 40.24it/s]

[1, -1] 116


  2%|▋                                 | 19602/999999 [07:47<6:30:24, 41.85it/s]

In [None]:
overlaps_ctx = sequences_ctx[0].overlaps(net_ctx, ctx)
overlaps_bg = sequences_bg[0].overlaps(net_bg, bg)
correlations_ctx = sequences_ctx[0].overlaps(net_ctx, ctx, correlation=True)
correlations_bg = sequences_bg[0].overlaps(net_bg, bg, correlation=True)
filename = 'learning-0005-1600-600-5-1000-v0'
np.savez('./data/' + filename + '.npz', 
         overlaps_ctx=overlaps_ctx, overlaps_bg=overlaps_bg, 
         correlations_ctx=correlations_ctx, correlations_bg=correlations_bg, 
         state_ctx=net_ctx.exc.state, state_bg=net_bg.exc.state)