In [None]:
### Importing libraries ###

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.distributions.bernoulli import Bernoulli

import torchvision
from torchvision.datasets import MNIST
from torchvision.datasets import FashionMNIST
from torchvision import transforms

import matplotlib.pyplot as plt
import time

In [None]:
### Author: Harideep Nair ###
### Excitatory Column (EC) ###

    # Consists of 'q' excitatory neurons with 'p' synapses each.
    # Implements an SRM0 neuron model.
    
    # Args: tin_max - Maximum input spiketime.
    #       p       - Number of synapses per neuron.
    #       q       - Number of neurons.
    #       wres    - Bit resolution for synaptic weights.
    #       theta   - Excitation threshold for neuron. An output spike is generated when neuron's body potential reaches
    #                 this threshold.
    #       ntype   - Type of neuron response function. Supports step-no-leak and ramp-no-leak response functions.
    #       w_init  - Type of initialization for synaptic weights. Supports 'zero', 'random uniform' and
    #                 'random normal' initializations.
    #       ramp    - Slope of ramp for ramp-no-leak response function. Default is 1.

class ExcitatoryColumn():
    def __init__(self, tin_max, p, q, wres, theta, ntype="rnl", w_init="zero", ramp=1):
        self.p             = p
        self.q             = q
        self.wmax          = 2**wres-1
        self.theta         = theta
        self.ntype         = ntype
        self.ramp          = ramp
        
        # Synaptic weight initialization. Shape of weights is [self.q,self.p].
        if w_init       == "zero":
            self.weights       = torch.zeros(self.q, self.p)
        elif w_init     == "uniform":
            self.weights       = torch.randint(low=0, high=self.wmax+1, size=(self.q, self.p)).type(torch.FloatTensor)
        elif w_init     == "normal":
            self.weights       = torch.round(((self.wmax+1)/2+torch.randn(self.q, self.p)).clamp_(0,self.wmax))
        
        # Calculates length of the time dimension (self.time) required for each response function model.
        # self.time relates to the maximum output spike time (say, tout_max). In fact, self.time = tout_max + 1.
        if self.ntype   == "snl":
            self.time      = tin_max + 1
        elif self.ntype == "rnl":
            self.time      = tin_max + self.wmax
            
        self.pot           = torch.zeros(self.time, self.q)
        self.ec_spiketimes = float('Inf')*torch.ones(self.q)
        self.const         = torch.arange(self.time).repeat(self.q, self.p, 1).permute(2,0,1)
    
    
    # Implements a step-no-leak response function model.
    
        # Args: input_spiketimes - Tensor of input spiketimes with shape [self.p].
        # Returns 1) a tensor of output spiketimes with shape [self.q].
        #         2) a tensor of body potentials with shape [self.time, self.q].
    
    def StepNoLeak(self, input_spiketimes):
        weights                                             = self.weights.repeat(self.time, 1, 1)
        spikes                                              = input_spiketimes.repeat(self.time, self.q, 1)
        spikes                                              = self.const - spikes
        spikes[spikes>=0]                                   = 1
        spikes[spikes<0]                                    = 0
        responses                                           = torch.mul(spikes, weights)
        pot                                                 = torch.sum(responses, dim=2)
        
        temp                                                = pot.clone()
        temp[temp<self.theta]                               = 0
        temp[temp>=self.theta]                              = 1
        tempsum                                             = torch.sum(temp, dim=0)
        
        ec_spiketimes                                       = self.time - tempsum
        ec_spiketimes[ec_spiketimes == self.time]           = float('Inf')
        return ec_spiketimes, pot
    
    
    # Implements a ramp-no-leak response function model.
    
        # Args: input_spiketimes - Tensor of input spiketimes with shape [self.p].
        # Returns 1) a tensor of output spiketimes with shape [self.q].
        #         2) a tensor of body potentials with shape [self.time, self.q].
    
    def RampNoLeak(self, input_spiketimes, ramp=1):
        weights                                             = self.weights.repeat(self.time, 1, 1)
        spikes                                              = input_spiketimes.repeat(self.time, self.q, 1)
        spikes                                              = self.const - spikes
        spikes[spikes>=0]                                   = 1
        spikes[spikes<0]                                    = 0
        responses                                           = ramp * torch.cumsum(spikes,dim=0)
        responses[responses>=weights]                       = weights[responses>=weights]
        pot                                                 = torch.sum(responses, dim=2)
        
        temp                                                = pot.clone()
        temp[temp<self.theta]                               = 0
        temp[temp>=self.theta]                              = 1
        tempsum                                             = torch.sum(temp, dim=0)
        
        ec_spiketimes                                       = self.time - tempsum
        ec_spiketimes[ec_spiketimes == self.time]           = float('Inf')
        return ec_spiketimes, pot
    
    def __call__(self, data):
        if self.ntype == "snl":
            #step-no-leak response
            self.ec_spiketimes, self.pot = self.StepNoLeak(data)
            
        elif self.ntype == "rnl":
            #ramp-no-leak response
            self.ec_spiketimes, self.pot = self.RampNoLeak(data, self.ramp)
            
        return self.ec_spiketimes


