In [3]:
import sys, os, pickle
import torch
sys.path.append('/home/om2382/mft-theory/')
from cluster import *
from core import *
from empirics import *
from functions import *
from ode_methods import *
from plotting import *
from theory import *
from utils import *
from functools import partial
import matplotlib.pyplot as plt

In [2]:
### To do
#memory issue to get more data points and larger networks in estimating psi
#turn psi estimation into function suite for different cases
#find relevant range of connectivity parameter values to simluate

In [4]:
### --- SET UP ALL CONFIGS --- ###
from itertools import product
n_seeds = 5
macro_configs = config_generator(g=[3, 5.5, 8, 10, 15, 20])

micro_configs = tuple(product(macro_configs, list(range(n_seeds))))
prototype = False

### --- SELECT PARTICULAR CONFIG --- ###
try:
    i_job = int(os.environ['SLURM_ARRAY_TASK_ID']) - 1
except KeyError:
    i_job = 0
    prototype = True
params, i_seed = micro_configs[i_job]
i_config = i_job//n_seeds

new_random_seed_per_condition = True
if new_random_seed_per_condition:
    np.random.seed(i_job)
else: #Match random seeds across conditions
    np.random.seed(i_seed)

In [5]:
### --- Set parameters --- ###
N = 2000
g = params['g']
T_window = 100
dT = 0.05
phi_torch = lambda x: torch.erf((np.sqrt(np.pi)/2)*x)
phi_prime_torch = lambda x: torch.exp(-(np.pi/4)*x**2)
lags = np.arange(0, T_window, dT)
n_lags = int(T_window/dT)
lags_full = np.concatenate([-lags[-n_lags:][::-1], lags[1:n_lags], np.array([lags[-1]])])

In [6]:
### --- Estimate psi empirically --- ###

dT_emp = 0.5
lags_emp = np.arange(0, T_window, dT_emp)
n_lags_emp = int(T_window/dT_emp)

#compute psi
W = g*np.random.normal(0, 1/np.sqrt(N), (N, N))
W_ = torch.from_numpy(W).type(torch.FloatTensor).to(0)
Psi = estimate_psi(lags_emp, T_sim=5000, dt_save=dT_emp, dt=0.1,
                   W=W_, phi_torch=phi_torch, T_save_delay=100, N_batch=1, N_loops=10)

#symmetrize psi for comparison
Psi = Psi.cpu().detach().numpy()
Psi_emp = np.concatenate([Psi[-n_lags_emp:][::-1], Psi[1:n_lags_emp]])
lags_emp_full = np.concatenate([-lags_emp[-n_lags_emp:][::-1], lags_emp[1:n_lags_emp]])

In [None]:
### -- Optional: measure single-unit properties empirically --- ###
if False:
    phi_prime = lambda x: torch.exp(-(np.pi/4)*x**2)
    generate_W = lambda g, N: g*np.random.normal(0, 1/np.sqrt(N), (N, N))
    
    C_Phi_half_emp, alpha_emp = estimate_single_unit_autocov_and_alpha(lags, T_sim=3000, dt_save=dT, dt=0.05,
                                                                       generate_W=partial(generate_W, g=g, N=N),
                                                                       N=N, phi_torch=phi_torch, phi_prime=phi_prime_torch,
                                                                       T_save_delay=100,
                                                                       N_batch=2, N_disorder=2)
if False:
    alpha = alpha_emp
    C_Phi_half = C_Phi_half_emp.cpu().detach().numpy()

In [None]:
### --- Compute single-unit properties

d = compute_Delta_0(g=g)
time, Delta_T = integrate_potential(d, g=g, tau_max=T_window, N_tau=int(T_window/dT))
Delta_T = fix(Delta_T)
C_Phi_half = compute_C_simple(d, Delta_T)
alpha = compute_phi_prime_avg(d)

In [None]:
### --- Compute Psi from theory --- ###

#Define relevant single-unit functions
C_phi = np.concatenate([C_Phi_half[-n_lags:][::-1],
                        C_Phi_half[1:n_lags],
                        np.array([C_Phi_half[-1]])])
