In [1]:
### Importing libraries ###

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.distributions.bernoulli import Bernoulli

import torchvision
from torchvision.datasets import MNIST
from torchvision.datasets import FashionMNIST
from torchvision import transforms

import matplotlib.pyplot as plt
import time
import random
import pickle as pkl

import import_ipynb
from TNNColumn import *

importing Jupyter notebook from TNNColumn.ipynb
importing Jupyter notebook from TNN.ipynb


SyntaxError: duplicate argument 'lr_capture' in function definition (<string>, line 50)

In [None]:
class RecurrentTNN():
    def __init__(self, res_arr, size_arr, stdp_arr, theta_arr, winit_arr, ntype, rv = None, lr = None ):
        self.input_size            = size_arr[0]
        self.hidden_size           = size_arr[1]
        self.out_size              = size_arr[2]
        self.hidden_stdp           = stdp_arr[0]            #first layer is stdp
        self.out_stdp              = stdp_arr[1]
        self.stdp_random           = stdp_arr[2]
        self.theta_hidden          = theta_arr[0]
        self.theta_out             = theta_arr[1]
        self.winit_hidden          = winit_arr[0]
        self.winit_out             = winit_arr[1]
        self.ntype                 = ntype
        self.tmax_hidden           = res_arr[0]
        self.tmax_out              = res_arr[1]
        self.wres                  = res_arr[2]
        self.wmax                  = 2**(self.wres)-1
        self.rv                    = rv
        self.lr                    = lr
    
        #last out_size rows correspond to the recurrent connections from the output back to the
        #hidden layer
        self.recurrent_size    = self.hidden_size
        
        self.layer1_input_size = self.input_size + self.recurrent_size 
        self.layer1_out_size   = self.hidden_size
        self.layer2_input_size = self.layer1_out_size
        self.layer2_out_size   = self.out_size
        
        
        #create the columns
        self.hidden_layer      = TNNColumn(  self.tmax_hidden, 
                                             self.layer1_input_size,  
                                             self.layer1_out_size, 
                                             self.wres, 
                                             self.theta_hidden, 
                                             self.ntype, 
                                             self.hidden_stdp,
                                             self.stdp_random,
                                             self.winit_hidden)
        self.out_layer         = TNNColumn(  self.tmax_out, 
                                             self.layer2_input_size,  
                                             self.layer2_out_size,    
                                             self.wres, 
                                             self.theta_out,    
                                             self.ntype, 
                                             self.out_stdp,
                                             self.stdp_random,
                                             self.winit_out)
        
    def generate_rv(self, datasize, seq_len, p, q, layer):
        start       = time.time()
        
        ucapture    = self.rv[layer][0]
        usearch     = self.rv[layer][1]
        ubackoff    = self.rv[layer][2]
        umin        = self.rv[layer][3]
        
        bcapture    = Bernoulli(ucapture/1024)
        bsearch     = Bernoulli(usearch/1024)
        bbackoff    = Bernoulli(ubackoff/1024)
        bmin        = Bernoulli(umin/1024)
        
        w           = torch.Tensor([float(l) for l in range(self.wmax+1)])
        bstickup    = Bernoulli((w/self.wmax)*(2-w/self.wmax))
        bstickdown  = Bernoulli((1-w/self.wmax)*(1+w/self.wmax))

        rvcapture   = bcapture.sample(  [datasize, seq_len, q, p] )
        rvsearch    = bsearch.sample(   [datasize, seq_len, q, p] )
        rvbackoff   = bbackoff.sample(  [datasize, seq_len, q, p] )
        rvmin       = bmin.sample(      [datasize, seq_len, q, p] )
        rvstickup   = bstickup.sample(  [datasize, seq_len, q, p] )
        rvstickdown = bstickdown.sample([datasize, seq_len, q, p] )

        end         = time.time()
        print("Random variables generated in ", end-start)
        
        return (rvcapture, rvsearch, rvbackoff, rvmin, rvstickup, rvstickdown)
    
    def train(self, train_loader, datasize, seq_len, epoch, layer, k_hidden):
        if self.stdp_random:
            rvcapture1, rvsearch1, rvbackoff1, rvmin1, rvstickup1, rvstickdown1 = self.generate_rv(datasize, seq_len, self.layer1_input_size, self.hidden_size, 0)
            rvcapture2, rvsearch2, rvbackoff2, rvmin2, rvstickup2, rvstickdown2 = self.generate_rv(datasize, seq_len, self.layer2_input_size, self.out_size   , 1)

            for i in range(epoch):
                start = time.time()
                for inx, (data, label) in enumerate(train_loader):
                    print("Iteration: {0}\r".format(inx), end="") 
                    rvcapture   = [rvcapture1[inx],   rvcapture2[inx]   ]
                    rvsearch    = [rvsearch1[inx],    rvsearch2[inx]    ]
                    rvbackoff   = [rvbackoff1[inx],   rvbackoff2[inx]   ]
                    rvmin       = [rvmin1[inx],       rvmin2[inx]       ]
                    rvstickup   = [rvstickup1[inx],   rvstickup2[inx]   ]
                    rvstickdown = [rvstickdown1[inx], rvstickdown2[inx] ]
                    
                    if layer == 'hidden':
                        self.train_hidden_layer_random(data[0], label, rvcapture, rvsearch, rvbackoff, rvmin, rvstickup, rvstickdown)
                    elif layer == 'out':
                        self.train_output_layer_random(data[0], label, rvcapture, rvsearch, rvbackoff, rvmin, rvstickup, rvstickdown)
                    elif layer == 'both':
                        self.train_both_layers_random(data[0], label, rvcapture, rvsearch, rvbackoff, rvmin, rvstickup, rvstickdown)
                    endt                   = time.time()
                    print("                                 Time elapsed: {0}\r".format(endt-start), end="")
                end   = time.time()
                print("Column training done in ", end-start)
        else:
            self.lr_capture     = self.lr[0]
            self.lr_capture_min = self.lr[1]
            self.lr_minus       = self.lr[2]
            self.lr_minus_min   = self.lr[3]
            self.lr_backoff     = self.lr[4]
            self.lr_search      = self.lr[5]
            prev_weights = torch.zeros((self.layer1_out_size, self.layer1_input_size))
            for i in range(epoch):
                start = time.time()
                for inx, (data, label) in enumerate(train_loader):
                    print("Iteration: {0}\r".format(inx), end="") 
                    if layer == 'hidden':
                        self.train_hidden_layer_det(data[0], label, k_hidden)
                    elif layer == 'out':
                        self.train_output_layer_det(data[0], label, k_hidden)
                    elif layer == 'both':
                        self.train_both_layers_det(data[0], label, k_hidden)
                    endt                   = time.time()
                    print("                                 Time elapsed: {0}\r".format(endt-start), end="")
                    print("                                                              Conv: {0}\r".format( \
                                torch.sum(torch.abs(prev_weights - self.hidden_layer.ec.weights ))), end="")
                    prev_weights = self.hidden_layer.ec.weights.clone()
                end   = time.time()
                print("Column training done in ", end-start)
    
    
    
    #trains one sample
    def train_both_layers_random(self, sample, target, rvcapture, rvsearch, rvbackoff, rvmin, rvstickup, rvstickdown):        
        #initial hidden state is 0
        out_prev = torch.ones(self.recurrent_size) * float('inf')
        
        for inx, data in enumerate(sample):
            #concatenate input with the previous output
            data = torch.cat((data, out_prev))
            # First layer feedforward processing
            out1, winner            = self.hidden_layer(data)
            
            #first layer is regular stdp
            self.hidden_layer.ec.weights    = self.hidden_layer.stdp(data, out1, self.hidden_layer.ec.weights, \
                                                         rvcapture[0][inx], rvsearch[0][inx], rvbackoff[0][inx], \
                                                         rvmin[0][inx], rvstickup[0][inx], rvstickdown[0][inx])
            
            #update previous out
            out_prev               = out1 
            
            #second layer feedforward processing
            out2, winner            = self.out_layer(out1)
            if self.out_stdp == 'stdp':
                self.out_layer.ec.weights = self.out_layer.stdp(out1, out2, self.out_layer.ec.weights, \
                                                         rvcapture[1][inx], rvsearch[1][inx], rvbackoff[1][inx], \
                                                         rvmin[1][inx], rvstickup[1][inx], rvstickdown[1][inx])
            elif self.out_stdp == 'rstdp':
                reward = 0
                if winner != -1:
                    if target == winner.item():
                        reward = 1
                    else:
                        reward = -1
                self.out_layer.ec.weights = self.out_layer.stdp(reward, out1, out2, self.out_layer.ec.weights, \
                                                         rvcapture[1][inx], rvsearch[1][inx], rvbackoff[1][inx], \
                                                         rvmin[1][inx], rvstickup[1][inx], rvstickdown[1][inx])
            
            
    def train_hidden_layer_random(self, sample, target, rvcapture, rvsearch, rvbackoff, rvmin, rvstickup, rvstickdown):
        #initial hidden state is 0
        out_prev = torch.ones(self.recurrent_size) * float('inf')
        for inx, data in enumerate(sample):
            r = data[data != float('inf')]
            if r.shape[0] == 0:
                continue
            
            #concatenate input with the previous output
            data = torch.cat((data, out_prev))
            
            #feedforward pass
            out, winner            = self.hidden_layer(data)
            
            #hidden layer is regular stdp
            self.hidden_layer.ec.weights    = self.hidden_layer.stdp(data, out, self.hidden_layer.ec.weights, \
                                                         rvcapture[0][inx], rvsearch[0][inx], rvbackoff[0][inx], \
                                                         rvmin[0][inx], rvstickup[0][inx], rvstickdown[0][inx])
            #update previous out
            out_prev               = out
        
    def train_output_layer_random(self, sample, target, rvcapture, rvsearch, rvbackoff, rvmin, rvstickup, rvstickdown):
        for inx, data in enumerate(sample):
            
            #feedforward pass
            out, winner            = self.out_layer(data)
            
            if self.out_stdp == 'stdp':
                self.out_layer.ec.weights = self.out_layer.stdp(data, out, self.out_layer.ec.weights, \
                                                         rvcapture[1][inx], rvsearch[1][inx], rvbackoff[1][inx], \
                                                         rvmin[1][inx], rvstickup[1][inx], rvstickdown[1][inx])
            elif self.out_stdp == 'rstdp':
                reward = 0
                if winner != -1:
                    if target == winner.item():
                        reward = 1
                    else:
                        reward = -1
                self.out_layer.ec.weights = self.out_layer.stdp(reward, data, out, self.out_layer.ec.weights, \
                                                         rvcapture[1][inx], rvsearch[1][inx], rvbackoff[1][inx], \
                                                         rvmin[1][inx], rvstickup[1][inx], rvstickdown[1][inx])
    
    ##################
    ## Deterministic
    ##################
    
    
    #trains one sample
    def train_both_layers_det(self, sample, target, k_hidden):        
        #initial hidden state is 0
        out_prev = torch.ones(self.recurrent_size) * float('inf')
        
        for inx, data in enumerate(sample):
            #concatenate input with the previous output
            data = torch.cat((data, out_prev))
            # First layer feedforward processing
            out1, winner            = self.hidden_layer(data, k_hidden)
            
            #first layer is regular stdp
            self.hidden_layer.ec.weights    = self.hidden_layer.stdp(data, out1, self.hidden_layer.ec.weights, \
                                                                     self.lr_capture, self.lr_backoff, self.lr_search)
            
            #update previous out
            out_prev               = out1 
            
            #second layer feedforward processing
            out2, winner            = self.out_layer(out1, 1)
            if self.out_stdp == 'stdp':
                self.out_layer.ec.weights = self.out_layer.stdp(out1, out2, self.out_layer.ec.weights, \
                                                                self.lr_capture, self.lr_backoff, self.lr_search)
            elif self.out_stdp == 'rstdp':
                reward = 0
                if winner != -1:
                    if target == winner.item():
                        reward = 1
                    else:
                        reward = -1
                self.out_layer.ec.weights = self.out_layer.stdp(reward, out1, out2, self.out_layer.ec.weights, \
                                                                 self.lr_capture, self.lr_backoff, self.lr_search)
            
    
    
    def train_hidden_layer_det(self, sample, target, k_hidden):
        #initial hidden state is 0
        out_prev = torch.ones(self.recurrent_size) * float('inf')
        winner_prev = -1
        for inx, data in enumerate(sample):
