In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Notes on the conv
* kernel size of time should be 1, otherwise it gets convolved as well.
* bias=False
* replace their pool layer with nn.AvgPool3d. The pool they used is a $sum*\theta$ not an average.
* should recieve the Spikes from the previous layer

## Pipeline
raw data(B,C,H,W) --- spike generation ---> spike(B,C,H,W,T) --- Conv3d on spikes ---> X(t) for every t of T --- eq 1: LIF, threshold and reset ---> intermediate V(t), the membrane potential --- based on V(t) ---> Spikes for the next layer.



## The 3 classes we should be able to use them directly
## We need spikes from spike generatior as input and the spike counter, then they should work fine

In [16]:
# def LIF(spikes, theta, leak, V_min):
    
#         '''
#         Integrate-and-fire: given a tensor with shape (B,C,H,W,T), loop over T

#         Params: 
#             spikes: the spikes from previous layer, containing 0 or 1.
#             theta: threshold to fire a spike.
#             l: leakage parameter.
#             V_min: the resting state of membrane potential, usually is set to 0.

#         return:
#             V: the potential to fire (just for check, can be removed later)
#             next_spikes: 0 or 1 tensor with the same shape of input

#         '''
        
#         # the padding controls where to pad, first two of the tuple control the last dim of the tensor
#         _pad = nn.ConstantPad3d((1,0,0,0,0,0), 0)
#         pad_spikes = _pad(spikes)

#         V = torch.zeros_like(pad_spikes)
#         next_Spikes = torch.zeros_like(pad_spikes)

#         T = pad_spikes.shape[-1]

#         for t in range(1, T):
#             # equation (1a)
#             V[:,:,:,:,t] = V[:,:,:,:,t-1] + leak + pad_spikes[:,:,:,:,t]
#             # thresholding and fire spike (1b)
#             mask_threshold = V[:,:,:,:,t] >= theta
#             next_Spikes[:,:,:,:,t][mask_threshold] = 1        
#             # reset the potential to zero 
#             V[:,:,:,:,t][mask_threshold] = 0
            
#             # reset the value to V_min if drops below (1c)
#             mask_min = (V[:,:,:,:,t] < V_min)
#             V[:,:,:,:,t][mask_min] = V_min

#         return (V[:,:,:,:,1:], next_Spikes[:,:,:,:,1:])



# class convLayer(nn.Conv3d):

#     def __init__(self, inChannels, outChannels, kernelSize, theta, leak=0, V_min=0):
        
#         kernel = (kernelSize, kernelSize, 1)
        
#         super(convLayer, self).__init__(inChannels, outChannels, kernel, bias=False)
#         self.theta = theta
#         self.leak = leak
#         self.V_min = V_min

        
#     def forward(self, input):
        
#         # get X, namely eq(2)
#         conv_spikes = F.conv3d(input, 
#                         self.weight, self.bias, 
#                         self.stride, self.padding, self.dilation, self.groups)
        
#         output = LIF(conv_spikes, self.theta, self.leak, self.V_min)
        
#         return (conv_spikes, *output)
    
# class poolLayer(nn.AvgPool3d):
#     def __init__(self, kernelSize, theta, leak=0, V_min=0):
        
#         if type(kernelSize) == int:
#             kernel_size = (kernelSize, kernelSize, 1)
#         elif len(kernelSize) == 2:
#             kernel_size = (kernelSize[0], kernelSize[1], 1)
#         else:
#             raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernelSize.shape))
#         super(poolLayer, self).__init__(kernel_size)
#         self.theta = theta
#         self.leak = leak
#         self.V_min = V_min
        
#     def forward(self, input):
        
#         # get X, namely eq(2)
#         pool_spikes = F.avg_pool3d(input, self.kernel_size)
        
#         output = LIF(pool_spikes, self.theta, self.leak, self.V_min)
        
#         return (pool_spikes, *output)