In [None]:
### Lateral Inhibition (LI) ###

# Implements 1-WTA with lowest index tie-breaking

class LateralInhibition():
        
    def __call__(self,ec_out, k=1):
        wintime                      = torch.min(ec_out)
        if wintime != float('Inf'):
            sort_times, sort_idx     = torch.sort(ec_out)
            win_times, win_idx       = sort_times[:k], sort_idx[:k]
            li_out                   = float('Inf')*torch.ones(ec_out.shape)
            li_out[win_idx]          = win_times
        else:
            li_out                   = ec_out
            win_idx                  = -1

        return li_out, win_idx

In [None]:
### Unsupervised Spike Timing Dependent Plasticity (STDP) - Modified ###

# Modification of baseline RSTDP - implements much more stochasticity.
# Each synapse has a separate Bernoulli random variable associated with it.

class STDP():
    def __init__(self, wres):
        self.wmax       = 2**(wres)-1
        
    def __call__(self, intimes, outtimes, weights, rvcapture, rvsearch, rvbackoff, rvmin, rvstickup, rvstickdown):
        intimes         = torch.flatten(intimes)
        q               = outtimes.shape[0]
        p               = intimes.shape[0]
        ec_in           = intimes.repeat(q,1)
        li_out          = outtimes.repeat(p,1).permute(1,0)

        # Case 1 (capture)
        weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)] \
                       += rvcapture[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)] \
                        * torch.max(rvmin[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)], \
                          torch.diagonal(rvstickup[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)] \
                          [:,weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)].long()],0))
        
        # Case 2 (minus)
        weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)] \
                       -= rvcapture[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)] \
                        * torch.max(rvmin[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)], \
                          torch.diagonal(rvstickdown[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)] \
                          [:,weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)].long()],0))
        
        # Case 3 (search)
        weights[(ec_in!=float('Inf'))*(li_out==float('Inf'))] \
                       += rvsearch[(ec_in!=float('Inf'))*(li_out==float('Inf'))]
        
        # Case 4 (backoff)
        weights[(ec_in==float('Inf'))*(li_out!=float('Inf'))] \
                       -= rvbackoff[(ec_in==float('Inf'))*(li_out!=float('Inf'))] \
                        * torch.max(rvmin[(ec_in==float('Inf'))*(li_out!=float('Inf'))], \
                          torch.diagonal(rvstickdown[(ec_in==float('Inf'))*(li_out!=float('Inf'))] \
                          [:,weights[(ec_in==float('Inf'))*(li_out!=float('Inf'))].long()],0))
        
        return weights.clamp(0, self.wmax)

In [None]:
### Author: Harideep Nair ###
### Spike Timing Dependent Plasticity (STDP) with Reward for TNN Column ###

