In [1]:
import os, time
from copy import deepcopy
import numpy as np
import pickle
from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn.linear_model import RidgeCV

# Pytorch
import torch
use_cuda = True
if torch.cuda.is_available() and use_cuda:
    device = torch.device('cuda')
    print("Use cuda.")
else:
    device = torch.device('cpu')
    print("Use cpu.")
to_dev = lambda arr: (torch.from_numpy(arr).to(device)
                       if type(arr) == np.ndarray else arr.to(device))

# Data path
from specs import data_path

from helpers import gen_corr, ridge_CCA
from rnn_model_dt import RNN_Net, find_fp
from task_generators import cycling, flipflop, mante, romo, complex_sine
task_gens = [cycling, flipflop, mante, romo, complex_sine]

# Compute correlation
task_names = [
    "cycling",
    "flipflop",
    "mante",
    "romo",
    "complex_sine",
]
task_lbls = [
    "Cycling",
    "3-bit flipflop",
    "Mante",
    "Romo",
    "Complex sine",
]
n_task = len(task_names)
file_names = []
for tn in task_names:
    file_name = "neuro_noisy_" + tn + "_n_512" + ".pkl"
    file_names.append(file_name)
    

# Whether to add neural noise 
with_noise = False

#######################################################################################
# Number of samples
n_samples = 5
# Number of scenarios = weight initializations
n_sce = 4
n_mi = n_samples, n_sce

# Grace time: Minimal time for states and output to be taken into account
t_pc_min = 10

Use cpu.


In [2]:
# Compute correlation and regression of output. Takes
# Takes about 4 min.

# Correlation
loss_sce = []
n_steps_sce = []
lr0s_sce = []
w_keys = ["rnn.weight_ih_l0", "rnn.weight_hh_l0", "decoder.weight"]
n_ws = len(w_keys)
n_ifd = 3
norm_w_sce = np.zeros((n_task, *n_mi, n_ws, n_ifd))
norm_h_sce = np.zeros((n_task, *n_mi, n_ifd))
corr_w_h_sce = np.zeros((n_task, *n_mi, n_ifd, 2))

# Number of PCs to be regressed on.
n_comp_max = 30
n_comps = np.arange(n_comp_max)+1
n_nc = len(n_comps)
# Subset of outputs used for training / testing
frac_train = 3/4
# Possible rgularization parameters
alpha_range = np.logspace(-3, 6, 20)
# Results
loss = torch.zeros((n_task, *n_mi))
loss_fit = torch.zeros((n_task, *n_mi, n_nc))
loss_test = torch.zeros((n_task, *n_mi, n_nc))
r_sq = np.zeros((n_task, *n_mi, n_nc))
r_sq_test = np.zeros((n_task, *n_mi, n_nc))
alphas = np.zeros((n_task, *n_mi, n_nc))
bs_train_all = np.zeros((n_task), dtype=int)
cevr_fit = np.zeros((n_task, *n_mi, n_nc))

