In [None]:
from __future__ import division

import os
import sys
import logging
import argparse
import itertools

import numpy as np
from joblib import Parallel, delayed

sys.path.insert(0, '../../../network')
from network import Population, RateNetwork
from transfer_functions import ErrorFunction
from connectivity import SparseConnectivity, LinearSynapse
from sequences import GaussianSequence

logging.basicConfig(level=logging.INFO)

In [None]:
cpu_cores = 8

In [None]:
def simulate_network(z_sigma, S, P):

    phi = ErrorFunction(mu=0.22, sigma=0.1).phi
    exc = Population(N=40000, tau=1e-2, phi=phi)

    sequences = [GaussianSequence(P, exc.size, seed=i) for i in range(S)]
    patterns = np.stack([s.inputs for s in sequences])

    conn_EE = SparseConnectivity(source=exc, target=exc, p=0.005)
    synapse = LinearSynapse(conn_EE.K, A=1)
    conn_EE.store_sequences(patterns, synapse.h_EE)

    net = RateNetwork(exc, c_EE=conn_EE, formulation=1)

    z0 = z_sigma*np.random.RandomState(seed=42).normal(0,1,size=exc.size)
    r0 = exc.phi(patterns[0,0,:] + z0)
    net.simulate(0.3, r0=r0)

    overlaps = sequences[0].overlaps(net, exc)
    return overlaps

In [None]:
def compute_perturbations(z_sigma, S, P):

    def func(i):
        overlaps = simulate_network(z_sigma[i], S, P)
        return overlaps

    overlaps_unperturbed = simulate_network(0, S, P)
    overlaps_perturbed_set = Parallel(n_jobs=args.cpu_cores)(delayed(func)(i) for i in range(len(z_sigma)))

    t_max = np.argmax(overlaps_unperturbed[-1])
    t_mid = int(t_max / 2)

    # Average norm on full length of simulation
    distance1 = []
    for overlaps in overlaps_perturbed_set:
        avg_norm = np.mean(np.sqrt(np.sum(overlaps[:,:t_max]**2, axis=0)))
        distance1.append(avg_norm)
    
    # Average norm of second half of simulation
    distance2 = []
    for overlaps in overlaps_perturbed_set:
        avg_norm = np.mean(np.sqrt(np.sum(overlaps[:,t_mid:t_max]**2, axis=0)))
        distance2.append(avg_norm)

    # Average norm of unperturbed recall
    distance3 = np.mean(np.sqrt(np.sum(overlaps_unperturbed[:,:t_max]**2, axis=0)))
    distance4 = np.mean(np.sqrt(np.sum(overlaps_unperturbed[:,t_mid:t_max]**2, axis=0)))

    return distance1, distance2, distance3, distance4

In [None]:
z_sigma1 = np.arange(0.1, 2.6, 0.1)
distances1 = compute_perturbations(z_sigma1, S=1, P=16)

In [None]:
z_sigma2 = np.linspace(0.1, 10, 20)
distances2 = compute_perturbations(z_sigma2, S=4, P=16)

In [None]:
np.save(open("data/data_a.npy", "wb"), [z_sigma1, z_sigma2, distances1, distances2])