In [1]:
import numpy as np
import torch
import torchvision 
import torch.nn as nn
import torch.optim as optim 
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm
from matplotlib import pyplot as plt
import math
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x1082f7fd0>

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
class CFD(Dataset):
    
    def __init__(self, num_trials, trial_length, coh, start_delay, check_length, mid_delay, target_length, post_delay):
        
        self.num_trials = num_trials
        self.trial_length = trial_length
        self.coh = coh
        self.start_delay = start_delay
        self.check_length = check_length
        self.mid_delay = mid_delay
        self.target_length = target_length
        self.post_delay = post_delay
        
        X = np.zeros((self.num_trials, self.trial_length))
        
        # create useful variables
        num_coh = self.num_trials // len(self.coh)
        target_onset = self.start_delay + self.check_length + self.mid_delay
        
        # uniformly assign coherences 
        X[:, self.start_delay:self.start_delay + self.check_length] = np.tile(np.repeat(self.coh, num_coh), (self.check_length,1)).T
        
        # set target location, -1 means red target is on left, 1 means red target is on right
        X[:, target_onset:target_onset+self.target_length] = np.tile([-1,1], (self.target_length, num_trials//2)).T
        
        Y = np.zeros((self.num_trials, self.trial_length, 2))
        
        # green checkerboard trials, 0 is left reach, 1 is right reach
        Y[:self.num_trials//2, target_onset:target_onset+self.target_length, 0] = np.tile([0,1], (self.target_length, num_trials//4)).T
        Y[:self.num_trials//2, target_onset:target_onset+self.target_length, 1] = np.tile([1,0], (self.target_length, num_trials//4)).T
        
        # red checkerboard trials
        Y[self.num_trials//2:, target_onset:target_onset+self.target_length, 0] = np.tile([1,0], (self.target_length, num_trials//4)).T
        Y[self.num_trials//2:, target_onset:target_onset+self.target_length, 1] = np.tile([0,1], (self.target_length, num_trials//4)).T
        
        self.data = torch.from_numpy(np.expand_dims(X,axis=-1)).float()
        self.labels = torch.from_numpy(Y).float()
        
    def __len__(self):
        
        return (self.data.shape[0])
    
    def __getitem__(self, idx):
        
        return self.data[idx], self.labels[idx]

In [4]:
def visualize_data(CFD_object, show_all):
    
    data = CFD_object.data
    labels = CFD_object.labels
    
    num_trials = len(data)
    coh = CFD_object.coh
    
    if show_all:
        for i in range(num_trials):
            fig, axs = plt.subplots(1,3,sharey=True)
            axs[0].plot(data[i, :])
            axs[1].plot(labels[i, :, 0], color='r')
            axs[2].plot(labels[i, :, 1], color='g')
            plt.show()

    else: 
    
        print("Red target on left")
        for i in range(0, num_trials, num_trials//len(coh)):
            fig, axs = plt.subplots(1,3,sharey=True)
            axs[0].plot(data[i, :])
            axs[1].plot(labels[i, :, 0], color='r')
            axs[2].plot(labels[i, :, 1], color='g')
            plt.show()

        print("Green target on left")
        for i in range(1, num_trials, num_trials//len(coh)):
            fig, axs = plt.subplots(1,3,sharey=True)
            axs[0].plot(data[i, :])
            axs[1].plot(labels[i, :, 0], color='r')
            axs[2].plot(labels[i, :, 1], color='g')
            plt.show()

In [5]:
c = CFD(4, 100, [.1,.3,.7,.9], 0, 50, 0, 50, 0)
#visualize_data(c, show_all=True)

In [9]:
# -------- RNN ---------
# model-related
input_size = 1 # features
sequence_length = 20 # timesteps
hidden_size = 5 # RNN units
num_layers = 1 # RNN layers
output_dim = 2
nonlin='relu'

# training-related
num_epochs = 2
learning_rate = 1e-3
batch_size = 10
lambda_omega = 2
lambda_l2 = 1
clip_value = .1


class RNN(nn.Module):
    
    def __init__(self, input_size, hidden_size, num_layers, output_dim, ratio):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.num_layers = num_layers
        self.output_dim = output_dim
        self.ratio = ratio 
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, nonlinearity=nonlin, batch_first=True, bias=False)
        self.fc = nn.Linear(hidden_size, output_dim, bias=False)
        
        
    def dale_weight_init(self):
        
        with torch.no_grad():
        
            num_exc = np.int(self.ratio[0]*self.hidden_size)
            num_inh = np.int(self.hidden_size - num_exc)

            D = torch.diag_embed(torch.cat((torch.ones(num_exc), -1*torch.ones(num_inh)))) 
            self.rnn.weight_hh_l0 = torch.nn.Parameter(torch.abs(self.rnn.weight_hh_l0.detach()).matmul(D))
                                                    

    def enforce_dale(self):
        
        with torch.no_grad():
            
            num_exc = np.int(self.ratio[0]*self.hidden_size)
            num_inh = np.int(self.hidden_size - num_exc)

            self.rnn.weight_hh_l0[:num_exc, :].clamp(min=0)
            self.rnn.weight_hh_l0[num_exc:, :].clamp(max=0)

        
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, requires_grad=True).to(device)
        h_t, _ = self.rnn(x,h0)
        h_t.retain_grad()
        self.h_t = h_t
        print(h_t.shape)
        out = self.fc(h_t)
        return out
    

model = RNN(input_size, hidden_size, num_layers, output_dim, ratio=[.8,.8,.8]).float()
#model.dale_weight_init()
       
# loss function and gradient desc. algorithm 
criterion = nn.L1Loss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay=lambda_l2)

training_set = CFD(1000, sequence_length, [.1,.9], 0, 10, 0, 10, 0)
trainloader = DataLoader(training_set, batch_size=batch_size,
                         shuffle=True)

In [10]:
def pascanu_regularizer(model, nonlin):
    
    '''
    
     Step 1: compute gradient of ht w.r.t ht-1
         This equals Wrec.T*I(Winxt + Wrecht-1 + b), I(x) = 0 if x <= 0, and 1 else
         Compute I(Winxt + Wrecht-1 + b) from ht

     Step 2: multiply BPTT computed dloss/dht w/ dht/dht-1 
     
     Step 3: calculate omega
     
    '''
    global first_run
    
    if first_run:
        first_run=False
        return 0
    
    if nonlin == 'relu':
            
        # Step 1 
        h_t_binary = torch.squeeze(torch.diag_embed(torch.squeeze(model.h_t.grad))) # diag matrix, hidden_size*hidden_size 
        h_t_binary[h_t_binary!=0] = 1 # convert to I(...) 
        dht_dht_prev = torch.matmul(model.rnn.weight_hh_l0.T, h_t_binary) # seq_len x hidden_size * hidden_size

        # Step 2 
        dl_dht_prev = torch.squeeze(torch.bmm(torch.unsqueeze(torch.squeeze(model.h_t.grad),axis=1), dht_dht_prev)) # seq_len x hidden_size

        # Step 3
        omega = torch.sum(torch.pow((torch.norm(dl_dht_prev,dim=1) / (torch.norm(torch.squeeze(model.h_t.grad),dim=1))) - 1, 2))
        return omega 

    else:
        
        print("Code has not been written for other nonlin functions yet, omega is 0")
        omega = 0 
        
    return omega

In [11]:
first_run = True
i = 0
for epoch in range(num_epochs):

    for idx, (data, targets) in enumerate(trainloader):
        
        data = data.to(device=device)

        targets = targets.to(device=device)
        
        omega = pascanu_regularizer(model, nonlin)
        
        # forward
        scores = model(data)
        l1_loss = criterion(scores, targets)
        
        # backward
        optimizer.zero_grad()
        
        l1_loss.backward()
        print(l1_loss)
        print(model.rnn.weight_hh_l0.grad)
        
        if omega != 0:
            print("----------omega-----------")
            omega.backward()
            print(model.rnn.weight_hh_l0.grad)
            print("\n")
        


        
        # clip gradients 
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
    
        # use saved gradients in params to step 
        optimizer.step()
        
        # impose dale's principle
        #model.enforce_dale()
        

torch.Size([10, 1])
torch.Size([10, 1])


In [83]:
omega.grad

  """Entry point for launching an IPython kernel.


In [60]:
for name, param in model.named_parameters():
    if 'rnn'in name:
        print(name)

rnn.weight_ih_l0
rnn.weight_hh_l0


In [14]:
# examining loss function 
criterion = nn.L1Loss()
pred = torch.zeros(100, 2)
pred[50:, 0] = .96
label = torch.zeros(100, 2)
label[50:, 0] = 1
criterion(pred,label)


tensor(0.0100)

In [23]:
import torch
device = 'cpu'
a = torch.tensor([-10,0,20], dtype=torch.float64, device=device, requires_grad=True)
b = torch.tensor([0,0,0], dtype=torch.float64, device=device)
loss = torch.nn.functional.l1_loss(a, b, reduction='sum')
loss.backward()
print(a.grad)

tensor([-1.,  0.,  1.], dtype=torch.float64)


In [30]:
Wrec = torch.ones(2,2)
h_t_binary = torch.ones(5,2,2)
h_t_binary[:,1, 1] = 0
torch.matmul(Wrec.T, h_t_binary)


tensor([[[2., 1.],
         [2., 1.]],

        [[2., 1.],
         [2., 1.]],

        [[2., 1.],
         [2., 1.]],

        [[2., 1.],
         [2., 1.]],

        [[2., 1.],
         [2., 1.]]])

In [50]:
b = torch.ones(1,5)
e = torch.ones(2,5)
f = torch.matmul(e,d)
print(f)




tensor([[ 1.,  1.,  1., -1., -1.],
        [ 1.,  1.,  1., -1., -1.]])


In [104]:
class NN (nn.Module):
    
    def __init__(self, input_size, output_dim):
        super(NN, self).__init__()
        self.input_size = input_size
        self.output_dim = output_dim
        self.fc = nn.Linear(input_size, output_dim, bias=False)
        
        
    def forward(self, x):
        out = self.fc(x)
        return out
        

In [114]:
vanilla_NN = NN(10,1)
out = vanilla_NN(2*torch.ones(10))
out.backward()
out.grad_fn

<SqueezeBackward3 at 0x120777f10>

In [29]:
l2 = torch.Tensor([1,2,3,1,2,3])
torch.norm(l2)
torch.sum(l2**2)**.5


tensor(5.2915)

In [13]:
torch.mm(torch.ones(10,10), torch.ones(32,10))

RuntimeError: size mismatch, m1: [10 x 10], m2: [32 x 10] at /Users/distiller/project/conda/conda-bld/pytorch_1595629430416/work/aten/src/TH/generic/THTensorMath.cpp:41