# class denseLayer(nn.Conv3d):
#     def __init__(self, inFeatures, outFeatures, theta, leak=0, V_min=0):
#         '''
#         '''
#         # extract information for kernel and inChannels
#         if type(inFeatures) == int:
#             kernel = (1, 1, 1)
#             inChannels = inFeatures 
#         elif len(inFeatures) == 2:
#             kernel = (inFeatures[1], inFeatures[0], 1)
#             inChannels = 1
#         elif len(inFeatures) == 3:
#             kernel = (inFeatures[1], inFeatures[0], 1)
#             inChannels = inFeatures[2]
#         else:
#             raise Exception('inFeatures should not be more than 3 dimension. It was: {}'.format(inFeatures.shape))

        
#         if type(outFeatures) == int:
#             outChannels = outFeatures
#         else:
#             raise Exception('outFeatures should not be more than 1 dimesnion. It was: {}'.format(outFeatures.shape))

        
#         super(denseLayer, self).__init__(inChannels, outChannels, kernel, bias=False)
        
#         # params for the LIF
#         self.theta = theta
#         self.leak = leak
#         self.V_min = V_min
        
    
#     def forward(self, input):
#         fc_spikes = F.conv3d(input, 
#                         self.weight, self.bias, 
#                         self.stride, self.padding, self.dilation, self.groups)
        
#         output = LIF(fc_spikes, self.theta, self.leak, self.V_min)
        
#         return (fc_spikes, *output)

### test for size change
1. convLayer correct
2. poolLayer correct
3. denseLayer correct

In [3]:
from custmized_layer import convLayer, poolLayer, denseLayer

In [7]:
test_batch = torch.ones((128,3,28,28,10))

slayer_conv = convLayer(3,12,5,theta=0.5)
slayer_pool = poolLayer(2,theta=0.5)
# the inFeatures of denselayer should be (H,W,C)
slayer_dense = denseLayer((28,28,3), 2, theta=0.5)

print("test batch size:",test_batch.shape)
print("conv batch size:",slayer_conv(test_batch).shape) # should give (128, 12, 24, 24, 10)
print("pool batch size:",slayer_pool(test_batch).shape) # should give (128, 3, 14, 14, 10)
print("dense(linear) batch size:",slayer_dense(test_batch).shape) # should give (128, 2, 1, 1, 10)


test batch size: torch.Size([128, 3, 28, 28, 10])
conv batch size: torch.Size([12, 24, 24, 10])
pool batch size: torch.Size([3, 14, 14, 10])
dense(linear) batch size: torch.Size([2, 1, 1, 10])


### test for correct spikes
1. convLayer correct
2. poolLayer correct
3. denseLayer correct

In [16]:
test_spikes = torch.ones((2,1,4,4,10)) * 0.1

slayer_conv = convLayer(1,1,4,theta=0.5)
slayer_pool = poolLayer(2,theta=0.5)
slayer_dense = denseLayer((4,4,1), 2, theta = 0.5)

In [17]:
""" 
    the layer will return:  1. spikes after corresponding op (conv, pull, linear), 
                            2. the fire potential 
                            3. the spikes (whenever the potential larger than theta, it has 1 as entry)
""" 
print(slayer_conv(test_spikes))
print()
print(slayer_pool(test_spikes))
print()
print(slayer_dense(test_spikes))

tensor([[[[[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]]]],



        [[[[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]]]]])

tensor([[[[[0., 0., 0., 0., 1., 0., 0., 0., 0., 1.],
           [0., 0., 0., 0., 1., 0., 0., 0., 0., 1.]],

          [[0., 0., 0., 0., 1., 0., 0., 0., 0., 1.],
           [0., 0., 0., 0., 1., 0., 0., 0., 0., 1.]]]],



        [[[[0., 0., 0., 0., 1., 0., 0., 0., 0., 1.],
           [0., 0., 0., 0., 1., 0., 0., 0., 0., 1.]],

          [[0., 0., 0., 0., 1., 0., 0., 0., 0., 1.],
           [0., 0., 0., 0., 1., 0., 0., 0., 0., 1.]]]]])

tensor([[[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],


         [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]],



        [[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],


         [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]]])