# Adds partial reinforcement to baseline unsupervised STDP.
# If the winner neuron's index matches the desired index corresponding to the assigned label, then proceed with
# conventional STDP cases except for search.
# Else if the winner neuron does not correspond to the assigned label, do reverse STDP for capture case with
# backoff probability. Backoff and minus cases are not required. Search is executed as usual.

    # Args: wres          - Bit resolution for weights.
    #       stochasticity - Decides how correlated or uncorrelated weight updates will be, for a column's synapses. Could
    #                       be low or high.
    #       layer         - Determines if STDP is done for a single cloumn (laye = 0) or on a layer level (layer = 1).
    #       reward        - Reward signal to guide STDP to learn the desired output label assignments.
    #                       When the winning neuron index matches the desired label, reward is '1' -> perform conventional STDP.
    #                       When there is no winning neuron, reward is '0' -> just perform search.
    #                       When the winning neuron index doesn't match the desired label, reward is '-1' -> perform anti-STDP.
    #       intimes       - Tensor of input spiketimes with shape [1,in_channels,height,width]. However, the actual shape
    #                       doesn't matter since it's flattened later. After flattening, the shape becomes [self.p].
    #       outtimes      - Tensor of LI's output spiketimes with shape [self.q].
    #       weights       - Tensor of synaptic weights with shape [self.q,self.p].
    #       rvcapture     - BRV for 'capture' and 'minus' cases.
    #       rvsearch      - BRV for 'search' case.
    #       rvbackoff     - BRV for 'backoff' case.
    #       rvmin         - BRV for enforcing a minimum probability of update.
    #       rvstickup     - BRV for sticking the weights towards wmax. Helps in generating a bimodal weight distribution.
    #       rvstickdown   - BRV for sticking the weights towards 0. Helps in generating a bimodal weight distribution.
    # Returns a tensor of RSTDP-updated synaptic weights of shape [self.q,self.p].

class RSTDP():
    def __init__(self, wres, stochasticity="high", layer=0):
        self.wmax           = 2**(wres)-1
        self.stoch          = stochasticity
        self.layer          = layer
        
    def __call__(self, reward, intimes, outtimes, weights, rvcapture, rvsearch, rvbackoff, rvmin, rvstickup, rvstickdown):
        if self.layer == 0:
            intimes         = torch.flatten(intimes)
            q               = outtimes.shape[0]
            p               = intimes.shape[0]
            ec_in           = intimes.repeat(q,1)
            li_out          = outtimes.repeat(p,1).permute(1,0)
        elif self.layer == 1:
            ec_in           = intimes
            li_out          = outtimes
        
        # Low stochasticity - All Bernoulli random variables are shared across the entire column.
        if self.stoch == "low":
            
            if reward == 1:
                # Case 1 (capture)
                weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)] \
                               += rvcapture * torch.max(rvmin, rvstickup \
                                  [weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)].long()])
                
                # Case 2 (minus)
                weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)] \
                               -= rvcapture * torch.max(rvmin, rvstickdown \
                                  [weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)].long()])
                
                # Case 4 (backoff)
                weights[(ec_in==float('Inf'))*(li_out!=float('Inf'))] \
                               -= rvbackoff * torch.max(rvmin, rvstickdown \
                                  [weights[(ec_in==float('Inf'))*(li_out!=float('Inf'))].long()])
                
            elif reward == 0:
                
                # Case 3 (search)
                weights[(ec_in!=float('Inf'))*(li_out==float('Inf'))] \
                               += rvsearch
                
            elif reward == -1:
                # Case 1 (capture)
                weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)] \
                               -= rvbackoff * torch.max(rvmin, rvstickdown \
                                  [weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)].long()])

                # Case 3 (search)
                weights[(ec_in!=float('Inf'))*(li_out==float('Inf'))] \
                               += rvsearch
                
        # High stochasticity - Each synapse has a separate Bernoulli random variable associated with it.  
        elif self.stoch == "high":
            
            if reward == 1:
                # Case 1 (capture)
                weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)] \
                               += rvcapture[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)] \
                                * torch.max(rvmin[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)], \
                                  torch.diagonal(rvstickup[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)] \
                                  [:,weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)].long()],0))

                # Case 2 (minus)
                weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)] \
                               -= rvcapture[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)] \
                                * torch.max(rvmin[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)], \
                                  torch.diagonal(rvstickdown[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)] \
                                  [:,weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)].long()],0))

                # Case 4 (backoff)
                weights[(ec_in==float('Inf'))*(li_out!=float('Inf'))] \
                               -= rvbackoff[(ec_in==float('Inf'))*(li_out!=float('Inf'))] \
                                * torch.max(rvmin[(ec_in==float('Inf'))*(li_out!=float('Inf'))], \
                                  torch.diagonal(rvstickdown[(ec_in==float('Inf'))*(li_out!=float('Inf'))] \
                                  [:,weights[(ec_in==float('Inf'))*(li_out!=float('Inf'))].long()],0))
                
            elif reward == 0:

                # Case 3 (search)
                weights[(ec_in!=float('Inf'))*(li_out==float('Inf'))] \
                               += rvsearch[(ec_in!=float('Inf'))*(li_out==float('Inf'))]
                
            elif reward == -1:
                # Case 1 (capture)
                weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)] \
                               -= rvbackoff[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)] \
                                * torch.max(rvmin[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)], \
                                  torch.diagonal(rvstickdown[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)] \
                                  [:,weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)].long()],0))

                # Case 3 (search)
                weights[(ec_in!=float('Inf'))*(li_out==float('Inf'))] \
                               += rvsearch[(ec_in!=float('Inf'))*(li_out==float('Inf'))]
        
        return weights.clamp_(0, self.wmax)

