In [70]:
import sys
sys.path.append("../../")
import numpy as np
from matplotlib import pyplot as plt
import json
import pickle
import os
import torch
from rnn_coach.src.RNN_torch import *
from rnn_coach.src.DynamicSystemAnalyzer import *
from rnn_coach.src.RNN_numpy import *
from rnn_coach.src.Task import *
from rnn_coach.src.DataSaver import *
from latent_circuit_inference.src.LatentCircuit import *
from latent_circuit_inference.src.LatentCircuitFitter import *
from latent_circuit_inference.src.LCAnalyzer import *
from latent_circuit_inference.src.utils import *
from latent_circuit_inference.src.circuit_vizualization import *
from copy import deepcopy
from sklearn.decomposition import IncrementalPCA as iPCA

In [71]:
#given the folder, open up the files:
RNN_folder = "20230207-08111997"
RNN_folder_full_path = os.path.join("../", "../", "rnn_coach", "data", "trained_RNNs", "CDDM", RNN_folder)
mse_score_RNN = os.listdir(RNN_folder_full_path)[0].split("_")[0]
rnn_config = json.load(open(os.path.join(RNN_folder_full_path, f"{mse_score_RNN}_config.json"), "rb+"))
try:
    rnn_data = pickle.load(open(os.path.join(RNN_folder_full_path, f"{mse_score_RNN}_params_CDDM.pkl"), "rb+"))
except:
    rnn_data = pickle.load(open(os.path.join(RNN_folder_full_path, f"params_CDDM_{mse_score_RNN}.pkl"), "rb+"))
LCI_config_file = json.load(open(os.path.join("../", "data", "configs", f"LCI_config.json"), mode="r", encoding='utf-8'))
task_data = rnn_config["task_params"]

In [72]:
# defining RNN:
activation_name = rnn_config["activation"]
RNN_N = rnn_config["N"]
if activation_name == 'relu':
    activation_RNN = lambda x: torch.maximum(x, torch.tensor(0))
elif activation_name == 'tanh':
    activation_RNN = torch.tanh
elif activation_name == 'sigmoid':
    activation_RNN = lambda x: 1/(1 + torch.exp(-x))
elif activation_name == 'softplus':
    activation_RNN = lambda x: torch.log(1 + torch.exp(5 * x))
dt = rnn_config["dt"]
tau = rnn_config["tau"]
connectivity_density_rec = rnn_config["connectivity_density_rec"]
spectral_rad = rnn_config["sr"]
sigma_inp = rnn_config["sigma_inp"]
sigma_rec = rnn_config["sigma_rec"]
# seed = LCI_config_file["seed"]
seed = np.random.randint(1000000)
print(f"seed: {seed}")
if not seed is None:
    rng.manual_seed(seed)
input_size = np.array(rnn_data["W_inp"]).shape[1]
output_size = np.array(rnn_data["W_out"]).shape[0]

# Task:
n_steps = task_data["n_steps"]
task = TaskCDDM(n_steps=n_steps, n_inputs=input_size, n_outputs=output_size, task_params=task_data)

seed: 376628


In [73]:
# defining RNN
N = RNN_N
W_rec = rnn_data["W_rec"]
W_inp = rnn_data["W_inp"]
W_out = rnn_data["W_out"]
dt = rnn_data["dt"]
tau = rnn_data["tau"]
activation_fun_RNN = lambda x: np.maximum(0, x)
RNN = RNN_numpy(N=N, W_rec=W_rec, W_inp=W_inp, W_out=W_out, dt=dt, tau=tau, activation=activation_fun_RNN)
RNN.y = np.zeros(N)

In [74]:
input_batch, target_batch, conditions_batch = task.get_batch()
sigma_rec = sigma_inp = 0
y, predicted_output_rnn = RNN.run_multiple_trajectories(input_timeseries=input_batch,
                                                        sigma_rec=sigma_rec, sigma_inp=sigma_inp)
Y = y.reshape(RNN.N, -1).T
pca = iPCA(n_components=10, batch_size=1000)
pca.partial_fit(Y)
U = pca.components_

In [75]:
U.shape

(10, 100)

In [76]:
np.cumsum(pca.explained_variance_ratio_)

array([0.427626  , 0.63245977, 0.80556567, 0.92296227, 0.95432405,
       0.9703748 , 0.98058787, 0.9856451 , 0.98959716, 0.99180045])

In [77]:
y.shape

(100, 750, 242)

In [78]:
Uy = np.swapaxes(U @ np.swapaxes(y, 0, 1), 0, 1)

In [79]:
input_batch.shape

(6, 750, 242)

In [80]:
from tqdm.auto import tqdm
points = []
for i in tqdm(range(y.shape[-1])):
    for j in range(30):
        t = np.maximum(0, j*25-1)
        point_dict = {}
        point_dict["state"] = Uy[:, t, i]
        point_dict["input"] = input_batch[:, t, i]
        point_dict["C"] = U @ RNN.rhs_noisless(input=input_batch[:, t, i], y=y[:, t, i])
        point_dict["J"] = U @ RNN.rhs_jac(input=input_batch[:, t, i], y=y[:, t, i]) @ U.T
        points.append(deepcopy(point_dict))

  0%|          | 0/242 [00:00<?, ?it/s]


KeyboardInterrupt



In [None]:
pickle.dump(points, open('../sampled_points.pkl', "wb+"))