# Iterate over tasks
loss_crit = torch.nn.MSELoss()
for i_task in range(n_task):
    task_name = task_names[i_task]
    file_name = file_names[i_task]
    data_file = data_path + file_name
    print(task_name)
    with open(data_file, 'rb') as handle:
        res = pickle.load(handle)
    [
        n_steps, n_samples, gs, out_scales, n_sce, opt_gens, lr0s, n_mi, dim_hid, dim_in, dim_out, 
        dt, rec_step_dt, n_layers, bias, train_in, train_hid, train_out, train_layers, nonlin, 
        gaussian_init, h_0_std, noise_input_std, noise_init_std, noise_hid_std, batch_size, 
        task_params, task_params_ev, n_t_ev, task_ev, n_if, n_ifn, steps, loss_all, 
        output_all, hids_all, 
        h_0_all, sd_if_all, 
        ] = res[:38]
    del res
    print('Loaded from ', data_file)
    n_t_min = int(t_pc_min / (dt * rec_step_dt))
    ts_ex, input_ex, target_ex, mask_ex, noise_input_ex, noise_init_ex = [to_dev(arr) for arr in task_ev]
    
    # Only keep the hidden states necessary
    if with_noise:
        # Dynamics with the noise used during training.
        hids_init_all = hids_all[0]
        hids_final_all = hids_all[1]
        output_final_all = output_all[1]
    else:
        # Noise-free testing dynamics: i_ifn = 2, 3
        hids_init_all = hids_all[2]
        hids_final_all = hids_all[3]
        output_final_all = output_all[3]
    del hids_all, output_all

    ################################################################################
    # Change in weights
    for mi in np.ndindex(*n_mi):
        for i_w, key in enumerate(w_keys):
            w_init = sd_if_all[0][mi][key]
            w_final = sd_if_all[1][mi][key]
            dw = w_final - w_init
            for i_if, w_i in enumerate([w_init, w_final, dw]):
                norm_w_sce[i_task][mi][i_w, i_if] = torch.sqrt((w_i**2).mean())

    #################################################################################
    # Activity and correlation to weights
    for mi in np.ndindex(*n_mi):
        hids_init = hids_init_all[mi]
        hids_final = hids_final_all[mi]
        # Difference. Note that the noise is the same for both init and final!
        d_hids = hids_final - hids_init
        for i_h, hids_i in enumerate([hids_init, hids_final, d_hids]):
            # Discard some initial dynamics
            hids_i = hids_i[:, n_t_min:]
            # Norm
            norm_h_sce[i_task][mi][i_h] = torch.sqrt((hids_i**2).mean())
            # Correlation to output weights
            i_if = [0, 1, 1][i_h]
            w_out = sd_if_all[i_if][mi]["decoder.weight"]
            corr_w_h_sce[i_task][mi][i_h] = gen_corr(w_out, hids_i)

    time0 = time.time()
    for mi in np.ndindex(*n_mi):
        i_s, i_sce = mi
        
        # Select states and output
        hids = hids_final_all[i_s, i_sce]
        output = output_final_all[i_s, i_sce]
        
        # Regress output from hidden states after PC (trail-cond average)
        # Run PCA
        pca = PCA(n_comp_max)
        h = hids[:, n_t_min:, :].reshape(-1, dim_hid)
        pca.fit(h)
        h_proj = pca.transform(hids.reshape(-1, dim_hid)).reshape(batch_size, -1, n_comp_max)
        cevr_fit[i_task][mi] = pca.explained_variance_ratio_.cumsum()
        # Regression on the activity projected on the PCs
        bs_train = int(frac_train * batch_size)
        bs_train_all[i_task] = bs_train
        y_train = output[:bs_train].reshape(-1, dim_out)
        for i_nc, n_comp in enumerate(n_comps):
            X_train = h_proj[:bs_train, :, :n_comp].reshape(-1, n_comp)
            ridge = RidgeCV(alphas=alpha_range).fit(X_train, y_train)
            # Output for the entire batch (separate train and test later)
            output_fit = ridge.predict(h_proj[:, :, :n_comp].reshape(-1, n_comp)).reshape(batch_size, -1, dim_out)
            output_fit = torch.from_numpy(output_fit.astype('float32'))
            # Compute loss
            loss_fit[i_task][mi][i_nc] = loss_crit(output_fit[mask_ex], target_ex[mask_ex])
            loss_test[i_task][mi][i_nc] = loss_crit(output_fit[bs_train:][mask_ex[bs_train:]], target_ex[bs_train:][mask_ex[bs_train:]])
            r_sq[i_task][mi][i_nc] = 1 - ((output - output_fit)**2).mean()  / ((output - output.mean((0, 1), keepdims=True))**2).mean()
            r_sq_test[i_task][mi][i_nc] = 1 - ((output - output_fit)[bs_train:]**2).mean()  / ((output - output.mean((0, 1), keepdims=True))[bs_train:]**2).mean()
            alphas[i_task][mi][i_nc] = ridge.alpha_

    print("Took %.1f sec."% (time.time() - time0))
        
lbl_sce = [r"$g=%.1f$, $\sigma^{(0)}$ %s" % (gs[i_sce], out_scales[i_sce]) 
           for i_sce in range(n_sce)]
res = [
    task_names, task_lbls, n_task, lbl_sce,
    n_samples, n_sce, n_mi, dim_hid, 
    with_noise, loss_sce, n_steps_sce, lr0s_sce, 
    norm_w_sce, norm_h_sce, corr_w_h_sce, 
    n_comp_max, n_comps, n_nc, frac_train, alpha_range, bs_train_all, 
    loss, loss_fit, loss_test, r_sq, r_sq_test, alphas, cevr_fit,
]
file_name = "neuro_corr_regression"
if not with_noise:
    file_name += "_no_noise"
file_name = "_".join(file_name.split('.'))
data_file = data_path + file_name + ".pkl"
with open(data_file, 'wb') as handle:
    pickle.dump(res, handle)
print('Saved to ', data_file)

cycling
Loaded from  ../data/neuro_noisy_cycling_n_512.pkl
Took 68.0 sec.
flipflop
Loaded from  ../data/neuro_noisy_flipflop_n_512.pkl
Took 20.6 sec.
mante
Loaded from  ../data/neuro_noisy_mante_n_512.pkl
Took 27.5 sec.
romo
Loaded from  ../data/neuro_noisy_romo_n_512.pkl
Took 20.0 sec.
complex_sine
Loaded from  ../data/neuro_noisy_complex_sine_n_512.pkl
Took 57.3 sec.
Saved to  ../data/neuro_corr_regression_no_noise.pkl


In [3]:
# Compute dissimilarity
# Takes about 1 min

### Dissimilarity metrics: Ridge CCA and co.
alpha_ridge_CCA = 1
dist_eucl_sce = []
dist_ang_sce = []
svs_cca_sce = []
dist_ex_eucl_sce = []
dist_ex_ang_sce = []

