In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import os, sys
sys.path.append(os.getcwd()+"/..")
from rnn_scripts.utils import *
from rnn_scripts.model import RNN
from scipy.optimize import fsolve

from rnn_scripts.utils import green_blue_colours
cls = green_blue_colours()

In [None]:
# Load a model

fig_dir=os.getcwd()+"/../figures/"
model_dir = os.getcwd()+"/../models/"
model = "N512_T1206-175730" # Rank 1 rate-model
rnn,params,task_params,training_params = load_rnn(model_dir+model)
rnn.rnn.svd_orth()
weight_scalers_to_1(rnn)

In [None]:
# Perform and plot clustering of a trained rank 1 network
# ---------------------------------------------

n_cl=2 # number of clusters

# extract loadings
loadings = extract_loadings(rnn, orth_I=False, zero_center=False)

# cluster
z,gmm = cluster(loadings,
    n_components=n_cl,
    bayes=True,
    n_init=500,
    random_state=None,
    mean_precision_prior=10e5,
    mean_prior=np.zeros(6),
    weight_concentration_prior_type="dirichlet_process",
    weight_concentration_prior=None,
    init_params="kmeans")

# make plot
covs = [gmm.covariances_[i] for i in range(1,-1,-1)]
titles =["Pop. "+str(i+1) + " (w = {:.2f})".format(np.sum(z==i)/len(z)) for i in np.arange(n_cl)]

vm = 20
fig,_ = plot_covs(covs, vm =vm,
                  labels = [r"$I_{osc}$", r"$I_{s_a}$", r"$I_{s_b}$", r"$n$", r"$m$", r"w"],
                  titles=titles,
                    figsize=(2.5,2), 
                    dpi=100, 
                    fontsize=8, 
                    float_labels=True,atol=0.5, float_lims=5,
                    label_fs=8,
                    numbers_fs=6)

plt.savefig(fig_dir + "FS_conn.svg")



In [None]:
# Define some helpers for finding fixed points

class Roots_FS():
    """ Finds roots (fixed points) of rank 1 RNN"""
    def __init__(self,alpha):
        self.alpha=alpha #integration constant
        self.input=torch.zeros(3)

    def __call__(self,x):
        """ find a fixed point close to x """
        X_in=np.outer(np.expand_dims(x,0),loadings[4])
        X_in = torch.from_numpy(X_in).float()
        I_in=torch.zeros((1,3))
        with torch.no_grad():
            X,O= rnn.rnn(I_in,X_in)
        dX = (X-X_in)/self.alpha
        dK = dX.numpy()@loadings[4]/np.linalg.norm(loadings[4])
        return dK.squeeze()
    
class Roots_MF():
    """ Finds roots (fixed points) of mean field eqs. describing rank 1 RNN"""
    def __init__(self,alpha):
        self.alpha=alpha
        self.input=torch.zeros(3)
    def __call__(self,x):
        """ find a fixed point close to x """
        K_in = torch.zeros((1,4))
        K_in[:,0]=torch.from_numpy(x)
        K_in[:,1:]=self.input.unsqueeze(0)
        I_in=torch.zeros((1,3))
        with torch.no_grad():
            K,O= rnn.rnn(I_in,K_in)
        dK = (K[:,0]-x)/self.alpha
        return dK.squeeze().numpy()


In [None]:
# Plot the dynamics of the trained rank 1 network
# -----------------------------------------------

ns=100
k_range=torch.linspace(-1,1,ns,)
loadings = extract_loadings(rnn)
X_in=np.outer(k_range,loadings[4])
X_in = torch.from_numpy(X_in).float()
alpha =.1
func = Roots_FS(alpha=alpha)
s=12

# no input
I_in = torch.zeros((ns,3))
with torch.no_grad():
    X,O = rnn.rnn(I_in,X_in)
    dX = (X-X_in)/alpha
    dK = dX.numpy()@loadings[4]/np.linalg.norm(loadings[4])