In [None]:
class STDP_Det():
    def __init__(self, wres):
        self.wmax       = 1
        
    def __call__(self, intimes, outtimes, weights, lr_capture, lr_backoff, lr_search,
                   recc, recc_start):
        intimes         = torch.flatten(intimes)
        q               = outtimes.shape[0]
        p               = intimes.shape[0]
        ec_in           = intimes.repeat(q,1)
        li_out          = outtimes.repeat(p,1).permute(1,0)

        # Case 1 (capture)
        weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)] \
                       += lr_capture
        
        # Case 2 (minus)
        weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)] \
                       -= lr_capture
        
        # Case 3 (search)
        if recc:
            #don't search the recurrent_layer
            weights[:,:recc_start][((ec_in!=float('Inf'))*(li_out==float('Inf')))[:,:recc_start]] \
                           += lr_search
        else:
            weights[(ec_in!=float('Inf'))*(li_out==float('Inf'))] \
                           += lr_search
        
        # Case 4 (backoff)
        if recc:
            #only backoff for the recurrent_layer
            weights[:,recc_start:][((ec_in==float('Inf'))*(li_out!=float('Inf')))[:,recc_start:]] \
                        -= lr_backoff
        else:
            weights[(ec_in==float('Inf'))*(li_out!=float('Inf'))] \
                           -= lr_backoff
            
        weights[:,:recc_start] = weights[:,:recc_start].clamp(0, self.wmax)
        weights[:,recc_start:] = weights[:,recc_start:].clamp(0, 0.5)
        return weights