for i_f, file_name in enumerate(file_names):
    data_file = data_path + file_name
    task_name = task_names[i_f]
    print(task_name)

    with open(data_file, 'rb') as handle:
        res = pickle.load(handle)
    [
        n_steps, n_samples, gs, out_scales, n_sce, opt_gens, lr0s, n_mi, dim_hid, dim_in, dim_out, 
        dt, rec_step_dt, n_layers, bias, train_in, train_hid, train_out, train_layers, nonlin, 
        gaussian_init, h_0_std, noise_input_std, noise_init_std, noise_hid_std, batch_size, 
        task_params, task_params_ev, n_t_ev, task_ev, n_if, n_ifn, steps, loss_all, 
        output_all, hids_all, 
        h_0_all, sd_if_all, 
    ] = res[:38]
    del res
    print('Loaded from ', data_file)

    # Only keep the hidden states necessary
    if with_noise:
        # Dynamics with the noise used during training.
        hids_init_all = hids_all[0]
        hids_final_all = hids_all[1]
    else:
        # Noise-free testing dynamics: i_ifn = 2, 3
        hids_init_all = hids_all[2]
        hids_final_all = hids_all[3]
    del hids_all

    # Grace period
    n_t_min = int(t_pc_min / (dt * rec_step_dt))
    
    ################################################################################
    # Generalized cos similarity
    # Within class. There is no specific order between samples. Unfold upper triangle.
    i_s_is, i_s_js = np.triu_indices(n_samples, k=1)
    n_triu = len(i_s_is)
    dist_eucl = np.zeros((n_triu, n_sce))
    dist_ang = np.zeros((n_triu, n_sce))
    svs_cca = np.zeros((n_triu, n_sce, dim_hid))
    for i_triu in tqdm(range(n_triu)):
        i_s_i = i_s_is[i_triu]
        i_s_j = i_s_js[i_triu]
        for i_sce in range(n_sce):
            mi = i_s_i, i_sce
            mj = i_s_j, i_sce
            # Hidden states
            h_i = hids_final_all[mi][:, n_t_min:].reshape(-1, dim_hid)
            h_j = hids_final_all[mj][:, n_t_min:].reshape(-1, dim_hid)
            d_eucl, d_ang, U, S, VT, norm_X_t, norm_Y_t = ridge_CCA(h_i, h_j, alpha_ridge_CCA)
            # Save
            mij = i_triu, i_sce
            dist_eucl[mij] = d_eucl
            dist_ang[mij] = d_ang
            svs_cca[mij] = S / (norm_X_t * norm_Y_t)

    # Between class
    i_sce_ex, j_sce_ex = np.triu_indices(n_sce, k=1)
    n_triu_ex = len(i_sce_ex)
    dist_ex_eucl = np.zeros((n_triu_ex, n_samples, n_samples))
    dist_ex_ang = np.zeros((n_triu_ex, n_samples, n_samples))
    for i_triu in tqdm(range(n_triu_ex)):
        i_sce = i_sce_ex[i_triu]
        j_sce = j_sce_ex[i_triu]
        for i_s, j_s in np.ndindex((n_samples, n_samples)):
            m_i = i_s, i_sce
            m_j = j_s, j_sce
            # Hidden states
            h_i = hids_final_all[m_i][:, n_t_min:].reshape(-1, dim_hid)
            h_j = hids_final_all[m_j][:, n_t_min:].reshape(-1, dim_hid)
            # Ridge CCA
            d_eucl, d_ang, WX, S, WY, norm_X_t, norm_Y_t = ridge_CCA(h_i, h_j, alpha_ridge_CCA)
            # Save
            mij = i_triu, i_s, j_s
            dist_ex_eucl[mij] = d_eucl
            dist_ex_ang[mij] = d_ang

    # Clear memory
    del hids_init_all, hids_final_all

    #################################################################################
    # Join
    dist_eucl_sce.append(dist_eucl)
    dist_ang_sce.append(dist_ang)
    svs_cca_sce.append(svs_cca)
    dist_ex_eucl_sce.append(dist_ex_eucl)
    dist_ex_ang_sce.append(dist_ex_ang)
    
lbl_sce = [r"$g=%.1f$, $\sigma^{(0)}$ %s" % (gs[i_sce], out_scales[i_sce]) 
           for i_sce in range(n_sce)]
res = [
    task_names, task_lbls, n_task, lbl_sce,
    n_samples, n_sce, n_mi, dim_hid, 
    with_noise, 
    alpha_ridge_CCA, dist_eucl_sce, dist_ang_sce, svs_cca_sce, 
    dist_ex_eucl_sce, dist_ex_ang_sce, 
]
file_name = "neuro_dissimilarity"
if not with_noise:
    file_name += "_no_noise"
file_name = "_".join(file_name.split('.'))
data_file = data_path + file_name + ".pkl"
with open(data_file, 'wb') as handle:
    pickle.dump(res, handle)
print('Saved to ', data_file)

cycling
Loaded from  ../data/neuro_noisy_cycling_n_512.pkl


100%|███████████████████████████████████████████| 10/10 [00:02<00:00,  3.56it/s]
100%|█████████████████████████████████████████████| 6/6 [00:11<00:00,  1.92s/it]


