In [None]:
import numpy as np
from torch.utils.data import DataLoader
import os, sys
sys.path.append(os.getcwd()+"/..")
from rnn_scripts.model import *
from rnn_scripts.train import *
from rnn_scripts.utils import *
from tasks.seqDS import *
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Load models
fig_dir=os.getcwd()+"/../figures/"
model_dir = os.getcwd()+"/../models/"

#model = "N512_T0217-141442" #rat 1
model = "N512_T0217-151523" #rat 2
#model = "N512_T0217-151542" #rat 3

rnn,params,task_params,training_params = load_rnn(model_dir+model)
rnn.rnn.svd_orth()
weight_scalers_to_1(rnn)

In [None]:
# Run model to get baseline loss
ds = seqDS(task_params)
dataloader = DataLoader(
    ds, batch_size=128, shuffle=True
)
test_input, test_target, test_mask = next(iter(dataloader))
rates, pred,bloss = predict(rnn, test_input,mse_loss, test_target, test_mask,return_loss=True)



In [None]:
def resample(gmm,params):
    """Resample loadings of an RNN, using a fitted mixture model"""
    loadings = gmm.sample(512)
    params['loadings']=loadings[0].T
    params['scale_w_out']=1
    params['scale_w_inp']=1
    rnn_rs =RNN(params)
    return rnn_rs

In [None]:
loadings = extract_loadings(rnn, orth_I=False, zero_center=False)
n_test = 30
accs = np.zeros((7,n_test))
for i in range(1,8):
    z,gmm = cluster(loadings,i,n_init=500)
    for j in range(n_test):
        rnn_rs=resample(gmm,params)
        rates, pred,loss = predict(rnn_rs, test_input,mse_loss, test_target, test_mask,return_loss=True)
        accs[i-1,j]=loss

In [None]:
# Plot accuracies over number of mixture components

plt.figure(figsize=(1.75,1.5),dpi=150)
plt.tight_layout()
alpha=0.5
for j in range(n_test):
    plt.scatter(np.arange(1,8),accs[:,j],color='black',alpha=alpha)
plt.axhline(bloss,label='baseline',color='red')
plt.ylim(-0.1,1.1)
plt.xlim(0.7,7.3)
plt.yticks([0,0.5,1],labels=[])#,labels=['0','.5','1'])
plt.xlabel("number of components")
plt.xticks(np.arange(1,8))
plt.savefig(fig_dir + "accs_rat2.svg")

In [None]:
# Cluster using 3 components

z,gmm = cluster(loadings,3)

In [None]:
# Plot empirical covariance matrices 

covs = [np.cov(loadings[:7,z==i]) for i in np.arange(3)]
titles =["w = {:.2f}".format(np.sum(z==i)/len(z)) for i in np.arange(3)]
fig,_ = plot_covs(covs, vm =np.max(np.abs(covs)),
                  labels = [r"$I_{osc}}$", r"$I_{s_a}$", r"$I_{s_b}$", r"$n_1$", r"$n_2$", r"$m_1$", r"$m_2$"],
                 titles = titles,
                 float_labels=True,
                 atol=0.5
)
plt.savefig(fig_dir + "FS_conn.svg")



In [None]:
# Plot pair plots of the loadings
for cl in np.arange(3):
    fig = plot_loadings(loadings[:7,z==cl],z[z==cl],alpha=1,colors=['#7ECCB9','#A382BB','#7ECCB9'],hist_lims=np.max(np.abs(loadings)))
plt.savefig(fig_dir + "loadings"+str(cl)+".svg")