In [None]:
class RSTDP_Det():
    def __init__(self, wres, layer=0):
        self.wmax           = 1
        self.layer          = layer
        
    def __call__(self, reward, intimes, outtimes, weights, lr_capture, lr_backoff, lr_search, \
                recc, recc_start):
        if self.layer == 0:
            intimes         = torch.flatten(intimes)
            q               = outtimes.shape[0]
            p               = intimes.shape[0]
            ec_in           = intimes.repeat(q,1)
            li_out          = outtimes.repeat(p,1).permute(1,0)
        elif self.layer == 1:
            ec_in           = intimes
            li_out          = outtimes
        
        if reward == 1:
            # Case 1 (capture)
            weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)] \
                               += lr_capture
                
            # Case 2 (minus)
            weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)] \
                               -= lr_capture
                
            # Case 4 (backoff)
            if recc:
                #only backoff for the recurrent_layer
                weights[:,recc_start:][((ec_in==float('Inf'))*(li_out!=float('Inf')))[:,recc_start:]] \
                            -= lr_backoff
            else:
                weights[(ec_in==float('Inf'))*(li_out!=float('Inf'))] \
                               -= lr_backoff
                
        elif reward == 0:
                
            # Case 3 (search)
            if recc:
                #don't search the recurrent_layer
                weights[:,:recc_start][((ec_in!=float('Inf'))*(li_out==float('Inf')))[:,:recc_start]] \
                               += lr_search
            else:
                weights[(ec_in!=float('Inf'))*(li_out==float('Inf'))] \
                               += lr_search
                
        elif reward == -1:
            # Case 1 (capture)
            weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)] \
                               -= lr_capture

            # Case 3 (search)
            if recc:
                #don't search the recurrent_layer
                weights[:,:recc_start][((ec_in!=float('Inf'))*(li_out==float('Inf')))[:,:recc_start]] \
                               += lr_search
            else:
                weights[(ec_in!=float('Inf'))*(li_out==float('Inf'))] \
                               += lr_search
                
        return weights.clamp_(0, self.wmax)

In [None]:
class STDP_Det_TMod():
    def __init__(self, wres, tmax):
        self.wmax       = 1
        self.tmax       = tmax
        
    def __call__(self, intimes, outtimes, weights, lr_capture, lr_capture_min, \
                 lr_minus, lr_minus_min, lr_backoff, lr_search,
                   recc, recc_start):
        ###experiment
        intimes[recc_start:][intimes[recc_start:] != float('inf')] -= 1
        ####
        
        intimes         = torch.flatten(intimes)
        q               = outtimes.shape[0]
        p               = intimes.shape[0]
        ec_in           = intimes.repeat(q,1)
        li_out          = outtimes.repeat(p,1).permute(1,0)

        capture         = torch.linspace(lr_capture_min, lr_capture, self.tmax + 1)
        minus           = torch.linspace(lr_minus_min, lr_minus, self.tmax + 1)
        
        # Case 1 (capture)
        if ((li_out - ec_in)[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)]).shape[0] > 0:
            weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)] \
                       += capture[((li_out - ec_in)[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)]).long()]
        
        # Case 2 (minus)
        if ((ec_in - li_out)[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)]).shape[0] > 0:
            weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)] \
                       -= minus[((ec_in - li_out)[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)]).long()]
        
        if recc:
            #don't search the recurrent_layer
            weights[:,:recc_start][((ec_in!=float('Inf'))*(li_out==float('Inf')))[:,:recc_start]] \
                                   += lr_search
        else:
            weights[(ec_in!=float('Inf'))*(li_out==float('Inf'))] \
                                   += lr_search
            
        
        # Case 4 (backoff)
        if recc:
            weights[:,:recc_start][((ec_in==float('Inf'))*(li_out!=float('Inf')))[:,:recc_start]] \
                        -= lr_backoff
            #only backoff for the recurrent_layer
            weights[:,recc_start:][((ec_in==float('Inf'))*(li_out!=float('Inf')))[:,recc_start:]] \
                        -= lr_backoff * 800
        else:
            weights[(ec_in==float('Inf'))*(li_out!=float('Inf'))] \
                           -= lr_backoff
        return weights.clamp(0, self.wmax)







# class STDP_Det_TMod():
#     def __init__(self, wres, tmax):
#         self.wmax       = 1
#         self.tmax       = tmax
        
#     def __call__(self, intimes, outtimes, weights, lr_capture, lr_minus, lr_backoff, lr_search,
#                    recc, recc_start):
#         intimes         = torch.flatten(intimes)
#         q               = outtimes.shape[0]
#         p               = intimes.shape[0]
#         ec_in           = intimes.repeat(q,1)
#         li_out          = outtimes.repeat(p,1).permute(1,0)

