# Effects of criticality on short term memory in a Reservoir Computing network

This notebook investigates the effects of network criticality (in terms of network parameters) on tasks that require shrot term memory, namely direct memory retrieval and a 3-bit parity task (non-linear computation). The network used is a Reservoir Computing network (RNN with fixed weights and only a traineable output layer) where the encoder weights and network weights are initialized following Gaussian distributions with 0 mean and a given variance. The details of how the networks were trained is explained in the thesis and can be seen in the code.

In [None]:
#import packages

import sklearn as skl
import matplotlib.pyplot as plt
import numpy as np
import sys
import scipy.stats as sps
from scipy.sparse import csr_matrix
import scipy.integrate as integrate
import warnings
warnings.filterwarnings('ignore')

from sklearn.utils.validation import check_is_fitted
from sklearn.cluster import SpectralClustering
from scipy.io import loadmat
from sklearn.linear_model import RidgeClassifier, LinearRegression
from scipy.spatial import distance

import networkx as nx

import math
import random

# 1. Model set up

The following class defines the parameters and strucutre of the network

In [None]:
class RNN_setup(skl.base.BaseEstimator, skl.base.TransformerMixin):
    def __init__(self,n_spins, K, var_w , var_e):

        self.n_spins = n_spins
        self.K = K #number of units projecting to each spin
        self.var_w = var_w #variance of weights, gaussain with 0 mean
        self.var_e = var_e #variance of encoder, gaussain with 0 mean
        
    def setConnectivity(self):

        rows = np.repeat(np.arange(self.n_spins),self.K)
        cols = np.array([])
        
        for i in range(self.n_spins):
            idx = np.random.choice(self.n_spins, self.K, replace = False)
            cols = np.append(cols,idx).astype(int)
        
        data = np.ones(self.n_spins*self.K)
        
        C_mat = csr_matrix((data, (rows, cols)), shape=(self.n_spins, self.n_spins))
               
        C_mat.setdiag(0,k=0)#no self connectivity in the model        
        return C_mat
    
    def setWeights(self):

        C_mat = self.setConnectivity()#connectivity matrix
        
        
        rows, cols = C_mat.nonzero()
        J_mat = C_mat.copy()
        
        for i,j in zip(rows,cols):
            J_mat[i,j] =np.random.normal(0, np.sqrt(self.var_w))
            #J_mat[i,j] = (np.random.randint(2)*2 -1)
        return J_mat
    
    def setEncoder(self):
        '''
        Generate the encoder
        '''
        encoder = np.random.normal(0, np.sqrt(self.var_e), self.n_spins)
        return encoder  

# 2. Running the model

The following class runs the network for a given period of time and observes its evolution

In [None]:
class RNN_run(skl.base.BaseEstimator, skl.base.TransformerMixin):
    def __init__(self, n_spins, J, enc, mean_u = 0,  r = 0.5, encoder = True):
            
        self.n_spins = n_spins
        self.r = r #rate of input u(.)
        self.mean_u = mean_u
        self.J = J
        self.enc = enc
        self.encoder = encoder #boolean statement of whether to have an encoder or not
        self.spins = np.random.randint(2,size = self.n_spins)*2 -1 #initialize the network configuration
        

    def generateInput(self, time_steps):
        '''
        Generate one input variable
        '''
        u_t = np.random.binomial(1, self.r, time_steps)*2-1 + self.mean_u #generate random input at time t
        
        return u_t
    
    def updateSpins(self, u):
        '''
        Updates the nodes as a Heavised step funciton, returns also the input.
        '''
        
        if self.encoder == True:
            u_t = self.enc*u
            
        else:   
            u_t = u
        
        h = self.J@self.spins + u_t
        
        #spin update, if h >= 0, set x(t) to 1, else set to 0
        temp_nodes = np.ones(self.n_spins)
        idx = np.where(h<0)
        temp_nodes[idx] = -1
        
        self.spins = temp_nodes
        
        return temp_nodes #just for sanity check, remove later
    
    def runsim(self, t_time):
        '''
        Run the simulation for t_time time steps and save the state of the network
        '''
        
        data = np.zeros((self.n_spins, t_time)) #empty matrix where each column is the state of the network at time t
        u_input = self.generateInput(t_time) #generate input for t time steps
        u_input_parity = u_input[0:-2]*u_input[1:-1]*u_input[2:]#generate the xor of the input
        
        for i, u in enumerate(u_input):
            spins_t = self.updateSpins(u)
            #print(np.dot(spins_t, self.spins)/self.n_spins)
            data[:,i] = self.spins
        
        return data, u_input, u_input_parity
                   

