In [2]:
import sys, os, pickle
import torch
sys.path.append('/home/om2382/mft-theory/')
from cluster import *
from core import *
from empirics import *
from functions import *
from ode_methods import *
from plotting import *
from theory import *
from utils import *
from functools import partial
import matplotlib.pyplot as plt

In [3]:
### --- SET UP ALL CONFIGS --- ###
from itertools import product
mode = 'sim' #or 'theory'
if mode == 'theory':
    n_seeds = 1
    macro_configs = config_generator(PR_G=list(np.round(np.arange(0.06, 1.02, 0.02), 2)),
                                     PR_D=[0.1, 1, 10],
                                     g=[3, 6, 10])
if mode == 'sim':
    n_seeds = 10
    macro_configs = config_generator(PR_G=list(np.round(np.arange(0.1, 1.1, 0.1), 2)),
                                     PR_D=[0.1, 1, 10],
                                     g=[3, 6, 10])

micro_configs = tuple(product(macro_configs, list(range(n_seeds))))
prototype = False

### --- SELECT PARTICULAR CONFIG --- ###
try:
    i_job = int(os.environ['SLURM_ARRAY_TASK_ID']) - 1
except KeyError:
    i_job = 0
    prototype = True
params, i_seed = micro_configs[i_job]
i_config = i_job//n_seeds

new_random_seed_per_condition = True
if new_random_seed_per_condition:
    np.random.seed(i_job)
else: #Match random seeds across conditions
    np.random.seed(i_seed)

In [None]:
### --- Set empirical parameters --- ###

#network properties size
N = 5000
g = params['g']
#g = 6
phi_torch = lambda x: torch.erf((np.sqrt(np.pi)/2)*x)
phi_prime_torch = lambda x: torch.exp(-(np.pi/4)*x**2)

In [None]:
### --- Estimate psi empirically --- ###

#Generate LDRG matrix

PR_D = params['PR_D']
PR_G = params['PR_G']

if PR_D < 1:
    alpha = 1
else:
    alpha = PR_D
    
K = int(alpha * N)
L = np.random.normal(0, 1/np.sqrt(N), (N, K))
RT = np.random.normal(0, 1/np.sqrt(N), (K, N))

if PR_D < 1:
    beta_D = invert_PR_by_newton(PR_D)
    D = np.exp(-beta_D*np.arange(K)/K)
else:
    D = np.ones(K)
    
if PR_G < 1:
    beta_G = invert_PR_by_newton(PR_G)
    G = np.exp(-beta_G*np.arange(N)/N)
else:
    G = np.ones(N)

g_correction = g / np.sqrt(np.sum(D**2)/N*np.sum(G**2)/N)
D = D * g_correction

L = torch.from_numpy(L).type(torch.FloatTensor).to(0)
D_ = torch.from_numpy(D).type(torch.FloatTensor).to(0)
RT = torch.from_numpy(RT).type(torch.FloatTensor).to(0)
G_ = torch.from_numpy(G).type(torch.FloatTensor).to(0)

W = torch.einsum('ik, k, kj, j -> ij', L, D_, RT, G_)
del L
del RT

In [None]:
### --- Estimate psi empirically --- ###

compute_empirical_psi = (mode == 'sim')
if compute_empirical_psi:
    W_ = W
    r_cov = estimate_Psi_with_on_diagonals(lags=[0], T_sim=2000, dt_save=1, dt=0.05, W=W_, phi_torch=phi_torch,
                                           T_save_delay=1000, N_batch=1, N_loops=200, runga_kutta=True,
                                           noise_sigma=0, mode='tau_tau', return_raw_cov=True)
    #x_cov, r_cov = estimate_cov_eigs(T_sim=2000, dt_save=1, dt=0.1, W=W_, phi_torch=phi_torch,
    #                                 T_save_delay=1000, N_batch=1, N_loops=60,
    #                                 return_raw_covs=True, runga_kutta=True)
    
    r_cov = np.squeeze(r_cov.cpu().detach().numpy())
    dim_emp = np.trace(r_cov)**2 / (r_cov**2).sum() / N
    Gr_cov = G[:,None] * r_cov * G[None,:]
    dim_nn_emp = np.trace(Gr_cov)**2 / (Gr_cov**2).sum() / N
    G_shuff = np.random.permutation(G)
    Gr_readout_cov = G_shuff[:,None] * r_cov * G_shuff[None,:]
    dim_readout_emp = np.trace(Gr_readout_cov)**2 / (Gr_readout_cov**2).sum() / N