#         capture         = torch.linspace(lr_capture - 0.003, lr_capture, self.tmax + 1)
#         minus           = torch.linspace(lr_minus - 0.003, lr_minus, self.tmax + 1)
        
#         case1 = (ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)
#         case2 = (ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)
#         case3 = (ec_in!=float('Inf'))*(li_out==float('Inf'))
#         case4 = (ec_in==float('Inf'))*(li_out!=float('Inf'))
        
#         # Case 1 (capture)
#         if ((li_out - ec_in)[case1]).shape[0] > 0:
#             weights[case1] += capture[((li_out - ec_in)[case1]).long()]
        
#         # Case 2 (minus)
#         if (((ec_in - li_out)[:,:recc_start])[case2[:,:recc_start]]).shape[0] > 0:
#             weights[:,:recc_start][case2[:,:recc_start]] \
#                        -= minus[(((ec_in - li_out)[:,:recc_start])[case2[:,:recc_start]]).long()]
              
#         # Case 2 (minus) recurrent
#         if (((ec_in - li_out)[:,recc_start:])[case2[:,recc_start:]]).shape[0] > 0:
#             weights[:,recc_start:][case2[:,recc_start:]] \
#                        -= capture[(((ec_in - li_out)[:,recc_start:])[case2[:,recc_start:]]).long()]
#         if recc:
#             #don't search the recurrent_layer
#             weights[:,:recc_start][case3[:,:recc_start]] \
#                                    += lr_search
            
# #             if outtimes[outtimes != float('inf')].shape[0] > 0:
# #                 #remove from the recc layer
# #                 weights[:,recc_start:][case3[:,recc_start:]] \
# #                                        -= lr_search
#         else:
#             weights[case3] \
#                                    += lr_search
            
        
#         # Case 4 (backoff)
#         if recc:
#             weights[:,:recc_start][case4[:,:recc_start]] \
#                         -= lr_backoff
#             #only backoff for the recurrent_layer
#             weights[:,recc_start:][case4[:,recc_start:]] \
#                         -= lr_backoff * 800
#         else:
#             weights[case4] \
#                            -= lr_backoff
            
#         weights[:,:recc_start] = (weights[:,:recc_start]).clamp(0, self.wmax)
#         weights[:,recc_start:] = (weights[:,recc_start:]).clamp(0, 0.75)
#         return weights

In [None]:
class RSTDP_Det_TMod():
    def __init__(self, wres, tmax):
        self.wmax       = 1
        self.tmax       = tmax
        
    def __call__(self, reward, target, intimes, outtimes, weights, lr_capture, lr_backoff, lr_search):
        intimes         = torch.flatten(intimes)
        q               = outtimes.shape[0]
        p               = intimes.shape[0]
        ec_in           = intimes.repeat(q,1)
        li_out          = outtimes.repeat(p,1).permute(1,0)

        capture         = torch.linspace(lr_capture - 0.008, lr_capture, self.tmax + 1)
        
        
        if reward == 1:
            # Case 1 (capture)
            weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))] \
                               += capture[((ec_in)[(ec_in!=float('Inf'))*(li_out!=float('Inf'))]).long()]
                
            weights[(ec_in==float('Inf'))*(li_out!=float('Inf'))] \
                               -= lr_backoff 
        
        elif reward == -1:
            # Case 1 (capture)
            weights[(ec_in!=float('Inf'))*(li_out!=float('Inf'))] \
                               -= capture[((ec_in)[(ec_in!=float('Inf'))*(li_out!=float('Inf'))]).long()]

            weights[int(target)][(ec_in!=float('Inf'))[int(target)]] \
                               += lr_search
        
        elif reward == 0:  
            weights[int(target)][(ec_in!=float('Inf'))[int(target)]] \
                               += lr_search * 10
        return weights.clamp(0, self.wmax)

In [3]:
# class STDP_Det_SMod():
#     def __init__(self, wres, tmax):
#         self.wmax       = 1
#         self.tmax       = tmax
        