# 3. Train and test the model

The following class trains and tests the network (for a given set of parameters) on the direct memory task and on the 3-bit parity task. The parameters define how far and to what side of the critical line the network is operating in.

In [None]:
class RNN_results(skl.base.BaseEstimator, skl.base.TransformerMixin):
    def __init__(self, K, tau, var_e, task, var_w, n_spins = 300, time_train = 1000, time_test = 500,  
                 burn_t = 500, num_init = 10, num_iter = 50, n_pert = 10, var_in = False):
            
        self.K = K
        self.tau = tau
        self.var_e = var_e
        
        self.var_in = var_in #if true give inputs for variance, otherwise generate them in the class
        self.var_w_list = var_w #if you want to give in specific var_w values
        
        self.var_w = self.critical_line_encoder() #initialize, then use critical line values
        self.n_spins = n_spins
        self.time_train = time_train
        self.time_test = time_test
        self.burn_t = burn_t
        self.num_init = num_init
        self.num_iter = num_iter
        self.n_pert= n_pert 
        #task type, can either be xor or direct memory
        self.task = task
        
    
    def shift_dataset(self,X, y, t, b_t):
        '''
        Pairs state of the network at time t with input at time t-tau
        '''
        X_shift = X[:,b_t+t:]
        y_shift = y[b_t:-t]

        return X_shift, y_shift
    
    def generate_data(self, t, w):
    
        #initialize the trianing and test sets
        X_train = np.zeros((self.n_spins,1))
        y_train = np.zeros(1)
        X_test = np.zeros((self.n_spins,1))
        y_test = np.zeros(1)
        
        #generate the network
        RNN = RNN_setup(n_spins = self.n_spins, K = self.K, var_w = w, var_e = self.var_e) 
        #get the weight matrix
        J_mat = RNN.setWeights()
        #get the encoder vector
        enc = RNN.setEncoder() 

        for ni in range(self.num_init):
        #run two separate networks
            RNN_r1 = RNN_run(self.n_spins, J_mat, enc, encoder = True)
            RNN_r2 = RNN_run(self.n_spins, J_mat, enc, encoder = True)
            
            #for the direct memory task select the direct input as the labels
            if task == "direct_memory":
                X1, u1, u1_par = RNN_r1.runsim(self.time_train+self.burn_t)
                X2, u2, u2_par = RNN_r2.runsim(self.time_test+self.burn_t)            
                X_tr, y_tr = self.shift_dataset(X1, u1, t, self.burn_t)
                X_te, y_te = self.shift_dataset(X2, u2, t, self.burn_t)
            
            #for the xor task select the xor of the input as the labels
            if task == "3parity":  
                X1, u1, u1_par = RNN_r1.runsim(self.time_train+self.burn_t)
                X2, u2, u2_par = RNN_r2.runsim(self.time_test+self.burn_t)
                #as the xor of the input has one less index, remove first instance of the data (does not compute xor)
                X_tr, y_tr = self.shift_dataset(X1[:,2:], u1_par, t, self.burn_t)
                X_te, y_te = self.shift_dataset(X2[:,2:], u2_par, t, self.burn_t)


            X_train = np.concatenate((X_train, X_tr), axis = 1)
            X_test = np.concatenate((X_test, X_te), axis = 1)
            y_train = np.concatenate((y_train, y_tr))
            y_test = np.concatenate((y_test, y_te))
        
        #remove all the initialization zeros
        return X_train[:,1:].T, y_train[1:], X_test[:,1:].T, y_test[1:], u1_par, u1
    
    
    def critical_line_encoder(self):
        '''
        Takes in a set of sigma_e (standard deviation for the encoder) values and computes the corresponding 
        sigma_w vales at criticality.
        '''
        
        sigma_e = np.sqrt(self.var_e)
        sigma_w_val = 0

        tan_arg = np.tan(np.pi/(2*self.K)) #tangent expression       
        sigma_w_val = (sigma_e*tan_arg)/(np.sqrt(1- self.K*(tan_arg**2) + (tan_arg**2)))
        
        #returns the variance of the network weights
        return sigma_w_val**2
        

    def train_test_tau(self):
        '''
        Train and test the network for differen tau values and fixed ratio of sigma_e, sigma_w.
        '''
        if self.var_in == True:
            print("True")
            var_w_perturbations = self.var_w_list
        else:
            print("False")
            var_w_perturbations = np.linspace(self.var_w - 0.7*self.var_w, self.var_w+0.2*self.var_w, self.n_pert)

        #initialize scores
        mean_accuracy = np.zeros((var_w_perturbations.shape[0],self.tau.shape[0], self.num_iter))
        MI_score = np.zeros((var_w_perturbations.shape[0],self.tau.shape[0], self.num_iter))
        
        #loop over sigma_w values and tau
        for it in range(self.num_iter):
            for i, w in enumerate(var_w_perturbations):
                for j, t in enumerate(self.tau):

                    X_train, y_train, X_test, y_test, u1_par, u1 = self.generate_data(t, w)#generate the data to trian the network
                    
                    #fit the linear decoder: using Ridge regression
                    clf = RidgeClassifier().fit(X_train, y_train)
                    
                    #compute the mean accuracy and normalized MI scores
                    mean_accuracy[i,j, it] = clf.score(X_test, y_test)
                    MI_score[i,j, it] = skl.metrics.normalized_mutual_info_score(y_test, clf.predict(X_test))
                    

        return mean_accuracy, MI_score, var_w_perturbations, self.var_w, u1_par, u1
        