#             if winner_prev != -1:
#                 out_prev[winner_prev.item()] += 1
            #concatenate input with the previous output
            data = torch.cat((data, out_prev))
            
            #feedforward pass hidden
            out, winner            = self.hidden_layer(data, k_hidden)

            if self.hidden_stdp == 'stdp' or self.hidden_stdp == 'stdp_tmod' or 'stdp_smod':
                #hidden layer is regular stdp
                self.hidden_layer.ec.weights    = self.hidden_layer.stdp(data, out, self.hidden_layer.ec.weights, \
                                                                     self.lr_capture, self.lr_capture_min, self.lr_minus, \
                                                                      self.lr_minus_min, self.lr_backoff, self.lr_search, \
                                                                      True, self.input_size)
            elif self.hidden_stdp == 'rstdp':
                reward = 0
                if winner != -1:
                    if target == winner.item():
                        reward = 1
                    else:
                        reward = -1
                self.hidden_layer.ec.weights    = self.hidden_layer.stdp(reward, data, out, self.hidden_layer.ec.weights, \
                                                                     self.lr_capture, self.lr_backoff, self.lr_search, \
                                                                     True, self.input_size)
            #update previous out
            out_prev               = out
            winner_prev            = winner
        
    def train_output_layer_det(self, sample, target, k_hidden):
        #initial hidden state is 0
        out_prev    = torch.ones(self.recurrent_size) * float('inf')
        for inx, data in enumerate(sample):
            #concatenate input with the previous output
            data = torch.cat((data, out_prev))
            
            #feedforward pass hidden
            out, winner            = self.hidden_layer(data, k_hidden) 
            
            if out[out != float('inf')].shape[0] > 0:
                mval, minx = torch.max(out[out != float('inf')],0)
                if mval.item() > 15:
                    out[minx.item()] = 15
            
            #update previous out
            out_prev               = out
        #feedforward pass out
        out_final, winner            = self.out_layer(out, 1)
            
        if self.out_stdp == 'stdp' or self.out_stdp == 'stdp_tmod':
            self.out_layer.ec.weights = self.out_layer.stdp(out, out_final, self.out_layer.ec.weights, \
                                                                self.lr_capture, self.lr_backoff, self.lr_search)
            
        elif self.out_stdp == 'rstdp' or self.out_stdp == 'rstdp_tmod':
            reward = 0
            if winner != -1:
                if target.item() == winner.item():
                        reward = 1
                else:
                        reward = -1
            self.out_layer.ec.weights = self.out_layer.stdp(reward, target.item(), out, out_final, self.out_layer.ec.weights, \
                                                                self.lr_capture, self.lr_backoff, self.lr_search)
            
             
    
    def predict(self, test_loader, datasize):
        predictions = torch.zeros(datasize)
        table = torch.zeros((self.layer2_out_size, self.layer2_out_size))
        for inx, (data, label) in enumerate(test_loader):
            predicted = predict_sample(self, sample)
            predictions[inx] = predicted
            if predicted != -1:
                table[predicted.item(), label[0]] += 1
            
        return (table, predictions)
    
    def predict_sample(self, sample):
        #initial hidden state is 0
        out_prev = torch.ones(self.recurrent_size) * float('inf')
        
        for inx, data in enumerate(sample):
            #concatenate input with the previous output
            data = torch.cat((data, out_prev))
            out1, winner            = self.layer1(data)
            
            #second layer feedforward processing
            out2, winner            = self.layer2(out1)
            #update previous out
            out_prev               = out2 
        
        return winner