# make plot
fig,ax = plt.subplots(figsize=(1.1,1.1))
ax.plot(k_range,dK, color=cls[-1])
ax.plot(k_range,torch.zeros_like(k_range),ls='--',color='grey')
ax.set_ylabel(r"$\frac{d\kappa}{dt}$",rotation=0)
ax.set_xlabel(r"$\kappa$")
ax.set_xlim(-1,1)
ax.set_yticks([-2,0,2])

# Unstable fixed points
for stroots in [-1,0,1]:
    root = fsolve(func, stroots)
    ax.scatter(root,0,color='grey',s=s,zorder=10,edgecolors='black')

# Stable fixed points
for stroots in [-.2,.2]:
    root = fsolve(func, stroots)
    ax.scatter(root,0,color='white',s=s,zorder=10,edgecolors='black')
    
plt.savefig(fig_dir + "FS_dK.svg")

In [None]:
#Initialize a Meanfield RNN with a specific covariance matrix
#---------------------------------------------

dt = 2

# RNN settings
params =  {
    "nonlinearity" : "tanh",
    "out_nonlinearity":"tanh",
    "readout_kappa":False,
    "train_meanfield":True,
    "n_supports":2,
    "train_cov":False,
    "rank" : 1,
    "n_inp" : 3,
    "p_inp" : 1,
    "n_rec" :4,
    "p_rec" : 1,
    "n_out" : 1,
    "cov": None,
    "loadings":None,
    "cov_init_noise":0,
    "scale_w_inp": 1,
    "scale_w_out": 1,
    "scale_n": 1,
    "scale_m": 1,
    "1overN_out_scaling":True,
    "train_w_inp" :False,
    "train_w_inp_scale" : False,
    "train_w_rec": False,
    "train_b_rec": False,
    "train_m": False,
    "train_n": False,
    "train_taus":False,
    "train_w_out":False,
    "train_w_out_scale":False,
    "train_b_out": False,
    "train_x0":False,
    "tau_lims": [20],
    "dt" : dt,
    "noise_std":0,
    "scale_x0": 0.1,
    "randomise_x0":True,
    "orth_indices":[]
}

# Covariance settings
cov_params={
"sdII1":.5,
"sdIs1":10,
"sdIw1":-6,
"sdn1":500,
"sdmn1":4.5,
"sdm1":.1,
"sdW1":1000,
"sdII2":.5,
"sdIs2":10,
"sdIw2":16,
"sdn2":100,
"sdmn2":-7,
"sdm2":1,
"sdW2":1000
}

# initialise covariance matrices
chol_covs = create_MF_covs_R1(cov_params, plot=False)

# initialise RNNs
rnn = RNN(params)

# set covariance matrices
weights=np.array([.67,.33])
with torch.no_grad():
    rnn.rnn.cov_chols.copy_(torch.from_numpy(chol_covs))
    rnn.rnn.weights.copy_(torch.from_numpy(weights))

# plot the covariances
covs = [chol_covs[i]@chol_covs[i].T for i in range(0,2)]
titles =["Pop. "+str(i+1) + " (w = {:.2f})".format(weights[i]) for i in np.arange(n_cl)]
vm=np.max(np.abs(covs))
vm = 20
fig,_ = plot_covs(covs, vm =vm,
                  labels = [r"$I_{osc}$", r"$I_{s_a}$", r"$I_{s_b}$", r"$n$", r"$m$", r"w"],
                  titles=titles,
                    figsize=(2.5,2), 
                    dpi=100, 
                    fontsize=8, 
                    float_labels=True,atol=0.25, float_lims=5,
                    label_fs=8,
                    numbers_fs=6)

plt.savefig(fig_dir + "MF_conn.svg")


In [None]:
# Run simulations with the reduced equations
# ------------------------------------------
K_range =(np.array([-7,-4,-.2,.2,4,7]))
theta_range = np.arange(-np.pi,np.pi,np.pi*3)
tau =20
w= 7.5*2*np.pi
amp_v = 1/ np.sqrt(1 + (tau * w / 1000) ** 2)
T = .4
time_vec = np.arange(0,T+dt/1000,dt/1000)

# Create initial conditions and inputs
x0s, input_ICs, phases =  create_ICs_MF_R1(K_range,theta_range, tau, T, dt,w)
total=len(x0s)
w_phases = wrap(phases)

