In [1]:
import sys
from copy import deepcopy
from geoopt.optim import RiemannianAdam
sys.path.append("./")
sys.path.append("../")
sys.path.append("../../")
import numpy as np
import json
import pickle
import os
import torch
from rnn_coach.src.RNN_torch import *
from rnn_coach.src.DynamicSystemAnalyzer import *
from rnn_coach.src.RNN_numpy import *
from rnn_coach.src.Task import *
from rnn_coach.src.DataSaver import *
from src.utils import jsonify
from latent_circuit_inference.src.LatentCircuit import *
from latent_circuit_inference.src.LatentCircuitFitter import *
from latent_circuit_inference.src.LCAnalyzer import *
from latent_circuit_inference.src.utils import *
from latent_circuit_inference.src.circuit_vizualization import *
from matplotlib import pyplot as plt
from scipy.stats import zscore
import pickle
import json
from pathlib import Path
from tqdm.auto import tqdm
from rnn_coach.src.utils import get_colormaps
colors, cmp = get_colormaps()

In [2]:
def mse_scoring(x, y):
    return np.mean((x - y) ** 2)

def R2(x, y):
    return 1.0 - mse_scoring(x, y)/np.var(y)

def plot_matrix(mat, vmin=None, vmax=None, show_numbers = False, figsize = (7,7)):
    if vmin is None:
        vmin = np.min(mat)
    if vmax is None:
        vmax = np.max(mat)    
    fig, ax = plt.subplots(1, 1, figsize = figsize)
    img = ax.imshow(mat, cmap=cmp, vmin = vmin, vmax = vmax)
    if show_numbers:
        for (i, j), z in np.ndenumerate(mat):
            if np.abs(z) > 0.01:
                ax.text(j, i, str(np.round(z, 2)), ha="center", va="center", color='k', fontsize=7)
    ax.set_xticks(np.arange(mat.shape[1])[::2])
    ax.set_yticks(np.arange(mat.shape[0])[::2])
    plt.show()

    
def permute_input_matrix(mat, order):
    new_mat = np.empty_like(mat)
    for i, r in enumerate(order):
        new_mat[i, :] = mat[r, :]
    return new_mat

projects_folder = str(Path.home()) + "/Documents/GitHub/"
RNN = '0.0018708_CDDM_tanh;tanh;N=100;lmbdo=0.3;lmbdr=0.5;lr=0.002;maxiter=3000'
RNN_folder = RNN
RNNs_path = os.path.join(projects_folder, "rnn_coach", "data", "trained_RNNs", "CDDM_tanh")
RNN_score = float(RNN.split("_")[0])
RNN_path = os.path.join(RNNs_path, RNN)
rnn_config = json.load(open(os.path.join(RNN_path, f"{RNN_score}_config.json"), "rb+"))
rnn_data = json.load(open(os.path.join(RNN_path, f"{RNN_score}_params_CDDM_tanh.json"), "rb+"))
train_config_file = f"train_config_CDDM_relu.json"

activation_name = rnn_config["activation"]
RNN_N = rnn_config["N"]
n_steps = rnn_config["n_steps"]
task_params = rnn_config["task_params"]
activation_torch = lambda x: torch.maximum(x, torch.tensor(0))
dt = rnn_config["dt"]
tau = rnn_config["tau"]
connectivity_density_rec = rnn_config["connectivity_density_rec"]
spectral_rad = rnn_config["sr"]
sigma_inp = 0.03
sigma_rec = 0.03
seed = np.random.randint(1000000)


print(f"seed: {seed}")
device = torch.device('cpu')
rng = torch.Generator(device=torch.device(device))
if not seed is None:
    rng.manual_seed(seed)
input_size = np.array(rnn_data["W_inp"]).shape[1]
output_size = np.array(rnn_data["W_out"]).shape[0]
mask = np.array(rnn_config["mask"])

tag = '8nodes'
LCI_config_file = json.load(open(os.path.join(projects_folder, "latent_circuit_inference", "data", "configs", f"LCI_config_{tag}.json"), mode="r", encoding='utf-8'))
task_data = rnn_config["task_params"]
tmp = task_data["coherences"][-1] * np.logspace(-(5 - 1), 0, 5, base=2)
coherences = np.concatenate([-np.array(tmp[::-1]), np.array([0]), np.array(tmp)]).tolist()
task_data["coherences"] = deepcopy(coherences)

seed: 650538


