In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import sys, os

sys.path.append("../../")
from vi_rnn.vae import VAE
from vi_rnn.train import train_VAE
from vi_rnn.datasets import SineWave, Oscillations_Cont
from torch.utils.data import Dataset, DataLoader
from py_rnn.model import RNN, predict
from vi_rnn.utils import *
from py_rnn.train import train_rnn
from py_rnn.train import save_rnn, load_rnn
import matplotlib.pyplot as plt
from vi_rnn.saving import save_model
from py_rnn.default_params import get_default_params
import matplotlib as mpl

%matplotlib inline

In [None]:
train_teacher = False  # load already trained teacher model
data_dir = "../../data/student_teacher/"  # store inferred model
model_dir = "../../models/students/"  # store teacher RNN
cuda = True  # toggle if GPU is available

In [None]:
# initialise teacher RNN
model_params, training_params = get_default_params(n_rec=20)
training_params["l2_rates_reg"] = 0.1
rnn_osc = RNN(model_params)


# initialise teacher RNN's task
task_params = {
    "n_trials": 50,
    "dur": 200,
    "n_cycles": 4,
}
sine_task = SineWave(task_params)
x, y, m = sine_task[0]

In [None]:
# Plot teacher task
plt.plot(x, label="input")
plt.plot(y, label="output")
plt.plot(m, label="mask")
plt.legend()

In [None]:
# train or load teacher RNN

if train_teacher:
    losses, reg_losses = train_rnn(
        rnn_osc, training_params, sine_task, sync_wandb=False
    )
    save_rnn(
        data_dir + "osc_rnn_new", rnn_osc, model_params, task_params, training_params
    )
else:
    rnn_osc, model_params, task_params, training_params = load_rnn(data_dir + "osc_rnn")

In [None]:
# plot example output
rates, pred = predict(rnn_osc, torch.zeros(1000, 1))
fig, axs = plt.subplots(2, figsize=(4, 2))
axs[0].plot(pred[0, :, :])
axs[0].set_xlabel("timesteps")
axs[1].plot(rnn_osc.rnn.nonlinearity(torch.from_numpy(rates[0])));

In [None]:
# Extract weights
U, V, B = extract_orth_basis_rnn(rnn_osc)

In [None]:
# plot example trial plus the latent signal underlying it
batch_size = 4
task_params = {
    "dur": 75,
    "n_trials": 200,
    "name": "Sine",
    "n_neurons": 20,
    "out": "currents",
    "R_x": 0.1,
    "R_z": 0.2,
    "non_lin": nn.ReLU(),
}
task = Oscillations_Cont(task_params, U, V, B)
data_loader = DataLoader(task, batch_size=batch_size, shuffle=True)

In [None]:
tr_i = 0
rates = task.data[tr_i]
latent_code = task.latents[tr_i]
fig, ax = plt.subplots(1, 2, figsize=(4, 2))
T1 = 0
T2 = -1
ax[0].plot(latent_code[0, T1:T2].numpy(), latent_code[1, T1:T2].numpy())
ax[0].spines[["right", "top"]].set_visible(False)
ax[0].set_box_aspect(1)
ax[0].set_title("latent")
T1 = 0
T2 = -1
n_obs = 5
for i in range(n_obs):
    ax[1].plot(rates[i, T1:T2].T + i * 2)
ax[1].spines[["right", "top"]].set_visible(False)
ax[1].set_title("observed")

In [None]:
# Initialise VI / student setup

dim_z = 2
dim_N = task_params["n_neurons"]
dim_x = task_params["n_neurons"]
bs = 10
cuda = False
n_epochs = 1000
wandb = False
# initialise encoder


# initialise prior
rnn_params = {
     "transition": "low_rank",  
    "observation": "one_to_one",  
    "train_noise_x": True,
    "train_noise_z": True,
    "train_noise_z_t0": True,
    "init_noise_z": 0.1,
    "init_noise_z_t0": 1,
    "init_noise_x": task_params["R_x"],
    "noise_z": "full",
    "noise_x": "diag",
    "noise_z_t0": "full",
    "identity_readout": True,
    "activation": "relu",
    "decay": 0.7,
    "readout_from": task_params["out"],
    "train_obs_bias": False,
    "train_obs_weights": False,
    "train_neuron_bias": True,
    "weight_dist": "uniform",
    "weight_scaler": 1,  # /dim_N,
    "initial_state": "trainable",
    "obs_nonlinearity": "identity",
    "obs_likelihood": "Gauss",
    "simulate_input": True,}


training_params = {
    "lr": 1e-3,
    "lr_end": 1e-5,
    "grad_norm": 0,
    "n_epochs": n_epochs,
    "eval_epochs": 50,
    "batch_size": bs,
    "cuda": cuda,
    "smoothing": 20,
    "freq_cut_off": 10000,
    "k": 64,
    "loss_f": "opt_smc",
    "resample": "systematic",  # , multinomial or none"
    "run_eval": True,
    "smooth_at_eval": False,
    "init_state_eval": "posterior_sample",

}


VAE_params = {
    "dim_x": dim_x,
    "dim_z": dim_z,
    "dim_N": dim_N,
    "rnn_params": rnn_params,
}
seed = 1
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
vae = VAE(VAE_params)

In [None]:
# Train
# wandb=True
train_VAE(vae, training_params, task, sync_wandb=wandb, out_dir=model_dir, fname=None)

In [None]:
save_model(vae, training_params, task_params, name=model_dir + "SW20_1000_new2")

In [None]:
print("True noise: " + str(task_params["R_z"]))
print("Inferred noise:")
vae = orthogonalise_network(vae)
print(vae.rnn.std_embed_z(vae.rnn.R_z))