# Add stimulus input for the first 75 time steps
input_ICs_st1= input_ICs.clone()
input_ICs_st1[:,:75,1]=1
input_ICs_st2= input_ICs.clone()
input_ICs_st2[:,:75,2]=1

# Run simulations
ks0, O0 = predict(rnn,input_ICs, x0=x0s)
ks1, O1 = predict(rnn,input_ICs_st1, x0=x0s)
ks2, O2 = predict(rnn,input_ICs_st2, x0=x0s)


In [None]:
# Plot dynamics of the Meanfield RNN
# -----------------------------------------------

fig,ax = plt.subplots(figsize=(1.1,1.1))

lim1 = -5.2
lim2=5.2
k_range=torch.linspace(lim1,lim2,ns,)

func = Roots_MF(alpha=alpha)
func.input=torch.zeros(3)
func.input[0]=np.sqrt(amp_v)

K_in = torch.zeros((ns,4))
K_in[:,0]=k_range
K_in[:,1]=np.sqrt(amp_v)
I_in = torch.zeros((ns,3))

with torch.no_grad():
    K,O = rnn.rnn(I_in,K_in)
    dK = (K[:,0]-k_range)/alpha

# make plot
ax.set_xlim(lim1,lim2)
ax.set_xlabel(r"$\kappa$")
ax.set_ylim(-.75,.75)
ax.set_yticks([-.5,0,.5])
ax.set_yticklabels(["-.5","0",".5"])
ax.plot(k_range,dK, color=cls[-1])
ax.plot(k_range,torch.zeros_like(k_range),ls='--',color='grey')

# Unstable fixed points
for stroots in [-6,0,6]:
    root = fsolve(func, stroots)
    ax.scatter(root,0,color='grey',s=s,zorder=10,edgecolors='black')

# Stable fixed points
for stroots in [-1,1]:
    root = fsolve(func, stroots)
    ax.scatter(root,0,color='white',s=s,zorder=10,edgecolors='black')

plt.savefig(fig_dir + "MF_dK.svg")

In [None]:
# Run simulations with a number of finite size models with weights sampled from the covariance matrix

ks_fss1 = []
ks_fss2 = []
ks_fss0 = []
Os_fss0=[]
Os_fss1=[]
Os_fss2=[]

# inialise finite size model
params_fs = params.copy()
params_fs['train_meanfield']=False
params_fs['train_cov']=True
params_fs['n_rec']=4096
params_fs['noise_std']=0
params_fs['n_supports']=2
rnn_fs = RNN(params_fs)

# Set covariances
with torch.no_grad():
    rnn_fs.rnn.cov_chols.copy_(torch.from_numpy(chol_covs))
    rnn_fs.rnn.weights.copy_(torch.from_numpy(np.array([1-0.33349609375,0.33349609375])))

# run simulations
for modeli in range(3):

    # resample weights / loadings
    rnn_fs.rnn.resample()
    loadings = extract_loadings(rnn_fs)

    x0s_fs=np.outer(x0s[:,0],loadings[4])+np.outer(x0s[:,1],loadings[0])
    x0s_fs=torch.from_numpy(x0s_fs)

    # Only oscillatory input 
    rates_st0, Ofs0 = predict(rnn_fs,input_ICs, x0=x0s_fs)
    for ind in np.arange(total):
        ks_fs0 = np.array(proj(np.expand_dims(loadings[4],0),rates_st0[ind,:,:])).T
        ks_fss0.append(ks_fs0)
        Os_fss0.append(Ofs0[ind,:,:])

    # Stimulus 1
    rates_st1, Ofs1 = predict(rnn_fs,input_ICs_st1, x0=x0s_fs)
    for ind in np.arange(total):
        ks_fs1 = np.array(proj(np.expand_dims(loadings[4],0),rates_st1[ind,:,:])).T
        ks_fss1.append(ks_fs1)
        Os_fss1.append(Ofs1[ind,:,:])
    # Stimulus 2
    rates_st2, Ofs2 = predict(rnn_fs,input_ICs_st2, x0=x0s_fs)
    for ind in np.arange(total):
        ks_fs2 = np.array(proj(np.expand_dims(loadings[4],0),rates_st2[ind,:,:])).T
        ks_fss2.append(ks_fs2)
        Os_fss2.append(Ofs2[ind,:,:])