In [3]:
dt = rnn_data["dt"]
tau = rnn_data["tau"]

N = rnn_data["N"]
W_inp = np.array(rnn_data["W_inp"])
W_rec = np.array(rnn_data["W_rec"])
W_out = np.array(rnn_data["W_out"])
activation_numpy = lambda x: np.tanh(0, x)

In [4]:
task_params_ = deepcopy(task_params)
task_params_["coherences"] = np.array([-0.8, -0.5, -0.2, -0.1, 0, 0.1, 0.2, 0.5, 0.8])
# task_params_["coherences"] = np.array([-0.8, 0, 0.8])

In [5]:
# set up torch RNN
activation_numpy = lambda x: np.tanh(x)
RNN = RNN_numpy(N=N,
                dt=dt, tau=tau,
                W_inp=W_inp,
                W_rec=W_rec,
                W_out=W_out,
                activation=activation_numpy)
task = TaskCDDM_tanh(n_steps=n_steps, n_inputs=input_size, n_outputs=output_size, task_params=task_params_)
inputs, targets, conditions = task.get_batch()

In [6]:
# Task:
n_steps = task_data["n_steps"]

LCn = 9 #(10 clusters)
N_PCs = 20 # PCAs
inp_connectivity_mask = np.zeros((LCn, input_size))
inp_connectivity_mask[:input_size, :input_size] = np.eye(input_size)

w_inp_init = deepcopy(inp_connectivity_mask)
w_rec_init = np.random.randn(LCn, LCn)
rec_connectivity_mask = np.ones_like(w_rec_init)
w_out_init = np.random.randn(output_size, LCn)
out_connectivity_mask = np.ones_like(w_out_init)

# Fitter:
lambda_w = 0.5
max_iter = 1000
tol = LCI_config_file["tol"]
lr = 0.01
actvation_name = LCI_config_file["activation"]
Qinitialization = LCI_config_file["Qinitialization"]

In [7]:
# # creating instances:
activation_torch = lambda x: torch.tanh(x)
rnn_torch = RNN_torch(N=RNN_N, dt=dt, tau=tau, input_size=input_size, output_size=output_size,
                      activation=activation_torch, random_generator=rng, device=device,
                      sigma_rec=sigma_rec, sigma_inp=sigma_inp)
RNN_params = {"W_inp": np.array(rnn_data["W_inp"]),
              "W_rec": np.array(rnn_data["W_rec"]),
              "W_out": np.array(rnn_data["W_out"]),
              "b_rec": np.array(rnn_data["bias_rec"]),
              "y_init": np.zeros(RNN_N)}
rnn_torch.set_params(RNN_params)

Using cpu for RNN!


In [8]:
task = TaskCDDM_tanh(n_steps=n_steps, n_inputs=input_size, n_outputs=output_size, task_params=task_data)

In [9]:
lc = LatentCircuit(N=LCn,
                   num_inputs=input_size,
                   num_outputs=output_size,
                   W_inp=torch.Tensor(w_inp_init).to(device),
                   W_out=torch.Tensor(w_out_init).to(device),
                   inp_connectivity_mask=torch.Tensor(inp_connectivity_mask).to(device),
                   rec_connectivity_mask=torch.Tensor(rec_connectivity_mask).to(device),
                   out_connectivity_mask=torch.Tensor(out_connectivity_mask).to(device),
                   dale_mask = None,
                   activation=activation_torch,
                   sigma_rec=sigma_rec,
                   sigma_inp=sigma_inp,
                   device=device,
                   random_generator=rng)

# lc.recurrent_layer.weight.data = deepcopy(torch.from_numpy(w_rec_init.astype("float32")))

Using cpu for Latent Circuit!


In [10]:
criterion = torch.nn.MSELoss()
fitter = LatentCircuitFitter(LatentCircuit=lc,
                             RNN=rnn_torch,
                             Task=task,
                             N_PCs = N_PCs,
                             encoding = True,
                             max_iter=max_iter,
                             tol=tol, lr = lr,
                             criterion=criterion,
                             lambda_w=lambda_w,
                             Qinitialization=False)

setting projection of RNN traces on the lower subspace


In [14]:
print(fitter.LatentCircuit.input_layer)

Linear(in_features=6, out_features=9, bias=False)


In [11]:
lc_inferred, train_losses, val_losses, net_params = fitter.run_training()

RuntimeError: The size of tensor a (4) must match the size of tensor b (6) at non-singleton dimension 1