else:
    dim_emp = 0
    dim_nn_emp = 0
    dim_readout_emp = 0

In [None]:
### --- Set theory parameters --- ###
T_window = 200
dT = 0.025

In [None]:
### --- Compute single-unit properties

d = compute_Delta_0(g=g)
time, Delta_T = integrate_potential(d, g=g, tau_max=T_window, N_tau=int(T_window/dT))
Delta_T = fix(Delta_T)
C_phi_half = compute_C_simple(d, Delta_T)
avg_gain = compute_phi_prime_avg(d)

In [None]:
### --- Compute Psi from theory --- ###

#Define relevant single-unit functions
C_phi = np.concatenate([C_phi_half,
                        np.array([C_phi_half[-1]]),
                        C_phi_half[1:][::-1]])
C_phi_omega = fft(C_phi, dT)
T = len(C_phi)
t_indices= np.concatenate([np.arange(0, T//2), np.arange(-T//2, 0)])
sampfreq = 1/dT
w = 2*np.pi*sampfreq*t_indices/T
C_phi_C_phi = np.multiply.outer(C_phi_omega, C_phi_omega)
S_phi = avg_gain/(np.sqrt(2*np.pi)*(1 + 1j*w))
S_phi_S_phi = np.multiply.outer(S_phi, S_phi)

#Compute psi of normalized units for LDRG network
num = 1 + (1/(alpha*PR_D) + 1/PR_G - 1)*np.abs(2*np.pi*g**2 * S_phi_S_phi)**2
denom = np.abs(1 - 2*np.pi*g**2 * S_phi_S_phi)**2
Psi_omega1_omega2 = (num / denom) * C_phi_C_phi
Psi_tau1_tau2 = ifft(Psi_omega1_omega2, dT)

#Compute psi of non-normalized units for LDRG network
num = 1/PR_G + (1/(alpha*PR_D))*np.abs(2*np.pi*g**2 * S_phi_S_phi)**2
denom = np.abs(1 - 2*np.pi*g**2 * S_phi_S_phi)**2
q22 = np.mean(G**2)**2
Psi_omega1_omega2 = (num / denom) * q22*C_phi_C_phi
Psi_nn_tau1_tau2 = ifft(Psi_omega1_omega2, dT)

#Compute psi of random readout units for LDRG network
num = ((1/PR_G - 1) * (np.abs(1 - 2*np.pi*g**2 * S_phi_S_phi)**2 + np.abs(2*np.pi*g**2 * S_phi_S_phi)**2)
       + 1 + 1/(alpha * PR_D) * np.abs(2*np.pi*g**2 * S_phi_S_phi)**2)
denom = np.abs(1 - 2*np.pi*g**2 * S_phi_S_phi)**2
Psi_omega1_omega2 = (num / denom) * q22*C_phi_C_phi
Psi_readout_tau1_tau2 = ifft(Psi_omega1_omega2, dT)

In [None]:
help(np.fft)

In [None]:
C2 = C_phi_half[0]**2
dim_theory = C2/Psi_tau1_tau2[0,0].real

In [None]:
dim_nn_theory = q22*C2/Psi_nn_tau1_tau2[0,0].real

In [None]:
dim_readout_theory = q22*C2/Psi_readout_tau1_tau2[0,0].real

In [None]:
processed_data = np.array([dim_emp, dim_theory,
                           dim_nn_emp, dim_nn_theory,
                           dim_readout_emp, dim_readout_theory])

In [None]:
### --- SAVE RESULTS -- ###
result = {'sim': None, 'i_seed': i_seed, 'config': params,
          'i_config': i_config, 'i_job': i_job}
try:
    result['processed_data'] = processed_data
except NameError:
    pass
    
try:
    save_dir = os.environ['SAVEDIR']
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    save_path = os.path.join(save_dir, 'result_{}'.format(i_job))

    with open(save_path, 'wb') as f:
        pickle.dump(result, f)
except KeyError:
    pass

In [5]:
###Truncate file above
file_name = 'LDRG_PR_match_PRG'
job_name = 'LDRG_PR_match_PRG_sim_final_3'
project_dir = '/home/om2382/low-rank-dims/'
main_script_path = os.path.join(project_dir, 'cluster_main_scripts', job_name + '.py')
get_ipython().run_cell_magic('javascript', '', 'IPython.notebook.save_notebook()')
get_ipython().system('jupyter nbconvert --to script --no-prompt {}.ipynb'.format(file_name))
get_ipython().system('awk "/###Truncate/ {{exit}} {{print}}" {}.py'.format(file_name))
get_ipython().system('sed -i "/###Truncate/Q" {}.py'.format(file_name))
get_ipython().system('mv {}.py {}'.format(file_name, main_script_path))

<IPython.core.display.Javascript object>

[NbConvertApp] Converting notebook LDRG_PR_match_PRG.ipynb to script
[NbConvertApp] Writing 14423 bytes to LDRG_PR_match_PRG.py
awk: cmd. line:1: /###Truncate/ <IPython.core.autocall.ZMQExitAutocall object at 0x2aef78db1ad0> <built-in function print>
awk: cmd. line:1:                       ^ syntax error
awk: cmd. line:1: /###Truncate/ <IPython.core.autocall.ZMQExitAutocall object at 0x2aef78db1ad0> <built-in function print>
awk: cmd. line:1:                                                                                ^ syntax error


In [None]:
###Submit job to cluster
n_jobs = len(micro_configs)
write_job_file(job_name, py_file_name='{}.py'.format(job_name),
               mem=64, n_hours=24, n_gpus=1,
               results_subdir='PRL_Submission')
job_script_path = os.path.join(project_dir, 'job_scripts', job_name + '.s')
submit_job(job_script_path, n_jobs, execute=False, results_subdir='PRL_Submission', lkumar=True)

In [None]:
###Get job status
get_ipython().system('squeue -u om2382')

In [None]:
project_dir = '/home/om2382/low-rank-dims/'
job_name = 'LDRG_PR_match_PRG_theory_final'
job_script_path = os.path.join(project_dir, 'job_scripts', job_name + '.s')
theory_results = unpack_processed_data(job_script_path, results_subdir='PRL_Submission')

In [6]:
project_dir = '/home/om2382/low-rank-dims/'
job_name = 'LDRG_PR_match_PRG_sim_final_3'
job_script_path = os.path.join(project_dir, 'job_scripts', job_name + '.s')
sim_results = unpack_processed_data(job_script_path, results_subdir='PRL_Submission')

In [None]:
### --- Save packaged results --- ###
with open('packaged_results/LDRG_PR_match_PRG_theory_final', 'wb') as f:
    pickle.dump(theory_results, f)
    

In [7]:
with open('packaged_results/LDRG_PR_match_PRG_sim_final_3', 'wb') as f:
    pickle.dump(sim_results, f)

In [None]:
sim_results[1].shape

In [None]:
fig, ax = plt.subplots(3, 3)
col1 = '#FF0000'
col2 = '#9580D6'
col3 = '#4AB7ED'
for i in range(3):
    for j in range(3):
        s1 = sim_results[1][:,i,j,:,0].mean(-1)
        s2 = sim_results[1][:,i,j,:,2].mean(-1)
        s3 = sim_results[1][:,i,j,:,4].mean(-1)
        ax[i,j].plot(sim_results[0]['PR_G'], s1, '.', color=col1)
        ax[i,j].plot(sim_results[0]['PR_G'], s2, '.', color=col2)
        ax[i,j].plot(sim_results[0]['PR_G'], s3, '.', color=col3)
        for k in range(10):
            s1 = sim_results[1][:,i,j,k,0]
            s2 = sim_results[1][:,i,j,k,2]
            s3 = sim_results[1][:,i,j,k,4]
            ax[i,j].plot(sim_results[0]['PR_G'], s1, '.', alpha=0.1, color=col1)
            ax[i,j].plot(sim_results[0]['PR_G'], s2, '.', alpha=0.1, color=col2)
            ax[i,j].plot(sim_results[0]['PR_G'], s3, '.', alpha=0.1, color=col3)
        t1 = sim_results[1][:,i,j,:,1].mean(-1)
        t2 = sim_results[1][:,i,j,:,3].mean(-1)
        t3 = sim_results[1][:,i,j,:,5].mean(-1)
        ax[i,j].plot(sim_results[0]['PR_G'], t1, color=col1)
        ax[i,j].plot(sim_results[0]['PR_G'], t2, color=col2)
        ax[i,j].plot(sim_results[0]['PR_G'], t3, color=col3)
        ax[i,j].set_ylim([0, 0.12])

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(4, 4))
col1 = '#FF0000'
col2 = '#4AB7ED'
for i_PRD in range(2):
    ax[i_PRD].plot(theory_results[0]['PR_G'], theory_results[1][:,i_PRD,0,1], color=col1)
    ax[i_PRD].plot(theory_results[0]['PR_G'], theory_results[1][:,i_PRD,0,3], color=col2)
    for k in range(3):
        pass
        ax[i_PRD].plot(sim_results[0]['PR_G'], sim_results[1][:,i_PRD,k,0], '.', color=col1, alpha=0.1, label='_nolegend_')
        ax[i_PRD].plot(sim_results[0]['PR_G'], sim_results[1][:,i_PRD,k,2], '.', color=col2, alpha=0.1, label='_nolegend_',
                       fillstyle='full')
    ax[i_PRD].plot(sim_results[0]['PR_G'], sim_results[1][:,i_PRD,:,0].mean(-1),'.', color=col1,
                   markersize=10, alpha=1, label='_nolegend_')
    ax[i_PRD].plot(sim_results[0]['PR_G'], sim_results[1][:,i_PRD,:,2].mean(-1),'.', color=col2,
               markersize=10, alpha=1, label='_nolegend_')

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
col1 = '#FF0000'
col2 = '#4AB7ED'
for k in range(10):
    pass
    ax.plot(sim_results[0]['PR_G'], sim_results[1][:,0,0,k,0], '.', color=col1, alpha=0.1, label='_nolegend_')
    ax.plot(sim_results[0]['PR_G'], sim_results[1][:,0,0,k,2], '.', color=col2, alpha=0.1, label='_nolegend_',
            fillstyle='full')
ax.plot(sim_results[0]['PR_G'], sim_results[1][:,0,0,:,0].mean(-1),'.', color=col1,
        markersize=10, alpha=1, label='_nolegend_')
ax.plot(theory_results[0]['PR_G'], theory_results[1][:,0,0,0,1], color=col1)
ax.plot(sim_results[0]['PR_G'], sim_results[1][:,0,0,:,2].mean(-1), '.', color=col2,
        markersize=10, alpha=1, label='_nolegend_')
ax.plot(theory_results[0]['PR_G'], theory_results[1][:,0,0,0,3], color=col2)
ax.set_ylim([0, 0.12])
ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1])
#ax.legend(['$PR(\{G_i \phi_i\})$ sim', '$PR(\{G_i \phi_i\})$ theory',
#           '$PR(\{\phi_i\})$ sim', '$PR(\{\phi_i\})$ theory'])
ax.legend(['$PR(\{\phi_i\})$', '$PR(\{G_i \phi_i\})$'],frameon=False)
ax.set_xlabel('$PR(\{G_i\})$')
ax.set_yticks([0, 0.03, 0.06])
#fig.savefig('figs/LDRG_PR_match_2.pdf')