# Make arrays for easier handling
ks_fss0 = np.array(ks_fss0)
ks_fss1 = np.array(ks_fss1)
ks_fss2 = np.array(ks_fss2)
Os_fss0 = np.array(Os_fss0)
Os_fss1 = np.array(Os_fss1)
Os_fss2 = np.array(Os_fss2)

In [None]:
# Make summary plot of dynamics and output
# -----------------------------------------------

# x axis limits
lim1 = -6
lim2=6
s=20
k_range=torch.linspace(lim1,lim2,ns,)

func.input=torch.zeros(3)
func.input[0]=np.sqrt(amp_v)

FS_alpha = .15
fig,ax = plt.subplots(3,3,figsize=(4,4))
fig.tight_layout(pad=.2)

# First column: no stimulus
#--------------------------------------------

# Plot meanfield dynamics
K_in = torch.zeros((ns,4))
K_in[:,0]=k_range
K_in[:,1]=np.sqrt(amp_v)
I_in = torch.zeros((ns,3))
with torch.no_grad():
    K,O = rnn.rnn(I_in,K_in)
    dK = (K[:,0]-k_range)/alpha
ax[0,0].plot(k_range,dK, color=cls[-1])
ax[0,0].plot(k_range,torch.zeros_like(k_range),ls='--',color='grey')

# Find fixed points
for stroots in [-6,0,6]:
    root = fsolve(func, stroots)
    ax[0,0].scatter(root,0,color='grey',s=s,zorder=10,edgecolors='black')
for stroots in [-1,1]:
    root = fsolve(func, stroots)
    ax[0,0].scatter(root,0,color='white',s=s,zorder=10,edgecolors='black')
ax[0,0].set_xlim(lim1,lim2)
ax[0,0].set_ylabel(r"$\frac{d\kappa}{dt}$",rotation=0)
ax[0,0].set_xlabel(r"$\kappa$")
ax[0,0].set_ylim(-1,1)
ax[0,0].set_yticks([-1,0,1])
ax[0,0].set_box_aspect(1)

# Recurrent dynamics
ax[1,0].set_ylabel(r"$\kappa$")
ax[1,0].plot(time_vec, ks0[:,:,0].T, color =cls[-2]);
ax[1,0].set_ylim(lim1,lim2)
ax[1,0].plot(time_vec, ks_fss0[:,:,0].T,color='grey',alpha=FS_alpha,zorder=10)
ax[1,0].set_xlim(0,time_vec[-1])
ax[1,0].set_yticks([-5,0,5])
ax[1,0].set_xticks([0,0.2,0.4])
ax[1,0].set_xticklabels([])
ax[1,0].set_box_aspect(1)

# Output 
ax[2,0].set_ylabel(r"output")
ax[2,0].plot(time_vec[1:], O0[:,:,0].T, color ='darkred');
ax[2,0].plot(time_vec[1:], Os_fss0[:,:,0].T,color='grey',alpha=FS_alpha,zorder=10)
ax[2,0].set_xlim(0,time_vec[-1])
ax[2,0].set_xlabel(r"time (s)")
ax[2,0].set_xticks([0,0.2,0.4])
ax[2,0].set_xticklabels(["0",".2",".4"])
ax[2,0].set_ylim(-3.2,3.2)
ax[2,0].set_yticks([-2,0,2])
ax[2,0].set_box_aspect(1)

# Second column, stimulus 1
#--------------------------------------------

# Meanfield dynamics
K_in = torch.zeros((ns,4))
K_in[:,0]=k_range
K_in[:,1]=np.sqrt(amp_v)
K_in[:,2]=1

I_in = torch.zeros((ns,3))
with torch.no_grad():
    K,O = rnn.rnn(I_in,K_in)
    dK = (K[:,0]-k_range)/alpha