flipflop
Loaded from  ../data/neuro_noisy_flipflop_n_512.pkl


100%|███████████████████████████████████████████| 10/10 [00:01<00:00,  7.15it/s]
100%|█████████████████████████████████████████████| 6/6 [00:05<00:00,  1.14it/s]


mante
Loaded from  ../data/neuro_noisy_mante_n_512.pkl


100%|███████████████████████████████████████████| 10/10 [00:02<00:00,  4.54it/s]
100%|█████████████████████████████████████████████| 6/6 [00:08<00:00,  1.37s/it]


romo
Loaded from  ../data/neuro_noisy_romo_n_512.pkl


100%|███████████████████████████████████████████| 10/10 [00:01<00:00,  6.60it/s]
100%|█████████████████████████████████████████████| 6/6 [00:05<00:00,  1.03it/s]


complex_sine
Loaded from  ../data/neuro_noisy_complex_sine_n_512.pkl


100%|███████████████████████████████████████████| 10/10 [00:02<00:00,  4.57it/s]
100%|█████████████████████████████████████████████| 6/6 [00:08<00:00,  1.48s/it]

Saved to  ../data/neuro_dissimilarity_no_noise.pkl





In [None]:
### Perturbations
# This can take 40 min. 

# Perturbation directions
pert_dirs = ['w_out', 'pc', 'rand', 'w_in']
pert_dir_lbls = [
    r'$\mathbf{w}_\mathrm{out}$', 
    r'PCs', 
    r'rand', 
    r'$\mathbf{w}_\mathrm{in}$', 
]
n_pd = len(pert_dirs)

# Perturbation amplitudes
n_pa = 21
pert_amp_maxs = {
    'cycling': 10, 
    'flipflop': 10, 
    'mante': 10, 
    'romo': 5,
    'complex_sine': 10, 
}
# Minimal time at which perturbations start
t_pert_mins = {
    'cycling': 5, 
    'flipflop': 5, 
    'mante': 10, 
    'romo': 1,
    'complex_sine': 5, 
}
# Length of interval on which perturbations happen. 
# This is relevant only for tasks in which the decision time starts right away, because we use this duration to 
# shift the loss eval mask.
dt_pert_intvls = {
    'cycling': 10, 
    'flipflop': 5, 
    'complex_sine': 20, 
}

# Perturbation time: 
# There should be a minimal distance to the decision period.
# Either, the decision starts late enought so that there's ample time for decision (Mante, Romo)
# Or the decision starts almost immediately (sine, cycling). In this case, we shall delay the decision time. 
# Minimal difference between perturbation and decision (= start of loss evaluation).
dt_pert_loss = 5
# Number of different perturbation times
n_pt = 10

# Joint indices
n_mip = n_pd, n_pa, n_pt

# Number of samples
batch_size_pert = 32

# Number of PCs and output vectors from which to draw perturbation direction
n_comp_pert = 2

# Results arrays
resp_lbls = [r"short", r"long", r"loss"]
n_resp = len(resp_lbls)

#######################################################################################
# Results arrays
n_samples = 5
n_sce = 4
n_mi = n_samples, n_sce
loss_pre_task = torch.zeros((n_task, *n_mi))
loss_pert_task = torch.zeros((n_task, *n_mi, *n_mip))
output_pre_task = []
output_pert_task = []
task_pert_task = []
wo_proj_pw_task = []
h_pre_proj_pw_task = []
h_pert_proj_pw_task = []
t_perts_task = torch.zeros((n_task, n_pt))