In [None]:
plt.figure()
for k in range(10):
    plt.plot(configs_array_1['PR_G'], results_array_1[:,k,0], '.', color='C0', alpha=0.1)
plt.plot(configs_array_1['PR_G'], results_array_1[:,:,0].mean(-1), '.', color='C0', markersize=10, alpha=1)
plt.plot(configs_array['PR_G'], results_array[:,0,1], color='k')
plt.ylim([0, 0.1])
plt.figure()
for k in range(10):
    plt.plot(configs_array_1['PR_G'], results_array_1[:,k,2], '.', color='C0', alpha=0.1)
plt.plot(configs_array_1['PR_G'], results_array_1[:,:,2].mean(-1), '.', color='C0', markersize=10, alpha=1)
plt.plot(configs_array['PR_G'], results_array[:,0,3], color='k')
plt.ylim([0, 0.1])

In [None]:
from scipy.stats import bootstrap
mean_pr = results_array[:,:,0].mean(-1)
sem_pr = results_array[:,:,0].std(-1)/np.sqrt(10)
sem_pr = bootstrap((results_array[:,:,0],), np.mean, axis=1)
#plt.errorbar(configs_array['g'], mean_pr, yerr=sem_pr, color='C0')
plt.fill_between(configs_array['alpha'], sem_pr.confidence_interval.low,
                 sem_pr.confidence_interval.high)