ax[0,1].plot(k_range,dK, color=cls[-1])
ax[0,1].plot(k_range,torch.zeros_like(k_range),ls='--',color='grey')
# Find fixed points
for stroots in [0]:
    root = fsolve(func, stroots)
    ax[0,1].scatter(root,0,color='grey',s=s,zorder=10,edgecolors='black')
ax[0,1].set_yticks([-2,0,2])
ax[0,1].set_box_aspect(1)
ax[0,1].set_xlim(lim1,lim2)

# Recurrent dynamics
ax[1,1].set_yticks([-5,0,5])
ax[1,1].set_xticks([0,0.2,0.4])
ax[1,1].set_xticklabels([])
ax[1,1].set_yticklabels([])
ax[1,1].set_box_aspect(1)
ax[1,1].set_xlim(0,time_vec[-1])
ax[1,1].plot(time_vec, ks_fss1[:,:,0].T,color='grey',alpha=FS_alpha,zorder=10)
ax[1,1].plot(time_vec, ks1[:,:,0].T, color =cls[-2]);
ax[1,1].set_ylim(lim1,lim2)

# Output
ax[2,1].plot(time_vec[1:], O1[:,:,0].T, color ='darkred');
ax[2,1].plot(time_vec[1:], Os_fss1[:,:,0].T,color='grey',alpha=FS_alpha,zorder=10)
ax[2,1].set_xlim(0,time_vec[-1])
ax[2,1].set_xticks([0,0.2,0.4])
ax[2,1].set_xticklabels(["0",".2",".4"])
ax[2,1].set_ylim(-3.2,3.2)
ax[2,1].set_yticks([-2,0,2])
ax[2,1].set_yticklabels([])
ax[2,1].set_box_aspect(1)

# Third column, stimulus 2
#--------------------------------------------

# Meanfield dynamics
K_in = torch.zeros((ns,4))
K_in[:,0]=k_range
K_in[:,1]=np.sqrt(amp_v)
K_in[:,3]=1
I_in = torch.zeros((ns,3))
with torch.no_grad():
    K,O = rnn.rnn(I_in,K_in)
    dK = (K[:,0]-k_range)/alpha
ax[0,2].plot(k_range,dK, color=cls[-1])
ax[0,2].plot(k_range,torch.zeros_like(k_range),ls='--',color='grey')

# Find fixed points
for stroots in [-6,6]:
    func.input[2]=1
    root = fsolve(func, stroots)
    ax[0,2].scatter(root,0,color='grey',s=s,zorder=10,edgecolors='black')
for stroots in [0]:
    root = fsolve(func, stroots)
    ax[0,2].scatter(root,0,color='white',s=s,zorder=10,edgecolors='black')
ax[0,2].set_xlim(lim1,lim2)
ax[0,2].set_box_aspect(1)

# Recurrent dynamics
ax[1,2].set_xlim(0,time_vec[-1])
ax[1,2].set_yticks([-5,0,5])
ax[1,2].set_xticks([0,0.2,0.4])
ax[1,2].set_xticklabels([])
ax[1,2].set_yticklabels([])
ax[1,2].set_box_aspect(1)
ax[1,2].plot(time_vec, ks2[:,:,0].T, color =cls[-2]);
ax[1,2].plot(time_vec, ks_fss2[:,:,0].T,color='grey',alpha=FS_alpha,zorder=10)
ax[1,2].set_ylim(lim1,lim2)

# Output
ax[2,2].plot(time_vec[1:], O2[:,:,0].T, color ='darkred');
ax[2,2].plot(time_vec[1:], Os_fss2[:,:,0].T,color='grey',alpha=FS_alpha,zorder=10);
ax[2,2].set_xticks([0,0.2,0.4])
ax[2,2].set_xticklabels(["0",".2",".4"])
ax[2,2].set_ylim(-3.2,3.2)
ax[2,2].set_xlim(0,time_vec[-1])
ax[2,2].set_yticks([-2,0,2])
ax[2,2].set_yticklabels([])
ax[2,2].set_box_aspect(1)

plt.savefig(fig_dir+ "supp_rate.svg")