In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Notes on the conv
* kernel size of time should be 1, otherwise it gets convolved as well.
* bias=False
* replace their pool layer with nn.AvgPool3d. The pool they used is a $sum*\theta$ not an average.
* should recieve the Spikes from the previous layer

## Pipeline
raw data(B,C,H,W) --- spike generation ---> spike(B,C,H,W,T) --- Conv3d on spikes ---> X(t) for every t of T --- eq 1: LIF, threshold and reset ---> intermediate V(t), the membrane potential --- based on V(t) ---> Spikes for the next layer.

In [2]:
sample_batch = torch.abs(torch.rand(2,1,4,4,6)) # (batch_size, channels, H, W, Time)
# conv3d = nn.Conv3d(1,2,kernel_size=(4,4,1),bias=False)
# conv_spikes = conv3d(sample_batch)
# conv_spikes

## The 3 classes we should be able to use them directly
## We need spikes from spike generatior as input and the spike counter, then they should work fine

In [41]:
def LIF(spikes, theta, leak, V_min):
    
        '''
        Integrate-and-fire: given a tensor with shape (B,C,H,W,T), loop over T

        Params: 
            spikes: the spikes from previous layer, containing 0 or 1.
            theta: threshold to fire a spike.
            l: leakage parameter.
            V_min: the resting state of membrane potential, usually is set to 0.

        return:
            next_spikes: 0 or 1 tensor with the same shape of input

        '''
        
        # the padding controls where to pad, first two of the tuple control the last dim of the tensor
        _pad = nn.ConstantPad3d((1,0,0,0,0,0), 0)
        pad_spikes = _pad(spikes)

        V = torch.zeros_like(pad_spikes)
        next_Spikes = torch.zeros_like(pad_spikes)

        T = pad_spikes.shape[-1]

        for t in range(1, T):
            # equation (1a)
            V[:,:,:,:,t] = V[:,:,:,:,t-1] + leak + pad_spikes[:,:,:,:,t]
            # thresholding and fire spike (1b)
            mask_threshold = V[:,:,:,:,t] >= theta
            next_Spikes[:,:,:,:,t][mask_threshold] = 1        
            # reset the potential to zero 
            V[:,:,:,:,t][mask_threshold] = 0
            
            # reset the value to V_min if drops below (1c)
            mask_min = (V[:,:,:,:,t] < V_min)
            V[:,:,:,:,t][mask_min] = V_min

        return (V[:,:,:,:,1:], next_Spikes[:,:,:,:,1:])



class convLayer(nn.Conv3d):

    def __init__(self, inChannels, outChannels, kernelSize, theta, leak=0, V_min=0):
        
        kernel = (kernelSize, kernelSize, 1)
        
        super(convLayer, self).__init__(inChannels, outChannels, kernel, bias=False)
        self.theta = theta
        self.leak = leak
        self.V_min = V_min

        
    def forward(self, input):
        
        # get X, namely eq(2)
        conv_spikes = F.conv3d(input, 
                        self.weight, self.bias, 
                        self.stride, self.padding, self.dilation, self.groups)
        
        output = LIF(conv_spikes, self.theta, self.leak, self.V_min)
        
        return (conv_spikes, *output)

    
class poolLayer(nn.AvgPool3d):
    def __init__(self, kernel_size, theta, leak=0, V_min=0):
        super(poolLayer, self).__init__(kernel_size)
        self.theta = theta
        self.leak = leak
        self.V_min = V_min
        
    def forward(self, input):
        
        # get X, namely eq(2)
        pool_spikes = F.avg_pool3d(input, self.kernel_size)
        
        output = LIF(pool_spikes, self.theta, self.leak, self.V_min)
        
        return (pool_spikes, *output)

class denseLayer(nn.Conv3d):
    def __init__(self, inFeatures, outFeatures, theta, leak=0, V_min=0):
        '''
        '''
        # extract information for kernel and inChannels
        if type(inFeatures) == int:
            kernel = (1, 1, 1)
            inChannels = inFeatures 
        elif len(inFeatures) == 2:
            kernel = (inFeatures[1], inFeatures[0], 1)
            inChannels = 1
        elif len(inFeatures) == 3:
            kernel = (inFeatures[1], inFeatures[0], 1)
            inChannels = inFeatures[2]
        else:
            raise Exception('inFeatures should not be more than 3 dimension. It was: {}'.format(inFeatures.shape))

        
        if type(outFeatures) == int:
            outChannels = outFeatures
        else:
            raise Exception('outFeatures should not be more than 1 dimesnion. It was: {}'.format(outFeatures.shape))

        
        super(denseLayer, self).__init__(inChannels, outChannels, kernel, bias=False)
        
        # params for the LIF
        self.theta = theta
        self.leak = leak
        self.V_min = V_min
        
    
    def forward(self, input):
        fc_spikes = F.conv3d(input, 
                        self.weight, self.bias, 
                        self.stride, self.padding, self.dilation, self.groups)
        
        output = LIF(fc_spikes, self.theta, self.leak, self.V_min)
        
        return (fc_spikes, *output)

