In [None]:
import sys
sys.path.insert(0, '../../../network')

In [None]:
import logging
import argparse
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from network import Population, RateNetwork
from transfer_functions import ErrorFunction
from connectivity import SparseConnectivity, LinearSynapse, ThresholdPlasticityRule
from sequences import GaussianSequence

logging.basicConfig(level=logging.INFO)

In [None]:
phi = ErrorFunction(mu=0.22, sigma=0.1).phi
ctx = Population(N=1000, tau=1e-2, phi=phi, name='ctx')
bg = Population(N=1000, tau=1e-2, phi=phi, name='bg')
plasticity = ThresholdPlasticityRule(x_f=0.5, q_f=0.8) 

S, P = 1, 3
sequences_ctx = [GaussianSequence(P,ctx.size, seed=5) for i in range(S)]
patterns_ctx = np.stack([s.inputs for s in sequences_ctx])
sequences_bg = [GaussianSequence(P,ctx.size, seed=27) for i in range(S)]
patterns_bg = np.stack([s.inputs for s in sequences_bg])

In [None]:
As = [np.array([5, 5, 1, i]) for i in np.arange(0, 3, 0.1)]
for A in As:
    A_cc, A_bb, A_cb, A_bc = A[0], A[1], A[2], A[3]
    J_cc = SparseConnectivity(source=ctx, target=ctx, p=0.05)
    synapse_cc = LinearSynapse(J_cc.K, A=A_cc)
    J_cc.store_attractors(patterns_ctx[0], patterns_ctx[0], synapse_cc.h_EE, 
                          plasticity.f, plasticity.g)
    J_bb = SparseConnectivity(source=bg, target=bg, p=0.05)
    synapse_bb = LinearSynapse(J_bb.K, A=A_bb)
    J_bb.store_attractors(patterns_bg[0], patterns_bg[0], synapse_bb.h_EE, 
                          plasticity.f, plasticity.g)
    J_cb  = SparseConnectivity(source=bg, target=ctx, p=0.05)
    synapse_cb = LinearSynapse(J_cb.K, A=A_cb)
    J_cb.store_attractors(patterns_bg[0], patterns_ctx[0], synapse_cb.h_EE, 
                          plasticity.f, plasticity.g)
    J_bc = SparseConnectivity(source=ctx, target=bg, p=0.05)
    synapse_bc = LinearSynapse(J_cc.K, A=0)
    J_bc.store_sequences(patterns_ctx, patterns_bg, synapse_bc.h_EE, plasticity.f, plasticity.g)
    J_bc.update_sequences(patterns_ctx[0][0], patterns_bg[0][1],
                       A_bc, lamb=1,f=plasticity.f, g=plasticity.g)
    J_bc.update_sequences(patterns_ctx[0][1], patterns_bg[0][2],
                       A_bc, lamb=1,f=plasticity.f, g=plasticity.g)
    J_bc.update_sequences(patterns_ctx[0][2], patterns_bg[0][0],
                       A_bc, lamb=1,f=plasticity.f, g=plasticity.g)
    net_ctx = RateNetwork(ctx, c_EE=J_cc, c_IE=J_bc, formulation=4)
    net_bg = RateNetwork(bg, c_II=J_bb, c_EI=J_cb, formulation=4)
    
    init_input_ctx = phi(patterns_ctx[0][0])
    init_input_bg = phi(patterns_bg[0][0])
    T=2
    net_ctx.simulate_euler2(net_bg, T, init_input_ctx, init_input_bg)
    overlaps_ctx = sequences_ctx[0].overlaps(net_ctx, ctx, phi=phi)
    overlaps_bg = sequences_bg[0].overlaps(net_bg, bg, phi=phi)
    corr_ctx = sequences_ctx[0].overlaps(net_ctx, ctx, phi=phi, correlation=True)
    corr_bg = sequences_ctx[0].overlaps(net_bg, bg, phi=phi, correlation=True)
    np.savez('/Users/stanleypark/Desktop/code/hebbian_sequence_learning/figures/article/1/output/'+str(A)+'.npz', 
         overlaps_ctx=overlaps_ctx, overlaps_bg=overlaps_bg, corr_ctx=corr_ctx, corr_bg=corr_bg)