In [1]:
import sys
sys.path.insert(0, "../")
sys.path.insert(0, "../dataset_generation/")

import torch
import numpy as np
import scipy.signal
import matplotlib.pyplot as plt

#plt.style.use("project_style.mplstyle")

import stats
import scipy.stats
from tqdm import tqdm
import analysis_tools as tools

In [2]:
samples = 100
timesteps = 400
sequence  = ["square"]

bins = [16,16]

In [3]:
model_name = "al1_10_l2_1"
path = f"../models/{model_name}"
model, params = tools.load_model(path, device = "cpu", model_type = "RNN")

In [4]:
noise_scales = np.concatenate(([0], np.geomspace(1e-6, 1, 7)))
noise_scales

array([0.e+00, 1.e-06, 1.e-05, 1.e-04, 1.e-03, 1.e-02, 1.e-01, 1.e+00])

In [5]:
xa, ra, va, ca = tools.test_dataset(samples*[sequence[0]], timesteps, context = params["context"], 
                                    device = "cpu", trajectories = True)
xb, rb, vb, cb = tools.test_dataset(samples*[sequence[0]], timesteps, context = params["context"], 
                                    device = "cpu", trajectories = True)

(torch.Size([100, 400, 8]),
 torch.Size([100, 400, 2]),
 torch.Size([100, 400, 8]),
 torch.Size([100, 400, 2]))

In [6]:
def time_ratemaps(states, r, bins):
    # state.shape = (N, T, Nc)
    ratemaps = np.zeros((states.shape[1], states.shape[-1], *bins)).astype("float32")
    for i in range(states.shape[1]):
        ratemaps[i] = stats.population_vector_ratemaps(states[None,:,i], r[None,:,i], [bins], smooth = True)  
    return ratemaps
        
def autocorr(x):
    z = np.zeros((len(x), len(x)))
    for i in tqdm(range(len(x))):
        for j in range(len(x)):
            z[i, j] = np.corrcoef(x[i], x[j])[1,0]
    return z

In [7]:
noisy_gs = []
noisy_ps = []
noise_corrs = []

for i in tqdm(range(len(noise_scales))):
    g_prev = None
    gs = []
    ps = []
    errors = []
    for x in (xa, xb):
        yhat, g, p, mu = model(x, g_prev = g_prev)
        noise = torch.normal(0, noise_scales[i], size = g[:,-1].shape)
        g_prev = g[:,-1] + noise
        g_ratemaps = time_ratemaps(g.detach().numpy(), x[1].detach().numpy(), bins)
        p_ratemaps = time_ratemaps(p.detach().numpy(), x[1].detach().numpy(), bins)
        gs.append(g_ratemaps.reshape((g_ratemaps.shape[0], -1)))
        ps.append(p_ratemaps.reshape((p_ratemaps.shape[0], -1)))

    noisy_gs.append(np.concatenate(gs, axis = 0))
    noisy_ps.append(np.concatenate(ps, axis = 0))

100%|█████████████████████████████████████████████| 8/8 [06:44<00:00, 50.54s/it]


In [8]:
corr = np.zeros(2*timesteps)
for i in tqdm(range(noisy_gs[0].shape[0])):
    for j in range(0, timesteps, 10):
        corr[i] += np.sum(noisy_gs[0][i]*noisy_gs[0][j])

100%|████████████████████████████████████████| 800/800 [00:05<00:00, 159.51it/s]