slayer_conv = convLayer(1,2,4,theta = 0.5)
slayer_pool = poolLayer((2,2,1),theta=0.5)
slayer_dense = denseLayer((1,1,2), 10, theta = 0.5)

In [28]:
slayer_conv(sample_batch)[2].shape

torch.Size([2, 2, 1, 1, 6])

In [21]:
slayer_dense(slayer_conv(sample_batch)[2])[2].shape

torch.Size([2, 10, 1, 1, 6])

In [37]:
slayer_pool(sample_batch)[1].shape

torch.Size([2, 1, 2, 2, 6])

In [38]:
sample_batch.shape

torch.Size([2, 1, 4, 4, 6])

In [23]:
pool = nn.AvgPool3d((2,2,1))
pool(sample_batch)

tensor([[[[[0.3815, 0.4638, 0.7046, 0.7369, 0.6491, 0.4967],
           [0.6913, 0.5906, 0.6854, 0.2956, 0.3159, 0.2807]],

          [[0.5666, 0.3653, 0.6465, 0.5139, 0.3165, 0.4897],
           [0.5576, 0.6426, 0.5479, 0.5369, 0.6438, 0.4931]]]],



        [[[[0.5946, 0.5295, 0.4178, 0.6139, 0.3208, 0.5947],
           [0.3775, 0.6151, 0.5779, 0.6122, 0.7627, 0.4354]],

          [[0.3299, 0.5560, 0.3300, 0.5514, 0.7094, 0.5810],
           [0.5793, 0.3344, 0.7445, 0.6099, 0.7501, 0.4990]]]]])

In [24]:
pool(sample_batch).shape

torch.Size([2, 1, 2, 2, 6])

In [5]:
# theta = 0.5
# V_min = -0.1
# Leak = 0

# conv_spikes = conv3d(sample_batch)
# print(sample_batch.shape)
# print(conv_spikes.shape)

# # the padding controls where to pad, first two of the tuple control the last dim of the tensor
# _pad = nn.ConstantPad3d((1,0,0,0,0,0), 0)
# conv_pad_spikes = _pad(conv_spikes)

# V = torch.zeros_like(conv_pad_spikes)
# next_Spikes = torch.zeros_like(conv_pad_spikes)

# T = conv_pad_spikes.shape[-1]


# for t in range(1, T):
#     print('_______calculate membrane_________')
#     # equation 1a
#     V[:,:,:,:,t] = V[:,:,:,:,t-1] + Leak + conv_pad_spikes[:,:,:,:,t]
#     print(V)
#     # thresholding and fire spike
#     mask_threshold = V[:,:,:,:,t] >= theta
#     next_Spikes[:,:,:,:,t][mask_threshold] = 1        
#     # reset the potential to zero 
#     V[:,:,:,:,t][mask_threshold] = 0
    
# #     mask_min = (V[:,:,:,:,t]<V_min)
# #     V[:,:,:,:,t][mask_min] = V_min
#     print('_____Spike!________')
#     print(V)

In [6]:
theta = 0.5
V_min = -0.1
Leak = 0

def LIF(spikes, theta=0.5, Leak=0, V_min=0):
    
    '''
    Integrate-and-fire: given a tensor with shape (B,C,H,W,T), loop over 
    
    Params:
        spikes: the spikes from previous layer, containing 0 or 1.
        theta: threshold to fire a spike.
        L: leakage parameter.
        V_min: the resting state of membrane potential, usually is set to 0.
        
    return:
        next_spikes: 0 or 1 tensor with the same shape of input

    '''
    
    # the padding controls where to pad, first two of the tuple control the last dim of the tensor
    _pad = nn.ConstantPad3d((1,0,0,0,0,0), 0)
    pad_spikes = _pad(spikes)
    
    V = torch.zeros_like(pad_spikes)
    next_Spikes = torch.zeros_like(pad_spikes)
    
    T = pad_spikes.shape[-1]
    
    for t in range(1, T):
        # equation 1a
        V[:,:,:,:,t] = V[:,:,:,:,t-1] + Leak + pad_spikes[:,:,:,:,t]
        # thresholding and fire spike
        mask_threshold = V[:,:,:,:,t] >= theta
        next_Spikes[:,:,:,:,t][mask_threshold] = 1        
        # reset the potential to zero 
        V[:,:,:,:,t][mask_threshold] = 0

        mask_min = (V[:,:,:,:,t]<V_min)
        V[:,:,:,:,t][mask_min] = V_min
        
        
    return V[:,:,:,:,1:], next_Spikes[:,:,:,:,1:]

In [8]:
sample_batch = torch.abs(torch.rand(2,1,4,4,6)) # (batch_size, channels, H, W, Time)
conv3d = nn.Conv3d(1,2,kernel_size=(4,4,1),bias=False)
conv_spikes = conv3d(sample_batch)
conv_spikes

