In [1]:
import os
import json
import pickle
import sys
sys.path.insert(0, '../')
from src.DataSaver import DataSaver
from src.DynamicSystemAnalyzer import *
from src.PerformanceAnalyzer import *
from src.RNN_numpy import RNN_numpy
from src.utils import get_project_root, numpify, orthonormalize
from src.Trainer import Trainer
from src.RNN_torch import RNN_torch
from src.Task import *
from matplotlib import pyplot as plt
import torch
import time
from sklearn.decomposition import IncrementalPCA as iPCA
from sklearn.cluster import KMeans, SpectralClustering, DBSCAN
from sklearn.decomposition import TruncatedSVD
from sklearn.mixture import GaussianMixture
np.set_printoptions(suppress = True)
%matplotlib inline

In [2]:
disp = True
activation = "relu"
taskname = "CDDM"

In [28]:
RNN_folder = '0.0070184_20230222-083339'
# RNN_folder = '0.0077046_20230222-045211'
RNN_folder_full_path = os.path.join("../", "../", "rnn_coach", "data", "trained_RNNs", "CDDM", RNN_folder)
mse_score_RNN = os.listdir(RNN_folder_full_path)[0].split("_")[0]
rnn_config = json.load(open(os.path.join(RNN_folder_full_path, f"{mse_score_RNN}_config.json"), "rb+"))
rnn_data = json.load(open(os.path.join(RNN_folder_full_path, f"{mse_score_RNN}_params_CDDM.json"), "rb+"))
train_config_file = f"train_config_{taskname}_{activation}.json"

In [29]:
activation_name = rnn_config["activation"]
RNN_N = rnn_config["N"]
n_steps = rnn_config["n_steps"]
task_params = rnn_config["task_params"]
if activation_name == 'relu':
    activation_RNN = lambda x: torch.maximum(x, torch.tensor(0))
elif activation_name == 'tanh':
    activation_RNN = torch.tanh
elif activation_name == 'sigmoid':
    activation_RNN = lambda x: 1/(1 + torch.exp(-x))
elif activation_name == 'softplus':
    activation_RNN = lambda x: torch.log(1 + torch.exp(5 * x))
dt = rnn_config["dt"]
tau = rnn_config["tau"]
connectivity_density_rec = rnn_config["connectivity_density_rec"]
spectral_rad = rnn_config["sr"]
sigma_inp = rnn_config["sigma_inp"]
sigma_rec = rnn_config["sigma_rec"]
seed = np.random.randint(1000000)
print(f"seed: {seed}")
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
rng = torch.Generator(device=torch.device(device))
if not seed is None:
    rng.manual_seed(seed)
input_size = np.array(rnn_data["W_inp"]).shape[1]
output_size = np.array(rnn_data["W_out"]).shape[0]

RNN = RNN_torch(N=RNN_N, dt=dt, tau=tau, input_size=input_size, output_size=output_size,
                      activation=activation_RNN, random_generator=rng, device=device,
                      sigma_rec=sigma_rec, sigma_inp=sigma_inp)
RNN_params = {"W_inp": np.array(rnn_data["W_inp"]),
              "W_rec": np.array(rnn_data["W_rec"]),
              "W_out": np.array(rnn_data["W_out"]),
              "b_rec": np.array(rnn_data["bias_rec"]),
              "y_init": np.zeros(RNN_N)}
RNN.set_params(RNN_params)
Task = eval("Task" + taskname)(n_steps=n_steps, n_inputs=input_size, n_outputs=output_size, task_params=task_params)


seed: 741209
Using cpu for RNN!


In [30]:
input_batch, target_batch, conditions_batch = Task.get_batch()
RNN.sigma_rec = RNN.sigma_inp = torch.tensor(0, device=RNN.device)
y, predicted_output_rnn = RNN(torch.from_numpy(input_batch.astype("float32")).to(RNN.device))
Y = np.hstack([y.detach().numpy()[:, :, i] for i in range(y.shape[-1])]).T

In [31]:
Y_mean = np.mean(np.abs(Y), axis=0)
th = np.percentile(Y_mean, 50)
inds = np.where(Y_mean > th)[0]

In [32]:
fig = plt.figure(figsize = (5,2))
plt.plot(np.sort(Y_mean))
plt.axhline(th, linestyle='--', color='k')
plt.show()

<IPython.core.display.Javascript object>

In [33]:
svd = TruncatedSVD(n_components=20, n_iter=11)
svd.fit(Y)

In [34]:
Y_projected = svd.components_.T
Y_projected_filtered = Y_projected[inds, :]

In [35]:
svd.explained_variance_ratio_

array([0.16951543, 0.414403  , 0.19579157, 0.13889371, 0.02602155,
       0.02754449, 0.00729338, 0.00456504, 0.00395573, 0.00344499,
       0.00164066, 0.00129018, 0.0011503 , 0.00078292, 0.00069209,
       0.00058184, 0.0004403 , 0.00032033, 0.00028983, 0.00024822,
       0.00022541, 0.00015467, 0.00011755, 0.00011502, 0.00009506,
       0.00008256, 0.00007871, 0.00007269, 0.00005596, 0.00005069,
       0.00004395, 0.0000384 , 0.00003668, 0.00003334, 0.00002745,
       0.00002596, 0.00002243, 0.00001937, 0.0000178 , 0.00001472],
      dtype=float32)

In [36]:
gm = GaussianMixture(n_components=8, max_iter=10000, n_init=101, tol=1e-13, verbose=1, init_params='k-means++')
lbls = gm.fit_predict(Y_projected_filtered)

Initialization 0
Initialization converged: True
Initialization 1
Initialization converged: True
Initialization 2
Initialization converged: True
Initialization 3
Initialization converged: True
Initialization 4
Initialization converged: True
Initialization 5
Initialization converged: True
Initialization 6
Initialization converged: True
Initialization 7
Initialization converged: True
Initialization 8
Initialization converged: True
Initialization 9
Initialization converged: True
Initialization 10
Initialization converged: True
Initialization 11
Initialization converged: True
Initialization 12
Initialization converged: True
Initialization 13
Initialization converged: True
Initialization 14
Initialization converged: True
Initialization 15
Initialization converged: True
Initialization 16
Initialization converged: True
Initialization 17
Initialization converged: True
Initialization 18
Initialization converged: True
Initialization 19
Initialization converged: True
Initialization 20
Initializati

In [37]:
Y_projected_filtered.shape

(50, 40)

In [38]:
%matplotlib notebook
colors = ['r', 'lightgreen', 'lightblue', 'yellow', 'orange', 'magenta', 'cyan', 'pink', 'k']

fig = plt.figure()
ax = fig.add_subplot(projection='3d')
for i, lbl in enumerate(lbls):
    xs = Y_projected_filtered[i, 0]
    ys = Y_projected_filtered[i, 1]
    zs = Y_projected_filtered[i, 2]
    ax.scatter(xs, ys, zs, color=colors[lbl], edgecolor='k')
#     ax.scatter(xs, ys, zs, color='r', s=30, edgecolor='k')
plt.show()

<IPython.core.display.Javascript object>

In [50]:
trial_num = 68
fig = plt.figure()
tmp = np.where(lbls == np.int32(0))[0]
inds_lbl = [inds[i] for i in tmp]
for ind in inds_lbl:
    plt.plot(y[ind, :, trial_num].detach().numpy())
plt.show()

<IPython.core.display.Javascript object>

IndexError: index 9 is out of bounds for axis 0 with size 7