#     def __call__(self, intimes, outtimes, weights, lr_capture,lr_capture_min, lr_minus,lr_minus_min, \
#                  lr_backoff, lr_search, recc, recc_start):
#         intimes         = torch.flatten(intimes)
#         q               = outtimes.shape[0]
#         p               = intimes.shape[0]
#         ec_in           = intimes.repeat(q,1)
#         li_out          = outtimes.repeat(p,1).permute(1,0)

#         capture         = torch.linspace(lr_capture_min, lr_capture, self.tmax + 1)
#         minus           = torch.linspace(lr_minus_min, lr_minus, self.tmax + 1)
        
#         case1 = (ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)
#         case2 = (ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)
#         case3 = (ec_in!=float('Inf'))*(li_out==float('Inf'))
#         case4 = (ec_in==float('Inf'))*(li_out!=float('Inf'))
        
        
#         #Case 1
#         if ((li_out - ec_in)[case1]).shape[0] > 0:
#             weights[case1] \
#                        += capture[((li_out - ec_in)[case1]).long()]
        
#         #Case 2
#         if ((ec_in - li_out)[case2]).shape[0] > 0:
#             weights[case2] \
#                        -= minus[((ec_in - li_out)[case2]).long()]
        
#         #Case 3 backoff in search
#         if outtimes[outtimes != float('inf')].shape[0] > 0:
#             weights[case3] -= lr_backoff
        
#         #Case 4 backoff
#         weights[case4] -= lr_backoff
        
        
#         return weights.clamp(0, self.wmax)



class STDP_Det_SMod():
    def __init__(self, wres, tmax):
        self.wmax       = 1
        self.tmax       = tmax
        
    def __call__(self, intimes, outtimes, winner, weights, lr_capture, lr_capture_min, \
                 lr_minus, lr_minus_min, lr_backoff, lr_search,
                   recc, recc_start):
        intimes         = torch.flatten(intimes)
        q               = outtimes.shape[0]
        p               = intimes.shape[0]
        ec_in           = intimes.repeat(q,1)
        li_out          = outtimes.repeat(p,1).permute(1,0)
        
        case1 = (ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)
        case2 = (ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)
        case3 = (ec_in!=float('Inf'))*(li_out==float('Inf'))
        case4 = (ec_in==float('Inf'))*(li_out!=float('Inf'))
        
        #Case 1
        weights[case1] += lr_capture
        
        #Case 2
        weights[case2] -= lr_minus
        
        #Case 3 backoff only for the winner
        if winner != -1:
            case5 = case1[winner.item()].repeat(q,1)
            case5[winner.item()] = False
            
            weights[case5] -= lr_backoff
        
        #case 4 backoff
        weights[case4] -= lr_backoff
        
        if outtimes[outtimes != float('inf')].shape[0] == 0:
            weights[case3] += lr_search
        
        return weights.clamp(0, self.wmax)

In [None]:
class STDP_Det_Recc():
    def __init__(self, wres, tmax):
        self.wmax       = 1
        self.tmax       = tmax
        
    def __call__(self, intimes, outtimes, winner, weights, lr_capture, lr_capture_min, \
                 lr_minus, lr_minus_min, lr_backoff, lr_search,
                   recc, recc_start):
        intimes         = torch.flatten(intimes)
        q               = outtimes.shape[0]
        p               = intimes.shape[0]
        ec_in           = intimes.repeat(q,1)
        li_out          = outtimes.repeat(p,1).permute(1,0)
        
        case1 = (ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)
        case2 = (ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in>li_out)
        case3 = (ec_in!=float('Inf'))*(li_out==float('Inf'))
        case4 = (ec_in==float('Inf'))*(li_out!=float('Inf'))
        
        
        weights[:,recc_start:][case1[:,recc_start:]] += lr_capture
        weights[:,recc_start:][case4[:,recc_start:]] -= lr_backoff
        weights[:,recc_start:][case3[:,recc_start:]] -= lr_backoff
        
        return weights.clamp(0, self.wmax)