for i_seed in range(10):
    pass
    #plt.plot(configs_array['g'], results_array[:,i_seed,0], '.', color='C0')
plt.plot(configs_array['alpha'], results_array[:,0,1], color='k')
plt.xscale('log')
#plt.ylim([0, 0.07])

In [None]:
from scipy.stats import bootstrap
mean_pr = results_array[:,:,0].mean(-1)
sem_pr = results_array[:,:,0].std(-1)/np.sqrt(10)
sem_pr = bootstrap((results_array[:,:,0],), np.mean, axis=1)
#plt.errorbar(configs_array['g'], mean_pr, yerr=sem_pr, color='C0')
plt.fill_between(configs_array['PR_G'], sem_pr.confidence_interval.low,
                 sem_pr.confidence_interval.high)
for i_seed in range(10):
    pass
    #plt.plot(configs_array['g'], results_array[:,i_seed,0], '.', color='C0')
plt.plot(configs_array['PR_G'], results_array[:,0,1], color='k')
#plt.xscale('log')
#plt.ylim([0, 0.07])

In [None]:
from scipy.stats import bootstrap
mean_pr = results_array[:,:,2].mean(-1)
sem_pr = results_array[:,:,2].std(-1)/np.sqrt(10)
sem_pr = bootstrap((results_array[:,:,2],), np.mean, axis=1)
#plt.errorbar(configs_array['g'], mean_pr, yerr=sem_pr, color='C0')
plt.fill_between(configs_array['PR_G'], sem_pr.confidence_interval.low,
                 sem_pr.confidence_interval.high)
for i_seed in range(10):
    pass
    #plt.plot(configs_array['g'], results_array[:,i_seed,0], '.', color='C0')
plt.plot(configs_array['PR_G'], results_array[:,0,3], color='k')

In [None]:
from scipy.stats import bootstrap
mean_pr = results_array[:,:,0].mean(-1)
sem_pr = results_array[:,:,0].std(-1)/np.sqrt(10)
sem_pr = bootstrap((results_array[:,:,0],), np.mean, axis=1)
#plt.errorbar(configs_array['g'], mean_pr, yerr=sem_pr, color='C0')
plt.fill_between(configs_array['g'], sem_pr.confidence_interval.low,
                 sem_pr.confidence_interval.high)
for i_seed in range(10):
    pass
    #plt.plot(configs_array['g'], results_array[:,i_seed,0], '.', color='C0')
plt.plot(configs_array['g'], results_array[:,0,1], color='k')
plt.ylim([0, 0.07])