# 4. Results on the direct memory task

Train the network for different parameters (different distances from the critical line) and look at the results on the test set.

In [None]:
K = 8
n_spins = 300
tau = np.arange(9)+1
var_w = np.array([0.137, 0.274, 0.301])
var_e = 5
sigma_e_critical = np.array([np.sqrt(var_e)])
perturbations = 10
it = 50
task = "direct_memory"

In [None]:
RNN_r = RNN_results(K, tau, var_e, task, var_w, n_spins = n_spins, time_train = 1000, time_test = 500,  
                 burn_t = 500, num_init = 10, num_iter = it, n_pert = perturbations, var_in = True)

mean_accuracy, MI_score, var_w_perturbed, var_w_critical, u_par, u = RNN_r.train_test_tau()

In [None]:
#path to data in gonzaga

path = "/home/elosegui/MSc_thesis_project/numpy_results/direct_memory_task/"

In [None]:
#save the files
np.save(path+"accuracy_"+"K" +str(K)+"_iterations"+str(it)+"_N"+str(n_spins)+"_var_e"+str(var_e), mean_accuracy)
np.save(path+"MI_"+"K" +str(K)+"_iterations"+str(it)+"_N"+str(n_spins)+"_var_e"+str(var_e), MI_score)
np.save(path+"var_w_perturbed_"+"K" +str(K)+"_iterations"+str(it)+"_N"+str(n_spins)+"_var_e"+str(var_e), var_w_perturbed)
np.save(path+"var_w_critical_"+"K" +str(K)+"_iterations"+str(it)+"_N"+str(n_spins)+"_var_e"+str(var_e), var_w_critical)

In [None]:
#load data
mean_accuracy = np.load(path+"accuracy_"+"K" +str(K)+"_iterations"+str(it)+"_N"+str(n_spins)+"_var_e"+str(var_e)+".npy")
MI_score = np.load(path+"MI_"+"K" +str(K)+"_iterations"+str(it)+"_N"+str(n_spins)+"_var_e"+str(var_e)+".npy")
var_w_perturbed = np.load(path+"var_w_perturbed_"+"K" +str(K)+"_iterations"+str(it)+"_N"+str(n_spins)+"_var_e"+str(var_e)+".npy")
var_w_critical = np.load(path+"var_w_critical_"+"K" +str(K)+"_iterations"+str(it)+"_N"+str(n_spins)+"_var_e"+str(var_e)+".npy")


### 4.1 Preliminary plots to look at the mean over trials of the MI score and the standard deviation

In [None]:

MI_mean = np.mean(MI_score, axis = 2)
MI_std = np.std(MI_score, axis = 2)

acc_mean = np.mean(mean_accuracy, axis = 2)
acc_std = np.std(mean_accuracy, axis = 2)

plt.plot(tau,acc_mean.T)
plt.xlabel("tau")
plt.ylabel("Mean acuracy score")
plt.title("Mean acuracy for varying distances from the critical line")
#plt.legend(["Letter condition", allLetters[i] for i in range(0,10)])
plt.legend(["var_w {}".format(np.round(w,3)) for w in var_w_perturbed], bbox_to_anchor=(1.05, 1), loc='upper left')

print(var_w_critical)



In [None]:
MI_mean = np.mean(MI_score, axis = 2)
MI_std = np.std(MI_score, axis = 2)