# Iterate over tasks
# for i_task in [4]:
for i_task in range(n_task):
    task_name = task_names[i_task]
    file_name = file_names[i_task]
    data_file = data_path + file_name
    print(task_name)
    with open(data_file, 'rb') as handle:
        res = pickle.load(handle)
        [
            n_steps, n_samples, gs, out_scales, n_sce, opt_gens, lr0s, n_mi, dim_hid, dim_in, dim_out, 
            dt, rec_step_dt, n_layers, bias, train_in, train_hid, train_out, train_layers, nonlin, 
            gaussian_init, h_0_std, noise_input_std, noise_init_std, noise_hid_std, batch_size, 
            task_params, task_params_ev, n_t_ev, task_ev, n_if, n_ifn, steps, 
            loss_all, 
            _, _ , #output_all, hids_all, 
            h_0_all, sd_if_all, 
        ] = res[:38]
        del res
    print('Loaded from ', data_file)

    # Task
    ts_ex, input_ex, target_ex, mask_ex, noise_input_ex, noise_init_ex = [to_dev(arr) for arr in task_ev]
    # Loss for zero output
    loss_crit = torch.nn.MSELoss()
    loss_0 = loss_crit(target_ex[mask_ex] * 0, target_ex[mask_ex]).item()
    task_ex = task_ev
    n_t_ex = len(ts_ex)
    # Hidden state noise. For initial and input, we use the frozen arrays!
    noise_hid_std_ex = 0.

    # Min and max time for pca
    n_t_pc_min = int(t_pc_min / (dt * rec_step_dt))

    # Perturbation amplitudes
    pert_amp_max = pert_amp_maxs[task_name]
    pert_amps = np.linspace(0, pert_amp_max * np.sqrt(dim_hid), n_pa)

    # Perturbation times. First compute the minimal decision times.
    t_pert_min = t_pert_mins[task_name]
    mask_pert = mask_ex.clone()
    target_pert = target_ex.clone()
    if task_name in ['cycling', 'flipflop', 'complex_sine']:
        # Shift decision to the end of the interval. 
        #Then allow perturbations in the long interval before, regardless of when the input pulses are given.
        # Shift the decision mask, then choose perturbation on freed interval.
        t_pert_max = t_pert_min + dt_pert_intvls[task_name]
        t_loss_min_i = t_pert_max + dt_pert_loss
        mask_pert[:, ts_ex < t_loss_min_i] = False
    elif task_name == 'mante':
        # The original Mante task demands fixation at zero during the evidence presentation. 
        # We will remove this part in order to allow for a perturbation meanwhile. 
        i_tg1 = np.argmax(target_ex != 0, axis=1).min().item()
        mask_pert[:, :i_tg1] = False
        t0_loss = float(ts_ex[i_tg1])
        t_pert_max = t0_loss - dt_pert_loss
    elif task_name == 'romo':
        # Shift decision to the end of the interval. 
        #Then allow perturbations in the long interval before, regardless of when the input pulses are given.
        i_b_last = torch.Tensor([ts_ex[m_i].min() for m_i in mask_ex[:, :, 0]]).argmax().item()
        mask_pert[:] = mask_ex[i_b_last]
        target_pert[mask_pert] = target_ex[mask_ex]
        t0_loss = float(ts_ex[mask_pert[0, :, 0]].min())
        t_pert_max = t0_loss - dt_pert_loss
    else:
        raise NotImplementedError()
    t_perts = t_pert_min + ((t_pert_max - t_pert_min) / (dt * rec_step_dt) * np.arange(n_pt) // n_pt) * (dt * rec_step_dt)

    # Output for saving
    output_pre_all = torch.zeros((*n_mi, batch_size_pert, n_t_ex, dim_out))
    output_pert_all = torch.zeros((*n_mi, *n_mip, batch_size_pert, n_t_ex, dim_out))
    # Projection on 2 PCs and w_out
    n_comp_pw = 3
    wo_proj_pw_all = torch.zeros((*n_mi, dim_out, n_comp_pw))
    h_pre_proj_pw_all = torch.zeros((*n_mi, batch_size_pert, n_t_ex, n_comp_pw))
    h_pert_proj_pw_all = torch.zeros((*n_mi, *n_mip, batch_size_pert, n_t_ex, n_comp_pw))

    time0 = time.time()
    for mi in np.ndindex(*n_mi):
        print(mi)
        i_s, i_sce = mi
        out_scale = out_scales[i_sce]
        g = gs[i_sce]
        # Network instance
        net = RNN_Net(dim_in, dim_hid, dim_out, n_layers, nonlin, bias, out_scale, g, gaussian_init, 
                      dt, rec_step_dt, train_layers)
        net.load_state_dict(sd_if_all[1][mi])
        h_0 = h_0_all[mi]
        # Transfer
        net.to(device)
        h_0 = h_0.to(device)

        with torch.no_grad():
            # Run unperturbed dynamics
            output_pre, hids_pre = net.forward_hid(input_ex + noise_input_ex, 
                                           h_0 + noise_init_ex, 
                                           noise_hid_std_ex)
            output_pre_all[mi] = output_pre.cpu()
            loss_pre_task[i_task][mi] = loss_crit(output_pre[mask_pert], target_pert[mask_pert]).item()
            # Output weights and PCs
            w_in = net.rnn.weight_ih_l0.clone()
            w_out = net.decoder.weight.clone()
            h = hids_pre[0, :, n_t_pc_min:].reshape(-1, dim_hid)
            pca = PCA(n_comp_pert)
            pca.fit(h)
            pcs = torch.from_numpy(np.float32(pca.components_))

            # Leading 2 PCs and output (1st direction).
            pcs_w = np.linalg.qr(np.r_[pca.components_[:2], w_out[:1]].T)[0].T
            for i_out in range(len(pcs_w)):
                if w_out[0] @ pcs_w[i_out] < 0:
                    pcs_w[i_out] *= -1 # Fix signs
            h_pre_proj_pw_all[mi]= hids_pre @ pcs_w.T
            wo_proj_pw_all[mi] = w_out @ pcs_w.T

            # Print correlations between output and PCs
            corr_wo_pc = w_out @ pcs.T / (
                torch.linalg.norm(w_out, axis=-1)[:, None] * 
                torch.linalg.norm(pcs, axis=-1)[None, :] )

            for mip in np.ndindex(*n_mip):
                i_pd, i_pa, i_pt = mip
                pert_dir = pert_dirs[i_pd]
                pert_amp = pert_amps[i_pa]
                t_pert = t_perts[i_pt]
                i_t_pert = int(t_pert / (dt * rec_step_dt))

                # Perturbations
                if pert_dir == 'pc':
                    pert_vecs = torch.randn((batch_size_pert, n_comp_pert), device=device) @ pcs[:n_comp_pert]
                    pert_vecs[0] = pcs[0] # The 0th entry is always along the leading PC
                elif pert_dir == 'w_out':
                    pert_vecs = torch.randn((batch_size_pert, dim_out), device=device) @ torch.linalg.qr(w_out.T)[0].T
                    pert_vecs[0] = w_out[0] / np.linalg.norm(w_out[0])
                elif pert_dir == 'rand':
                    pert_vecs = torch.randn(batch_size_pert, dim_hid)
                elif pert_dir == 'w_in':
                    pert_vecs = torch.randn((batch_size_pert, dim_in), device=device) @ torch.linalg.qr(w_in)[0].T
                    pert_vecs[0] = w_in[:, 0] / np.linalg.norm(w_in[:, 0])
                # Normalize
                pert_vecs /= torch.linalg.norm(pert_vecs, axis=-1, keepdims=True)

                # Initial state: state at t_pert + perturbation
                h_0_pert = hids_pre[:, :, i_t_pert] + pert_amp * pert_vecs[None]
                # Run perturbed dynamics. Note: input is not down-sampled by rec_step_dt (in contrast to output, target)
                output_pert = output_pre.clone()
                hids_pert = hids_pre.clone()
                # The model changed a bit. Now we're actually saving h_t and not h_t+1. 
                # output_pert[:, i_t_pert+1:] = net.forward((input_ex + noise_input_ex)[:, i_t_pert * rec_step_dt + 1:], 
                output_pert[:, i_t_pert:], hids_pert[:, :, i_t_pert:] = net.forward_hid((input_ex + noise_input_ex)[:, i_t_pert * rec_step_dt:], 
                                           h_0_pert, 
                                           noise_hid_std_ex)
                output_pert_all[mi][mip] = output_pert.cpu()
                # Loss. Note that the mask is adapted, so we also compute the unperturbed loss separately each time.
                loss_pert_task[i_task][mi][mip] = loss_crit(output_pert[mask_pert], target_pert[mask_pert]).item()

                # Projection on the leading 2 PCs and the remaining direction for the first output vector
                h_pert_proj_pw_all[mi][mip] = hids_pert @ pcs_w.T

    print("Took %.1f sec."% (time.time() - time0))
    # Save output and task
    output_pre_task.append(output_pre_all)
    output_pert_task.append(output_pert_all)
    wo_proj_pw_task.append(wo_proj_pw_all)
    h_pre_proj_pw_task.append(h_pre_proj_pw_all)
    h_pert_proj_pw_task.append(h_pert_proj_pw_all)
    task_pert = ts_ex, input_ex, target_pert, mask_pert, noise_input_ex, noise_init_ex
    task_pert_task.append(task_pert)
    t_perts = torch.Tensor(t_perts.astype('float32'))
    t_perts_task[i_task] = t_perts

    # Save this task (in case saving all goes wrong...)
    lbl_sce = [r"$g=%.1f$, $\sigma^{(0)}$ %s" % (gs[i_sce], out_scales[i_sce]) 
               for i_sce in range(n_sce)]
    res = [
        task_names, task_lbls, n_task, lbl_sce,
        n_samples, n_sce, n_mi, dim_hid, 
        pert_dirs, pert_dir_lbls, n_pd, n_pa, pert_amp_maxs, t_pert_mins, dt_pert_intvls, dt_pert_loss, n_pt, n_mip, 
        batch_size_pert, 
        n_comp_pert, resp_lbls, n_resp, 
        loss_pre_task, loss_pert_task, 
        noise_hid_std_ex, 
        task_pert, t_perts,
        output_pre_all, output_pert_all, 
        wo_proj_pw_all, h_pre_proj_pw_all, h_pert_proj_pw_all,
    ]
    file_name = "neuro_perturb_" + task_name 
    file_name = "_".join(file_name.split('.'))
    data_file = data_path + file_name + ".pkl"
    with open(data_file, 'wb') as handle:
        pickle.dump(res, handle)
    print('Saved to ', data_file)

In [2]:
# The single datasets may be too large to fit in memory. Remove the traces except for cycling (where we plot the example).

# Pytorch
task_names = [
    "cycling",
    "flipflop", 
    "mante",
    "romo",
    "complex_sine",
]
dim_hid = 512
task_lbls = [" ".join(tn.split("_")) for tn in task_names]
task_lbls = [
    "Cycling",
    "3-bit flipflop",
    "Mante",
    "Romo",
    "Complex sine",
]
n_task = len(task_names)

n_pd = 4
n_pa = 21
n_pt = 10
n_mip = n_pd, n_pa, n_pt

#######################################################################################
# Results arrays
import torch
n_samples = 5
n_sce = 4
n_mi = n_samples, n_sce
loss_pre_task = torch.zeros((n_task, *n_mi))
loss_pert_task = torch.zeros((n_task, *n_mi, *n_mip))
output_pre_task = []
output_pert_task = []
task_pert_task = []
wo_proj_pw_task = []
h_pre_proj_pw_task = []
h_pert_proj_pw_task = []
t_perts_task = torch.zeros((n_task, n_pt))

# Iterate over tasks
for i_task in range(n_task):
# for i_task in [2]:
    task_name = task_names[i_task]
    data_file = os.path.join(data_path, "neuro_perturb_%s.pkl" % task_name)
    with open(data_file, 'rb') as handle:
        [
                _, _, _, lbl_sce,
                n_samples, n_sce, n_mi, dim_hid, 
                pert_dirs, pert_dir_lbls, n_pd, n_pa, pert_amp_maxs, t_pert_mins, dt_pert_intvls, dt_pert_loss, n_pt, n_mip, 
                batch_size_pert, 
                n_comp_pert, resp_lbls, n_resp, 
                loss_pre_task_i, loss_pert_task_i, 
                noise_hid_std_ex, 
                task_pert, t_perts,
                output_pre_all, output_pert_all, 
                wo_proj_pw_all, h_pre_proj_pw_all, h_pert_proj_pw_all,
        ] = pickle.load(handle)
        
    if not task_name == "cycling":
        output_pre_all = 0
        output_pert_all = 0
        h_pre_proj_pw_all = 0
        h_pert_proj_pw_all = 0
        
    # Save output and task
    output_pre_task.append(output_pre_all)
    output_pert_task.append(output_pert_all)
    wo_proj_pw_task.append(wo_proj_pw_all)
    h_pre_proj_pw_task.append(h_pre_proj_pw_all)
    h_pert_proj_pw_task.append(h_pert_proj_pw_all)
    task_pert_task.append(task_pert)
    t_perts_task[i_task] = t_perts
    
    i_task_i = np.where(loss_pre_task_i.std((-2, -1)) != 0)[0]
    assert len(i_task_i) == 1
    i_task_i = i_task_i[0]
    
    loss_pre_task[i_task] = loss_pre_task_i[i_task_i]
    loss_pert_task[i_task] = loss_pert_task_i[i_task_i]

# Save all
res = [
    task_names, task_lbls, n_task, lbl_sce,
    n_samples, n_sce, n_mi, dim_hid, 
    pert_dirs, pert_dir_lbls, n_pd, n_pa, pert_amp_maxs, t_pert_mins, dt_pert_intvls, dt_pert_loss, n_pt, n_mip, 
    batch_size_pert, 
    n_comp_pert, resp_lbls, n_resp, 
    loss_pre_task, loss_pert_task, 
    output_pre_task, output_pert_task, 
    noise_hid_std_ex, task_pert_task, t_perts_task,
    wo_proj_pw_task, h_pre_proj_pw_task, h_pert_proj_pw_task,
]
file_name = "neuro_perturb"
file_name = "_".join(file_name.split('.'))
data_file = data_path + file_name + ".pkl"
with open(data_file, 'wb') as handle:
    pickle.dump(res, handle)
print('Saved to ', data_file)

Saved to  ../data/neuro_perturb.pkl


In [4]:
# Compute Noise compression results
# Takes about 2 min.

####################################################################################
### Variance along directions
# Project the difference on the 1st PC and w_perp
# Number of tested task conditions
n_tc = 4
# Number of samples per task condition
n_samp_per_tc = 16
batch_size_ex = n_tc * n_samp_per_tc
n_comp_tca = 2
pca = PCA(n_comp_tca)
# Number of samples along rand, output, PCs
n_rps = 2000
n_ops = 100
n_pps = 100
var_rps = np.zeros((n_task, *n_mi, n_rps))
var_ops = np.zeros((n_task, *n_mi, n_ops))
var_pps = np.zeros((n_task, *n_mi, n_pps))
####################################################################################

# Load data
for i_f, file_name in enumerate(file_names):
    data_file = data_path + file_name
    task_name = task_names[i_f]
    with open(data_file, 'rb') as handle:
        [
            n_steps, n_samples, gs, out_scales, n_sce, opt_gens, lr0s, n_mi, dim_hid, dim_in, dim_out, 
            dt, rec_step_dt, n_layers, bias, train_in, train_hid, train_out, train_layers, nonlin, 
            gaussian_init, h_0_std, noise_input_std, noise_init_std, noise_hid_std, batch_size, 
            task_params, task_params_ev, n_t_ev, task_ev, n_if, n_ifn, steps, 
            loss_all, 
            _, _ , #output_all, hids_all, 
            h_0_all, sd_if_all, 
        ] = pickle.load(handle)[:38]
    print('Loaded from ', data_file)

    # Task
    ts_ex, input_ex, target_ex, mask_ex, noise_input_ex, noise_init_ex = [to_dev(arr) for arr in task_ev]
    # Loss for zero output
    loss_0 = torch.nn.MSELoss()(target_ex[mask_ex] * 0, target_ex[mask_ex]).item()
    n_t_ex = len(ts_ex)
    # Noise (if not using the pre-defined arrays anyways)
    noise_input_std_ex = noise_input_std
    noise_init_std_ex = noise_init_std
    noise_hid_std_ex = noise_hid_std

    # Adjust task to repeated conditions
    input_ex = input_ex[:n_tc].repeat((n_samp_per_tc, 1, 1))
    target_ex = target_ex[:n_tc].repeat((n_samp_per_tc, 1, 1))
    mask_ex = mask_ex[:n_tc].repeat((n_samp_per_tc, 1, 1))
    noise_input_ex = noise_input_std * np.float32(np.random.randn(*input_ex.shape)) / np.sqrt(dt)
    noise_init_ex = noise_init_std * np.float32(np.random.randn(n_layers-1, batch_size_ex, dim_hid)) 

    time0 = time.time()
    for mi in np.ndindex(*n_mi):
        # print(mi)
        i_s, i_sce = mi
        out_scale = out_scales[i_sce]
        g = gs[i_sce]
        # Network instance
        net = RNN_Net(dim_in, dim_hid, dim_out, n_layers, nonlin, bias, out_scale, g, gaussian_init, 
                      dt, rec_step_dt, train_layers)
        net.load_state_dict(sd_if_all[1][mi])
        h_0 = h_0_all[mi]
        h_0 = h_0[:, :n_tc].repeat(1, n_samp_per_tc, 1)  # adjust to task conds.
        # Transfer
        net.to(device)
        h_0 = h_0.to(device)

        w_out = sd_if_all[1][mi]['decoder.weight'].clone()
        with torch.no_grad():
            # Run full dynamics
            output, hids = net.forward_hid(input_ex + noise_input_ex, 
                                           h_0 + noise_init_ex, 
                                           noise_hid_std_ex, last_time=False)

            # Trial conditioned average and fluctuations
            h_tca = torch.zeros((n_tc, n_t_ex, dim_hid))
            dh = torch.zeros((batch_size_ex, n_t_ex, dim_hid))
            for i_tc in range(n_tc):
                h_c = hids[0, i_tc::n_tc]
                h_c_m = h_c.mean(axis=-3)
                h_tca[i_tc] = h_c_m
                dh[i_tc::n_tc] = h_c - h_c_m
            h_tca = h_tca.reshape(-1, dim_hid)
            dh = dh.reshape((-1, dim_hid))

            # PCA of trial-cond. avg
            pca.fit(h_tca)
            pcs_tca = torch.from_numpy(np.float32(pca.components_))

            # Compute variance along random and nonrandom directions
            # Random projections
            rps = torch.randn((n_rps, dim_hid))#, device=device)
            rps = rps / rps.norm(dim=1)[:, None]
            var_rps[i_f][mi] = (dh @ rps.T).var(axis=0).cpu().numpy()
            # Variance along output directions
            w_mix = torch.randn((n_ops, dim_out)) @ torch.linalg.qr(w_out.T)[0].T
            w_mix /= torch.linalg.norm(w_mix, axis=-1, keepdims=True)
            var_ops[i_f][mi] = (dh @ w_mix.T).var(axis=0).cpu().numpy()
            # Variance along the trial-condition averaged PCs
            w_mix = torch.randn((n_pps, n_comp_tca)) @ pcs_tca
            w_mix /= torch.linalg.norm(w_mix, axis=-1, keepdims=True)
            var_pps[i_f][mi] = (dh @ w_mix.T).var(axis=0).cpu().numpy()
    print("Took %.1f sec."% (time.time() - time0))

# Save all
lbl_sce = [r"$g=%.1f$, $\sigma^{(0)}$ %s" % (gs[i_sce], out_scales[i_sce]) 
           for i_sce in range(n_sce)]
res = [
    task_names, task_lbls, n_task, lbl_sce,
    dim_hid,
    n_samples, n_sce, n_mi, 
    n_tc, n_samp_per_tc,batch_size_ex, n_comp_tca, n_rps, n_ops, n_pps, var_rps, var_ops, var_pps, 
    t_pc_min,
]

file_name = "neuro_noise_compression"
file_name = file_name.replace(".", "_")
data_file = data_path + file_name + ".pkl"
with open(data_file, 'wb') as handle:
    pickle.dump(res, handle)
print('Saved to ', data_file)


Loaded from  ../data/neuro_noisy_cycling_n_512.pkl
Took 18.5 sec.
Loaded from  ../data/neuro_noisy_flipflop_n_512.pkl
Took 8.4 sec.
Loaded from  ../data/neuro_noisy_mante_n_512.pkl
Took 13.1 sec.
Loaded from  ../data/neuro_noisy_romo_n_512.pkl
Took 9.5 sec.
Loaded from  ../data/neuro_noisy_complex_sine_n_512.pkl
Took 14.2 sec.
Saved to  ../data/neuro_noise_compression.pkl
