In [None]:
import argparse
import sys
import logging
import numpy as np
import scipy

sys.path.insert(0, '../../../network')
from network import Population, RateNetwork
from transfer_functions import ErrorFunction 
from connectivity import SparseConnectivity, LinearSynapse, ThresholdPlasticityRule
from sequences import GaussianSequence

import matplotlib.pyplot as plt

logging.basicConfig(level=logging.INFO)

In [None]:
datapath = "data/data_a.npy"

In [None]:
N_E = 40000
T = 0.4

exc = Population(N_E, tau=1e-2, phi=ErrorFunction(mu=0.07, sigma=0.05).phi)
conn = SparseConnectivity(source=exc, target=exc, p=0.005, seed=123)

P, S = 2, 30
sequences = [GaussianSequence(S,exc.size,seed=i) for i in range(P)]
patterns = np.stack([s.inputs for s in sequences])

plasticity = ThresholdPlasticityRule(x_f=1.645, q_f=0.8)
synapse = LinearSynapse(conn.K, A=14)

conn.store_sequences(patterns, synapse.h_EE, plasticity.f, plasticity.g)

net = RateNetwork(exc, c_EE=conn, formulation=1)

r0 = np.zeros(exc.size)
r0[:] = exc.phi(plasticity.f(patterns[0,0,:]))
net.simulate(T, r0=r0)
state1 = np.copy(net.exc.state.T).astype(np.float32)

net.clear_state()
r0[:] = exc.phi(plasticity.f(patterns[1,0,:]))
net.simulate(T, r0=r0)
state2 = np.copy(net.exc.state.T).astype(np.float32)

# Overlaps
overlaps = sequences[1].overlaps(net, exc)

# Average squared rate
M = np.mean(net.exc.state**2, axis=0)

logging.info("Saving data")
np.save(open(datapath, "wb"), [state1, state2])