In [1]:
import sys
sys.path.insert(0, '../../../network')

In [2]:
import logging
import argparse
import numpy as np
from network import Population, RateNetwork
from learning import ReachingTask
from transfer_functions import ErrorFunction
from connectivity import SparseConnectivity, LinearSynapse, ThresholdPlasticityRule
from sequences import GaussianSequence
import matplotlib.pyplot as plt
# np.set_printoptions(threshold=sys.maxsize)
logging.basicConfig(level=logging.INFO)

In [3]:
phi = ErrorFunction(mu=0.22, sigma=0.1).phi
ctx = Population(N=3000, tau=1e-2, phi=phi, name='ctx')
bg = Population(N=3000, tau=1e-2, phi=phi, name='bg')
plasticity = ThresholdPlasticityRule(x_f=0.5, q_f=0.8) 

S, P = 1, 4
sequences_ctx = [GaussianSequence(P,ctx.size, seed=11) for i in range(S)]
patterns_ctx = np.stack([s.inputs for s in sequences_ctx])
sequences_bg = [GaussianSequence(P,ctx.size, seed=367) for i in range(S)]
patterns_bg = np.stack([s.inputs for s in sequences_bg])

J_cc = SparseConnectivity(source=ctx, target=ctx, p=0.5)
synapse_cc = LinearSynapse(J_cc.K, A=6)
J_cc.store_attractors(patterns_ctx[0], patterns_ctx[0], synapse_cc.h_EE, plasticity.f, plasticity.g)
J_bb = SparseConnectivity(source=bg, target=bg, p=0.5)
synapse_bb = LinearSynapse(J_bb.K, A=6)
J_bb.store_attractors(patterns_bg[0], patterns_bg[0], synapse_bb.h_EE, plasticity.f, plasticity.g)
J_cb  = SparseConnectivity(source=bg, target=ctx, p=0.5)
synapse_cb = LinearSynapse(J_cb.K, A=5.5)
J_cb.store_attractors(patterns_bg[0], patterns_ctx[0], synapse_cb.h_EE, plasticity.f, plasticity.g)

J_bc = SparseConnectivity(source=ctx, target=bg, p=0.5)
synapse_bc = LinearSynapse(J_bc.K, A=0)
J_bc.store_sequences(patterns_ctx, patterns_bg, synapse_bc.h_EE, plasticity.f, plasticity.g)

net_ctx = RateNetwork(ctx, c_EE=J_cc, c_IE=J_bc, formulation=4, disable_pbar=False)
net_bg = RateNetwork(bg, c_II=J_bb, c_EI=J_cb, formulation=4)

INFO:connectivity:Building connections from ctx to ctx
INFO:connectivity:Storing attractors
100%|█████████████████████████████████████| 3000/3000 [00:01<00:00, 1788.71it/s]
INFO:connectivity:Building connections from bg to bg
INFO:connectivity:Storing attractors
100%|█████████████████████████████████████| 3000/3000 [00:01<00:00, 1824.53it/s]
INFO:connectivity:Building connections from bg to ctx
INFO:connectivity:Storing attractors
100%|█████████████████████████████████████| 3000/3000 [00:01<00:00, 1812.25it/s]
INFO:connectivity:Building connections from ctx to bg
INFO:connectivity:Storing sequences
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                  | 0/3000 [00:00<?, ?it/s][A
  4%|█▎                                    | 108/3000 [00:00<00:02, 1074.52it/s][A
  8%|███▏                                  | 250/3000 [00:00<00:02, 1276.08it/s][A
 14%|█████▍                                | 427/3000 [00:00<

In [None]:
init_input_ctx = np.random.RandomState().normal(0,1,size=patterns_ctx[0][0].shape)
init_input_bg = np.random.RandomState().normal(0,1,size=patterns_bg[0][0].shape)
# init_input_ctx = patterns_ctx[0][2]
# init_input_bg = patterns_bg[0][2]
T=3
mouse = ReachingTask(2, alpha=0.2)
behaviors = net_ctx.simulate_learning(mouse, net_bg, T, init_input_ctx, init_input_bg, patterns_ctx[0], patterns_bg[0], plasticity, noise=20)

INFO:network:Integrating network dynamics
  0%|                                                  | 0/2999 [00:00<?, ?it/s]

updating
updating
updating


  2%|▊                                        | 58/2999 [01:23<33:10,  1.48it/s]

In [None]:
init_input_ctx = np.random.RandomState().normal(0,1,size=patterns_ctx[0][0].shape)
init_input_bg = np.random.RandomState().normal(0,1,size=patterns_bg[0][0].shape)
T=10
net_ctx.simulate_euler2(net_bg, T, init_input_ctx, init_input_bg, noise=10)

In [None]:
overlaps = sequences_ctx[0].overlaps(net_ctx, ctx, phi=phi)
correlations = sequences_ctx[0].overlaps(net_ctx, ctx, phi=phi, correlation=True)

In [None]:
for i, row in enumerate(correlations):
    plt.plot(row, label=str(i))
    plt.xlim(25000, 35000)
plt.legend()

In [None]:
net_ctx.W.toarray().shape
