### Activity profile correlation: lambda vs sigma_z

In [1]:
import argparse
import pdb
import itertools
import sys
import os
import logging
import copy
import time
import signal
import numpy as np
import scipy
import ray
from tqdm import tqdm, trange
from apply import apply

In [2]:
sys.path.insert(0, '/home/mhg19/Manuscripts/PNAS19/network')
from network import Population, RateNetwork
from transfer_functions import ErrorFunction
from connectivity import SparseConnectivity, LinearSynapse, ThresholdPlasticityRule
from sequences import GaussianSequence

In [3]:
ray.init(redis_address="10.122.160.26:6382", include_webui=True, ignore_reinit_error=True)

{'node_ip_address': '10.122.160.26',
 'redis_address': '10.122.160.26:6382',
 'object_store_address': '/tmp/ray/session_2020-01-23_19-15-43_013143_957635/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2020-01-23_19-15-43_013143_957635/sockets/raylet',
 'webui_url': 'http://10.122.160.26:8080/?token=45fa8858d23069adcb4476112d91a9cb6134286800bea7f2',
 'session_dir': '/tmp/ray/session_2020-01-23_19-15-43_013143_957635'}

In [4]:
@ray.remote
def f():
    time.sleep(0.01)
    return ray.services.get_node_ip_address()

# Get a list of the IP addresses of the nodes that have joined the cluster.
print(set(ray.get([f.remote() for _ in range(1000)])))

{'10.122.160.21', '10.122.160.26', '10.122.160.27', '10.122.160.34', '10.122.160.25', '10.122.160.35'}


In [5]:
def run_simulation(lambda_, sigma_z):
    N = 40000
    T = 0.4
    S, P = 1, 30
    n_days = 30 #30
    
    exc = Population(N, tau=1e-2, phi=ErrorFunction(mu=0.07, sigma=0.05).phi)
    
    conn = SparseConnectivity(source=exc, target=exc, p=0.005, seed=42, disable_pbar=True)
    sequences = [GaussianSequence(P,exc.size,seed=i) for i in range(S)]
    patterns = np.stack([s.inputs for s in sequences])
    
    plasticity = ThresholdPlasticityRule(x_f=1.645, q_f=0.8)
    synapse = LinearSynapse(conn.K, A=14)
    conn.store_sequences(patterns, synapse.h_EE, plasticity.f, plasticity.g)
    
    # Weight components
    rng = np.random.RandomState(seed=43)
    W_sequence = np.copy(conn.W.data)
    W_pert = np.zeros_like(W_sequence) 
    
    state = []
    overlaps = []
    correlations = []
    for n in range(0, n_days):
        z_n = rng.normal(scale=1, size=conn.W.data.size)
        if n == 0:
            W_pert = sigma_z*z_n
        else:
            W_pert = lambda_*W_pert + np.sqrt(1-lambda_**2)*sigma_z*z_n
        conn.W.data = W_sequence + W_pert
        
        net = RateNetwork(
            exc,
            c_EE=conn,
            formulation=1,
            disable_pbar=True)
        net.clear_state()
        net.simulate(
            T,
            r0=exc.phi(plasticity.f(patterns[0,0,:])))
        m = sequences[0].overlaps(
            net,
            exc,
            plasticity=plasticity,
            correlation=False,
            disable_pbar=True)
        rho = sequences[0].overlaps(
            net,
            exc,
            plasticity=plasticity,
            correlation=True,
            disable_pbar=True)
        state.append(net.exc.state.astype(np.float32))
        overlaps.append(m.astype(np.float32))
        correlations.append(rho.astype(np.float32))
        
    # State correlation
    corr = np.zeros((n_days, N))
    corr_rev = np.zeros((n_days, N))
    for n in range(n_days):
        corr[n,:] = np.asarray([
            scipy.stats.pearsonr(
                state[0][i],
                state[n][i])[0] for i in range(N)])
        corr_rev[n,:] = np.asarray([
            scipy.stats.pearsonr(
                state[-1][i],
                state[n][i])[0] for i in range(N)])
        
    # Overlap similarity
    overlap_similarity_score = 1 - np.sum((overlaps[0]-overlaps[-1])**2)/np.sum(overlaps[0]**2)
        
    return {
        'params': {
            'lambda': lambda_,
            'sigma_z': sigma_z
            
        },
        'state': {
            'corr': corr,
            'corr_rev': corr_rev,
        },
        'overlaps': overlaps,
        'correlations': correlations,
        'overlap_similarity_score': overlap_similarity_score,
        'std_W_sequence': np.std(W_sequence),
    }

Parameter exploration

In [6]:
lambda_ = [0.95] #np.arange(0.4,1.15,0.15)
sigma_z = np.arange(0.01,0.06,0.01)

combinations = list(itertools.product(
    np.atleast_1d(lambda_),
    np.atleast_1d(sigma_z)))

parallel = True
object_ids = []
run_simulation_ray = ray.remote(num_cpus=4)(run_simulation)

n = 0
for lambda__, sigma_z_ in combinations[:]:
    if parallel:
        func = run_simulation_ray.remote
    else:
        func = run_simulation
    object_ids.append(func(lambda__, sigma_z_))
    n += 1

Collect and store results

In [7]:
directory = "/home/mhg19/Manuscripts/PNAS19/figures/notebooks/supplement/3/data/"
pbar = tqdm(total=n)
while len(object_ids) > 0:
    if parallel:
        ready_object_ids, _ = ray.wait(object_ids)
        id_ = ready_object_ids[0]
        data = ray.get(id_)
        object_ids.remove(id_)
    else:
        data = object_ids[0]
        object_ids.remove(data)
    params = data['params']
    lambda_, sigma_z = params['lambda'], params['sigma_z']
    filename = "lambda%.4f_sigmaz%.4f"%(
        lambda_,sigma_z) + ".npy"
    filepath = directory + filename
    np.save(open(filepath, 'wb'), data)
    pbar.update(1)
    time.sleep(1)

  0%|          | 0/5 [00:00<?, ?it/s]



100%|██████████| 5/5 [40:16<00:00, 585.62s/it]   2020-01-24 15:23:00,861	ERROR worker.py:1621 -- listen_error_messages_raylet: Connection closed by server.
2020-01-24 15:23:00,862	ERROR import_thread.py:89 -- ImportThread: Connection closed by server.
2020-01-24 15:23:00,864	ERROR worker.py:1521 -- print_logs: Connection closed by server.


[2m[36m(pid=957692)[0m Traceback (most recent call last):
[2m[36m(pid=957692)[0m   File "/home/mhg19/.local/lib/python3.7/site-packages/ray/workers/default_worker.py", line 98, in <module>
[2m[36m(pid=957692)[0m     ray.worker.global_worker.main_loop()
[2m[36m(pid=957692)[0m   File "/home/mhg19/.local/lib/python3.7/site-packages/ray/worker.py", line 954, in main_loop
[2m[36m(pid=957692)[0m     task = self._get_next_task_from_raylet()
[2m[36m(pid=957692)[0m   File "/home/mhg19/.local/lib/python3.7/site-packages/ray/worker.py", line 937, in _get_next_task_from_raylet
[2m[36m(pid=957692)[0m     task = self.raylet_client.get_task()
[2m[36m(pid=957692)[0m   File "python/ray/_raylet.pyx", line 335, in ray._raylet.RayletClient.get_task
[2m[36m(pid=957692)[0m   File "python/ray/_raylet.pyx", line 109, in ray._raylet.check_status
[2m[36m(pid=957692)[0m ray.exceptions.RayletError: The Raylet died with this message: [RayletClient] Raylet connection closed.
[2m[36m(