In [14]:
import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from itertools import permutations 
import sys
base = '/home3/ebrahim/isr/isr_model_review/BP06/'
sys.path.append(base)
from datasets import OneHotLetters, OneHotLetters_test
from EM import EM
from RNNcell import RNN_one_layer
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import r2_score
from matplotlib import pyplot as plt
from skimage.measure import block_reduce
from scipy.stats import pearsonr
import pandas as pd
import seaborn as sns
import pickle
from itertools import permutations, islice
import wandb
device = torch.device("cpu")


In [16]:
class RNNcell(nn.Module):

    """ Vanilla RNN with:
            - Feedback from output
            - Sigmoid nonlinearity over hidden activations 
            - Softmax activation over output 
            - Initialization follows Botvinick and Plaut, 2006 
            - Incorporated plastic connections based on Miconi, 2018
    """

    def __init__(self, data_size, hidden_size, output_size, noise_std, nonlin,
                bias, feedback_bool, alpha_s):

        """ Init model.
        :param (int) data_size: Input size
        :param (int) hidden_size: the size of hidden states
        :param (int) output_size : number of classes
        :param (float) noise_std: std. dev. for gaussian noise
        :param (str) nonlin: Nonlinearity for hidden activations: sigmoid, relu, tanh, or linear.
        :param (bool) h2h_bias: if true, bias units are used for hidden units
        :param (bool) feedback_bool: if true, feedback connections are implemented
        """
        super(RNNcell, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.nonlin = nonlin
        self.noise_std = noise_std
        self.feedback_bool = feedback_bool
        self.alpha_s = alpha_s

        # recurrent to recurrent connections 
        self.h2h = nn.Linear(hidden_size, hidden_size, bias=bias)
        nn.init.uniform_(self.h2h.weight, -0.5, 0.5)
        
        # input to recurrent unit connections 
        self.i2h = nn.Linear(data_size, hidden_size, bias=False)
        nn.init.uniform_(self.i2h.weight, -1.0, 1.0)

        # output to recurrent connections 
        # default to output size if no feedback size is specified 
        feedback_size = output_size

        self.o2h = nn.Linear(feedback_size, hidden_size, bias=False)
        nn.init.uniform_(self.o2h.weight, -1.0, 1.0)

        if nonlin == 'sigmoid':
            self.F = nn.Sigmoid()
        if nonlin == 'relu':
            self.F = nn.ReLU()
        if nonlin == 'tanh':
            self.F = nn.Tanh()
        if nonlin == 'linear':
            self.F = nn.Identity()
        if nonlin == 'relu6':
            self.F = nn.ReLU6()

    def forward(self, data, h_prev, feedback, i_prev, device):
        
        """
        @param data: input at time t
        @param r_prev: firing rates at time t-1
        @param x_prev: membrane potential values at time t-1
        @param feedback: feedback from previous timestep
        @param i_prev: if using continuous time RNN 
        """
        
        noise = self.noise_std*torch.randn(h_prev.shape).to(device)

        i = (1-self.alpha_s)*i_prev + self.alpha_s*(self.i2h(data) + self.h2h(h_prev)
        + self.o2h(feedback) + noise)
        h = self.F(i)
    
        return h, i 

class RNN_one_layer_EM(nn.Module):

    """ Single layer RNN """

    def __init__(self, input_size, hidden_size, output_size, feedback_bool, bias, 
        nonlin='sigmoid', noise_std=0.0, alpha_s=1.0, storage_capacity=3, cmpt=0.8):

        """ Init model.
        :param int data_size: Input size
        :param int hidden_size: the size of hidden states
        :param int output_size: number of classes
        :param bool feedback_bool: set to True to allow for feedback projections 
        :param bool bias: Set to True to allow for bias term 
        """
        super(RNN_one_layer_EM, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.cmpt = cmpt # competition between memories in EM 
        self.F = nn.Sigmoid()
            
        self.RNN = RNNcell(input_size, hidden_size, output_size, noise_std, nonlin, 
        bias=bias, feedback_bool=feedback_bool, alpha_s=alpha_s)

        self.h2o = nn.Linear(hidden_size, output_size, bias=bias)
        nn.init.uniform_(self.h2o.weight, -1.0, 1.0)
        
        self.em = EM(storage_capacity, hidden_size, 'cosine')
        self.hpc = nn.Linear(hidden_size+output_size, 1)

    def forward(self, data, h_prev, o_prev, i_prev, device):
        """
        @param data: input at time t
        @param h_prev : firing rates at time t-1 
        @param o_prev: output at time t-1
        """
        h, i = self.RNN(data, h_prev, o_prev, i_prev, device)

        output_no_EM = self.h2o(h)
        
        hpc_input = torch.cat([h, output_no_EM], dim=1)
        
        # EM gate is a scalar which controls how much 
        EM_gate = self.F(self.hpc(hpc_input))
        
        EM_mem = self.recall_from_EM(h, EM_gate)
        
        EM_mem = EM_mem.to(device)
        
        self.encode_to_EM(h+EM_mem)
        
        h_EM = h + EM_mem
        
        output_EM = self.h2o(h_EM)
        
        return output_EM, h, i, EM_mem
    
    def recall_from_EM(self, c_t, inps_t, comp_t=None):

        """
        Code from [1]
        run the "pattern completion" procedure
        Parameters
        ----------
        c_t : torch.tensor, vector
            cell state
        leak_t : torch.tensor, scalar
            LCA param, leak
        comp_t : torch.tensor, scalar
            LCA param, lateral inhibition
        inps_t : torch.tensor, scalar
            LCA param, input strength / feedforward weights
        Returns
        -------
        tensor, tensor
            updated cell state, recalled item
        """
        
        # level of lateral inhibition (beta in paper)
        if comp_t is None:
            comp_t = self.cmpt

        m_t = self.em.get_memory(c_t, leak=0, comp=comp_t, w_input=inps_t)
            
        return m_t
    
    def encode_to_EM(self, hidden_state):
        
        if self.em.encoding_off == False:
            self.em.save_memory(hidden_state)

    def init_states(self, batch_size, device, h0_init_val):

        output = torch.zeros(batch_size, self.output_size).to(device)
        h0 = torch.full((batch_size, self.hidden_size), float(h0_init_val)).to(device)
        i0 = torch.full((batch_size, self.hidden_size), float(0.0)).to(device)
       
        return output, h0, i0
    
    def init_em_config(self):
        self.flush_episodic_memory()
        self.encoding_off()
        self.retrieval_off()

    def flush_episodic_memory(self):
        self.em.flush()

    def encoding_off(self):
        self.em.encoding_off = True

    def retrieval_off(self):
        self.em.retrieval_off = True

    def encoding_on(self):
        self.em.encoding_off = False

    def retrieval_on(self):
        self.em.retrieval_off = False

In [17]:
# initialize untrained model
batch_size = 1
model = RNN_one_layer_EM(28, 200, 28, noise_std=0,
                        feedback_bool=True, bias=False)
model = model.to(device)

# create dataloader
rtt = DataLoader(OneHotLetters(9, 100, '/home3/ebrahim/isr/isr_model_review/BP06/test_set/test_lists_cleaned_26_set.pkl', 
                               28, batch_size=batch_size, num_letters=26, 
delay_start=3, delay_middle=1), batch_size=batch_size, shuffle=False)


In [18]:
# init initial states
y0, h0, i0 = model.init_states(batch_size, device, 
            0.5)

In [19]:
# Let's test if the model works 
h_current_list = []
EM_mem_storage = []
end_token_length = 1
max_list_length = 9
delay_start = 3
delay_middle = 1
for batch_idx, (X,y) in enumerate(rtt):
    
    list_length = batch_idx%max_list_length + 1
    
    EM_encode_timesteps = []
    list_presented_time = delay_start + list_length - 1
    list_recalled_time = list_presented_time + delay_middle + list_length
    
    EM_encode_timesteps.append(list_presented_time)
    EM_encode_timesteps.append(list_recalled_time)

    recall_start_time = delay_start + list_length + delay_middle
    recall_end_time = recall_start_time + list_length
    
    # run RNN and compute loss
    for timestep in range(X.shape[1]):
        
        if timestep in EM_encode_timesteps:
            model.encoding_on()
        else:
            model.encoding_off()
            
        if timestep == recall_start_time:
            model.retrieval_on()
        if timestep == recall_end_time:
            model.retrieval_off()
        # initial feedback 
        if timestep == 0:
            y_hat, h, i, EM_mem = model(X[:, timestep, :], h0, y0, i0, device)
        else:
            y_hat, h, i, EM_mem = model(X[:, timestep, :], h, y[:, timestep-1, :], i, device)
            
        h_current_list.append(h.detach())
        EM_mem_storage.append(EM_mem.detach())
        
print("Model can leave the port!")

Model can leave the port!