C_phi_omega = fft(C_phi, dT)
T = len(C_phi)
t_indices= np.concatenate([np.arange(0, T//2), np.arange(-T//2, 0)])
sampfreq = 1/dT
w = 2*np.pi*sampfreq*t_indices/T
C_phi_C_phi = np.multiply.outer(C_phi_omega, C_phi_omega)
S_phi = alpha/(np.sqrt(2*np.pi)*(1 + 1j*w))
S_phi_S_phi = np.multiply.outer(S_phi, S_phi)

#Compute psi for iid random network
Psi_PRX = (1/(np.abs(1 - 2*np.pi*(g**2) * (S_phi_S_phi))**2) - 1)*C_phi_C_phi
Psi_PRX_tau = ifft(Psi_PRX, dT)
Psi_tau_tau = np.diag(Psi_PRX_tau)

In [None]:
### --- Double check with David's code --- ###
Psi_tau_2 = compute_psi_theory(symmetrize(Delta_T), g, dT)
Ptt2 = np.diag(Psi_tau_2)
Psi_tau_tau_2 = np.concatenate([Ptt2[-len(Ptt2)//2:], Ptt2[:len(Ptt2)//2]])

In [None]:
#processed_data = np.array([Psi_emp, Psi_tau_tau, Psi_tau_tau_2])
processed_data = np.array([lags_emp_full, Psi_emp])

In [None]:
#plt.plot(lags_emp_full, Psi_emp)
#plt.plot(lags_full, Psi_tau_tau)
#plt.plot(lags_full, Psi_tau_tau_2)

In [None]:
### --- SAVE RESULTS -- ###
result = {'sim': None, 'i_seed': i_seed, 'config': params,
          'i_config': i_config, 'i_job': i_job}
try:
    result['processed_data'] = processed_data
except NameError:
    pass
    
try:
    save_dir = os.environ['SAVEDIR']
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    save_path = os.path.join(save_dir, 'result_{}'.format(i_job))

    with open(save_path, 'wb') as f:
        pickle.dump(result, f)
except KeyError:
    pass

In [None]:
###Truncate file above
file_name = 'erf_psi_theory_match'
job_name = 'PRX_match_attempt_2'
project_dir = '/home/om2382/low-rank-dims/'
main_script_path = os.path.join(project_dir, 'cluster_main_scripts', job_name + '.py')
get_ipython().run_cell_magic('javascript', '', 'IPython.notebook.save_notebook()')
get_ipython().system('jupyter nbconvert --to script --no-prompt {}.ipynb'.format(file_name))
get_ipython().system('awk "/###Truncate/ {{exit}} {{print}}" {}.py'.format(file_name))
get_ipython().system('sed -i "/###Truncate/Q" {}.py'.format(file_name))
get_ipython().system('mv {}.py {}'.format(file_name, main_script_path))

In [None]:
###Submit job to cluster
n_jobs = len(micro_configs)
write_job_file(job_name, py_file_name='{}.py'.format(job_name), mem=64, n_hours=1, n_gpus=1)
job_script_path = os.path.join(project_dir, 'job_scripts', job_name + '.s')
submit_job(job_script_path, n_jobs, execute=False)

In [None]:
###Get job status
get_ipython().system('squeue -u om2382')

In [None]:
project_dir = '/home/om2382/low-rank-dims/'
job_name = 'PRX_match_attempt_2'
job_script_path = os.path.join(project_dir, 'job_scripts', job_name + '.s')
configs_array, results_array, key_order, sim_dict = unpack_processed_data(job_script_path)

In [None]:
!ls -t ../job_scripts/

In [None]:
results_array.shape

In [None]:
fig, ax = plt.subplots(6, 1, figsize=(4, 12))
T_window = 100
dT = 0.05
phi_torch = lambda x: torch.erf((np.sqrt(np.pi)/2)*x)
phi_prime_torch = lambda x: torch.exp(-(np.pi/4)*x**2)
lags = np.arange(0, T_window, dT)
n_lags = int(T_window/dT)
lags_full = np.concatenate([-lags[-n_lags:][::-1], lags[1:n_lags], np.array([lags[-1]])])
for i in range(6):
    
    #Choose g
    g = configs_array['g'][i]
    
    ### --- Compute single-unit properties
    d = compute_Delta_0(g=g)
    time, Delta_T = integrate_potential(d, g=g, tau_max=T_window, N_tau=int(T_window/dT))
    Delta_T = fix(Delta_T)
    C_Phi_half = compute_C_simple(d, Delta_T)
    alpha = compute_phi_prime_avg(d)
    
    ### --- Compute Psi from theory --- ###

    #Define relevant single-unit functions
    C_phi = np.concatenate([C_Phi_half[-n_lags:][::-1],
                            C_Phi_half[1:n_lags],
                            np.array([C_Phi_half[-1]])])
    C_phi_omega = fft(C_phi, dT)
    T = len(C_phi)
    t_indices= np.concatenate([np.arange(0, T//2), np.arange(-T//2, 0)])
    sampfreq = 1/dT
    w = 2*np.pi*sampfreq*t_indices/T
    C_phi_C_phi = np.multiply.outer(C_phi_omega, C_phi_omega)
    S_phi = alpha/(np.sqrt(2*np.pi)*(1 + 1j*w))
    S_phi_S_phi = np.multiply.outer(S_phi, S_phi)

    #Compute psi for iid random network
    Psi_PRX = (1/(np.abs(1 - 2*np.pi*(g**2) * (S_phi_S_phi))**2) - 1)*C_phi_C_phi
    Psi_PRX_tau = ifft(Psi_PRX, dT)
    Psi_tau_tau = np.diag(Psi_PRX_tau)
    
    ### --- Double check with David's code --- ###
    Psi_tau_2 = compute_psi_theory(symmetrize(Delta_T), g, dT)
    Ptt2 = np.diag(Psi_tau_2)
    Psi_tau_tau_2 = np.concatenate([Ptt2[-len(Ptt2)//2:], Ptt2[:len(Ptt2)//2]])

    for j in range(5):
        ax[i].plot(results_array[i,j,0,:],
                   results_array[i,j,1,:], color='C0', alpha=0.1)
    
    ax[i].plot(results_array[i,j,0,:],
               results_array[i,:,1,:].mean(0), color='C0', alpha=1)
    ax[i].plot(lags_full, Psi_tau_tau, color='C1')
    ax[i].plot(lags_full, Psi_tau_tau_2, color='C2')
    ax[i].set_ylim([0, 14])