tensor([[[[[ 0.6851,  0.4956,  0.4907,  0.5220,  0.5770,  0.1868]]],


         [[[-0.4270, -0.3262, -0.2336, -0.4638, -0.1081, -0.4898]]]],



        [[[[ 0.9438,  0.5806,  0.6456,  0.5263,  0.4105,  0.7090]]],


         [[[-0.0026, -0.1107, -0.3080, -0.4042, -0.2310, -0.7148]]]]],
       grad_fn=<ThnnConv3DBackward>)

In [9]:
check_spikes = torch.ones(2,1,4,4,15) * 0.1

In [15]:
LIF(check_spikes,theta, V_min=0.1)

(tensor([[[[[0.1000, 0.2000, 0.3000, 0.4000, 0.1000, 0.2000, 0.3000, 0.4000,
             0.1000, 0.2000, 0.3000, 0.4000, 0.1000, 0.2000, 0.3000],
            [0.1000, 0.2000, 0.3000, 0.4000, 0.1000, 0.2000, 0.3000, 0.4000,
             0.1000, 0.2000, 0.3000, 0.4000, 0.1000, 0.2000, 0.3000],
            [0.1000, 0.2000, 0.3000, 0.4000, 0.1000, 0.2000, 0.3000, 0.4000,
             0.1000, 0.2000, 0.3000, 0.4000, 0.1000, 0.2000, 0.3000],
            [0.1000, 0.2000, 0.3000, 0.4000, 0.1000, 0.2000, 0.3000, 0.4000,
             0.1000, 0.2000, 0.3000, 0.4000, 0.1000, 0.2000, 0.3000]],
 
           [[0.1000, 0.2000, 0.3000, 0.4000, 0.1000, 0.2000, 0.3000, 0.4000,
             0.1000, 0.2000, 0.3000, 0.4000, 0.1000, 0.2000, 0.3000],
            [0.1000, 0.2000, 0.3000, 0.4000, 0.1000, 0.2000, 0.3000, 0.4000,
             0.1000, 0.2000, 0.3000, 0.4000, 0.1000, 0.2000, 0.3000],
            [0.1000, 0.2000, 0.3000, 0.4000, 0.1000, 0.2000, 0.3000, 0.4000,
             0.1000, 0.2000, 0.3000, 0

In [18]:
LIF(conv_spikes,theta)[0]

tensor([[[[[0.0000, 0.4956, 0.0000, 0.0000, 0.0000, 0.1868]]],


         [[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]],



        [[[[0.0000, 0.0000, 0.0000, 0.0000, 0.4105, 0.0000]]],


         [[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]]],
       grad_fn=<SliceBackward>)

In [12]:
conv_spikes

tensor([[[[[ 0.6851,  0.4956,  0.4907,  0.5220,  0.5770,  0.1868]]],


         [[[-0.4270, -0.3262, -0.2336, -0.4638, -0.1081, -0.4898]]]],



        [[[[ 0.9438,  0.5806,  0.6456,  0.5263,  0.4105,  0.7090]]],


         [[[-0.0026, -0.1107, -0.3080, -0.4042, -0.2310, -0.7148]]]]],
       grad_fn=<ThnnConv3DBackward>)

In [170]:
F.conv3d(_pad(sample_batch),weight=torch.ones(1,1,1,1,2))

tensor([[[[[0.8275, 1.4816, 0.9640, 0.4944, 1.1763, 1.4621],
           [0.0550, 0.6530, 0.8653, 0.6185, 0.7270, 0.5225],
           [0.4607, 0.6692, 1.0878, 1.1623, 0.9898, 1.2918],
           [0.7881, 0.8658, 0.9896, 1.6471, 0.8356, 1.0450]],

          [[0.6149, 1.1778, 1.3756, 1.1751, 0.6608, 1.1548],
           [0.0797, 0.4623, 0.4922, 0.3192, 1.1126, 1.3341],
           [0.7701, 1.5521, 1.5145, 1.5825, 1.3025, 0.9577],
           [0.4127, 0.5290, 0.5073, 0.9474, 0.6939, 1.0737]],

          [[0.9167, 1.6864, 1.5628, 0.9006, 0.9320, 0.9134],
           [0.9075, 1.2440, 0.6663, 0.4959, 0.6972, 1.4349],
           [0.9342, 1.4226, 0.5464, 0.5636, 1.2331, 1.4368],
           [0.2386, 0.5454, 1.0050, 1.2593, 1.2872, 1.7125]],

          [[0.8743, 1.7199, 1.0557, 0.7375, 0.9895, 0.5349],
           [0.8573, 1.1217, 1.0049, 1.3333, 1.0664, 0.5724],
           [0.0146, 0.2119, 1.0529, 1.5894, 1.6911, 1.5275],
           [0.2773, 1.1240, 1.7146, 1.8235, 1.5088, 0.8226]]]],



        [[[[

In [167]:
np.flip()