In [1]:
import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from itertools import permutations 
import sys
base = '/home3/ebrahim/isr/'
sys.path.append(base)
from datasets import OneHotLetters, OneHotLetters_test
from RNNcell import RNN_one_layer
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import r2_score
from matplotlib import pyplot as plt
from skimage.measure import block_reduce
from scipy.stats import pearsonr
import pandas as pd
import seaborn as sns
import pickle
from itertools import permutations, islice
import wandb
device = torch.device("cuda:0")

In [2]:
from pickle import FALSE
import torch.nn as nn 
import torch

class RNNcell(nn.Module):

    """ Vanilla RNN with:
            - Feedback from output
            - Sigmoid nonlinearity over hidden activations 
            - Softmax activation over output 
            - Initialization follows Botvinick and Plaut, 2006 
            - Incorporated plastic connections based on Miconi, 2018
            - Contextual variable binding 
    """

    def __init__(self, data_size, hidden_size, output_size, noise_std, nonlin,
                bias, feedback_bool, alpha_s, plastic, context, context_size, rule, h2h_weights):

        """ Init model.
        @param data_size (int): Input size
        @param hidden_size (int): the size of hidden states
        @param output_size (int): number of classes
        @param noise_std (float): std. dev. for gaussian noise
        @param nonlin (str): Nonlinearity for hidden activations: sigmoid, relu, tanh, or linear.
        @param h2h_bias (bool): if true, bias units are used for hidden units
        @param feedback_bool (bool): if true, feedback connections are implemented
        @param feedback_size (int): size of feedback units, if None defaults to output_size
        @param plastic (bool): if true, implement hebbian connections 
        @param context (bool): if true, then hebbian weights are formed between RNN hidden activity and context
        """
        super(RNNcell, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.nonlin = nonlin
        self.noise_std = noise_std
        self.feedback_bool = feedback_bool
        self.alpha_s = alpha_s
        self.plastic = plastic
        self.context = context
        self.context_size = context_size
        self.rule = rule
        self.h2h_weights = h2h_weights

        # recurrent to recurrent connections 
        self.h2h = nn.Linear(hidden_size, hidden_size, bias=bias)
        nn.init.uniform_(self.h2h.weight, -0.5, 0.5)
        
        # input to recurrent unit connections 
        self.i2h = nn.Linear(data_size, hidden_size, bias=False)
        nn.init.uniform_(self.i2h.weight, -1.0, 1.0)

        # output to recurrent connections 
        # default to output size if no feedback size is specified 
        feedback_size = output_size

        self.o2h = nn.Linear(feedback_size, hidden_size, bias=False)
        nn.init.uniform_(self.o2h.weight, -1.0, 1.0)

        if self.plastic:
            # plasticity coefficients
            self.alpha =  torch.nn.Parameter(torch.rand((hidden_size, hidden_size))*2 - 1)
            self.alpha.requires_grad = True

            # learning rate for plasticity, we'll allow the network to learn this 
            self.eta = torch.nn.Parameter(torch.Tensor([0.01]))
            self.eta.requires_grad = True

        if self.context: 

            self.alpha_c =  torch.nn.Parameter(torch.rand((context_size, hidden_size))*2 - 1)
            self.alpha_c.requires_grad = True

            self.eta_c = torch.nn.Parameter(torch.Tensor([0.01]))
            self.eta_c.requires_grad = True
  
        if nonlin == 'sigmoid':
            self.F = nn.Sigmoid()
        if nonlin == 'relu':
            self.F = nn.ReLU()
        if nonlin == 'tanh':
            self.F = nn.Tanh()
        if nonlin == 'linear':
            self.F = nn.Identity()
        if nonlin == 'relu6':
            self.F == nn.ReLU6()

    def forward(self, data, h_prev, feedback, hebb, context_hebb, i_prev, context_signal, device):
        """
        @param data: input at time t
        @param r_prev: firing rates at time t-1
        @param x_prev: membrane potential values at time t-1
        @param feedback: feedback from previous timestep
        @param hebb: hebbian weights
        @param context_hebb: hebbian weights from context to hidden state
        @param i_prev: if using continuous time RNN 
        """
        
        noise = self.noise_std*torch.randn(h_prev.shape).to(device)

        # h2h hebbian connections
        if self.plastic:
            hebb_activity = h_prev@torch.mul(self.alpha, hebb)[0]
        else:
            hebb_activity = torch.zeros((h_prev.shape[0], self.hidden_size)).to(device)

        # context2h hebbian connections 
        if self.context:
            context_activity = context_signal@torch.mul(self.alpha_c, context_hebb)[0]
        else:
            context_activity = torch.zeros((h_prev.shape[0], self.hidden_size)).to(device)

        # Only allow hebbian recurrent weights if false
        if self.h2h_weights:
            h_contribution = self.h2h(h_prev)
        else:
            h_contribution = h_prev

        i = (1-self.alpha_s)*i_prev + self.alpha_s*(hebb_activity + self.i2h(data) + h_contribution 
        + self.o2h(feedback) + context_activity + noise)
        h = self.F(i)

        if self.plastic:
            if self.rule == 'oja':
                hebb = hebb + self.eta_c * torch.mul((h_prev[0].unsqueeze(1) - 
                torch.mul(hebb , h[0].unsqueeze(0))) , h[0].unsqueeze(0))

            if self.rule == 'decay': 
                hebb = (1-self.eta)*hebb + self.eta*torch.bmm(h_prev.unsqueeze(2), h.unsqueeze(1))[0]

        if self.context:
            if self.rule == 'oja':
                context_hebb = context_hebb + self.eta_c * torch.mul((context_signal[0].unsqueeze(1) - 
                torch.mul(context_hebb , h[0].unsqueeze(0))) , h[0].unsqueeze(0))

            if self.rule == 'decay':
                context_hebb = (1-self.eta_c)*context_hebb + self.eta_c*torch.bmm(context_signal.unsqueeze(2), 
                h.unsqueeze(1))[0]

        return h, hebb, i, context_hebb

class RNN_one_layer(nn.Module):

    """ Single layer RNN """

    def __init__(self, input_size, hidden_size, output_size, feedback_bool, bias, 
        nonlin='sigmoid', noise_std=0.0, plastic=False, alpha_s=1.0, context=False, context_size=0, rule='decay', 
        h2h_weights=True):

        """ Init model.
        @param data_size: Input size
        @param hidden_size: the size of hidden states
        @param output_size: number of classes
        """
        super(RNN_one_layer, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.context_size = context_size
            
        self.RNN = RNNcell(input_size, hidden_size, output_size, noise_std, nonlin, 
        bias=bias, feedback_bool=feedback_bool, plastic=plastic, alpha_s=alpha_s, 
        context=context, context_size=context_size, rule=rule, h2h_weights=h2h_weights)

        self.h2o = nn.Linear(hidden_size, output_size, bias=bias)
        nn.init.uniform_(self.h2o.weight, -1.0, 1.0)

    def forward(self, data, h_prev, o_prev, hebb_prev, ch_prev, i_prev, device, context_signal=None):
        """
        @param data: input at time t
        @param h_prev : firing rates at time t-1 
        @param o_prev: output at time t-1
        """
        h, hebb, i, context_hebb = self.RNN(data, h_prev, o_prev, hebb_prev, ch_prev, i_prev, context_signal, device)

        output = self.h2o(h)

        return output, h, hebb, context_hebb, i

    def init_states(self, batch_size, device, h0_init_val):

        output = torch.zeros(batch_size, self.output_size).to(device)
        h0 = torch.full((batch_size, self.hidden_size), float(h0_init_val)).to(device)
        hebb0 = torch.zeros(batch_size, self.hidden_size, self.hidden_size).to(device)
        context_hebb0 = torch.zeros(batch_size, self.context_size, self.hidden_size).to(device)
        i0 = torch.full((batch_size, self.hidden_size), float(0.0)).to(device)
       
        return output, h0, hebb0, context_hebb0, i0

In [3]:
# initialize untrained model
batch_size = 1
model = RNN_one_layer(28, 200, 28, noise_std=0,
                        feedback_bool=True, bias=False, plastic=False, context=False, context_size=0)
model = model.to(device)

# create dataloader
rtt = DataLoader(OneHotLetters(9, 100, '/home3/ebrahim/isr/test_set/test_lists_set.pkl', 28, batch_size=batch_size, num_letters=26, 
delay_start=3, delay_middle=1), batch_size=batch_size, shuffle=False)


In [4]:
# init initial states
y0, h0, hebb0, ch0, i0 = model.init_states(batch_size, device, 
            0.5)

In [5]:
# Let's test if the model works 
for batch_idx, (X,y) in enumerate(rtt):

    h_current_list = []

    X = X.to(device)
    y = y.to(device)
    
    # run RNN and compute loss
    for timestep in range(X.shape[1]):

        # initial feedback 
        if timestep == 0:
            y_hat, h, hebb, c_hebb, i = model(X[:, timestep, :], h0, y0, hebb0, ch0, i0, device)
        else:
            y_hat, h, hebb, c_hebb, i = model(X[:, timestep, :], h, y[:, timestep-1, :], hebb, c_hebb, i, device)
            
        h_current_list.append(h.detach())

    break

print("Model can leave the port!")

Model can leave the port!


In [6]:
len(h_current_list)

7

In [107]:
c_hebb = torch.zeros((20,100))
h = torch.zeros((2,100)) 
cs = torch.zeros((2,20))


h_chebb = cs.unsqueeze(2) - torch.mul(h.unsqueeze(1), c_hebb)
print(h_chebb.shape)
#c_hebb_new = c_hebb + (0.1 * torch.mul(h.unsqueeze(2),cs.unsqueeze(1)) - 0.1 * torch.mul(h.unsqueeze(2))


torch.Size([2, 20, 100])


In [119]:
h = torch.Tensor([[.5, .25, .05, .05, .01], [1, .25, .05, .05, .01]])
chebb = torch.ones((3,5))
cs = torch.ones((2,3))
g = torch.mul((cs[0].unsqueeze(1) - torch.mul(chebb , h[0].unsqueeze(0))) , h[0].unsqueeze(0))
print(g.shape)
print((cs.unsqueeze(2) - torch.mul(h.unsqueeze(1), chebb)))

torch.Size([3, 5])
tensor([[[0.5000, 0.7500, 0.9500, 0.9500, 0.9900],
         [0.5000, 0.7500, 0.9500, 0.9500, 0.9900],
         [0.5000, 0.7500, 0.9500, 0.9500, 0.9900]],

        [[0.0000, 0.7500, 0.9500, 0.9500, 0.9900],
         [0.0000, 0.7500, 0.9500, 0.9500, 0.9900],
         [0.0000, 0.7500, 0.9500, 0.9500, 0.9900]]])