acc_mean = np.mean(mean_accuracy, axis = 2)
acc_std = np.std(mean_accuracy, axis = 2)

plt.figure(figsize=(10,5))
plt.plot(tau[3:5],acc_std.T[3:5,3:9])
plt.xlabel("tau")
plt.ylabel("Standard deviation")
plt.title("Standard deviation for varying distances from the critical line")
#plt.legend(["Letter condition", allLetters[i] for i in range(0,10)])
plt.legend(["var_w {}".format(np.round(w,3)) for w in var_w_perturbed[3:9]], bbox_to_anchor=(1.05, 1), loc='upper left')

print(var_w_critical)


# 5. Results on the 3-bit parity task

In [None]:
#define the parameters

K = 8
n_spins = 300
tau = np.arange(9)+1
var_e = 5
perturbations = 10
it = 50
var_w = np.array([])
task = "3parity"

In [None]:
#train and test the network

RNN_r = RNN_results(K, tau, var_e, task, var_w, n_spins = 300, time_train = 1000, time_test = 500,  
                 burn_t = 500, num_init = 10, num_iter = 50, n_pert = 10, var_in = False)

mean_accuracy, MI_score, var_w_perturbed, var_w_critical, u_par, u = RNN_r.train_test_tau()

In [None]:
#path to data in gonzaga

path = "/home/elosegui/MSc_thesis_project/numpy_results/3bit_parity/"

In [None]:
#save the files
np.save(path+"accuracy_"+"K" +str(K)+"_iterations"+str(it)+"_N"+str(n_spins)+"_var_e"+str(var_e), mean_accuracy)
np.save(path+"MI_"+"K" +str(K)+"_iterations"+str(it)+"_N"+str(n_spins)+"_var_e"+str(var_e), MI_score)
np.save(path+"var_w_perturbed_"+"K" +str(K)+"_iterations"+str(it)+"_N"+str(n_spins)+"_var_e"+str(var_e), var_w_perturbed)
np.save(path+"var_w_critical_"+"K" +str(K)+"_iterations"+str(it)+"_N"+str(n_spins)+"_var_e"+str(var_e), var_w_critical)



In [None]:
#load data
mean_accuracy = np.load(path+"accuracy_"+"K" +str(K)+"_iterations"+str(it)+"_N"+str(n_spins)+"_var_e"+str(var_e)+".npy")
MI_score = np.load(path+"MI_"+"K" +str(K)+"_iterations"+str(it)+"_N"+str(n_spins)+"_var_e"+str(var_e)+".npy")
var_w_perturbed = np.load(path+"var_w_perturbed_"+"K" +str(K)+"_iterations"+str(it)+"_N"+str(n_spins)+"_var_e"+str(var_e)+".npy")
var_w_critical = np.load(path+"var_w_critical_"+"K" +str(K)+"_iterations"+str(it)+"_N"+str(n_spins)+"_var_e"+str(var_e)+".npy")


### 5.1 Preliminary plots for the mean and standard deviation of the MI score across trials

In [None]:
MI_mean = np.mean(MI_score, axis = 2)
MI_std = np.std(MI_score, axis = 2)

acc_mean = np.mean(mean_accuracy, axis = 2)
acc_std = np.std(mean_accuracy, axis = 2)

plt.figure(figsize=(10,5))
plt.plot(tau,acc_mean.T[:,3:9])
plt.xlabel("tau")
plt.ylabel("Mean acuracy score")
plt.title("Mean acuracy for varying distances from the critical line")
#plt.legend(["Letter condition", allLetters[i] for i in range(0,10)])
plt.legend(["var_w {}".format(np.round(w,3)) for w in var_w_perturbed[3:9]], bbox_to_anchor=(1.05, 1), loc='upper left')

print(var_w_critical)

In [None]:
MI_mean = np.mean(MI_score, axis = 2)
MI_std = np.std(MI_score, axis = 2)

acc_mean = np.mean(mean_accuracy, axis = 2)
acc_std = np.std(mean_accuracy, axis = 2)

plt.figure(figsize=(10,5))
plt.plot(tau[3:5],acc_std.T[3:5,3:9])
plt.xlabel("tau")
plt.ylabel("Standard deviation")
plt.title("Standard deviation for varying distances from the critical line")
#plt.legend(["Letter condition", allLetters[i] for i in range(0,10)])
plt.legend(["var_w {}".format(np.round(w,3)) for w in var_w_perturbed[3:9]], bbox_to_anchor=(1.05, 1), loc='upper left')

print(var_w_critical)