In [20]:
import torch 
import torch.nn as nn           #  Feedforward and Loss function
import torch.nn.functional as F #  Functions that dont accept parameters 
from torch.utils.data import DataLoader  #  Data set Management i.e. creat mini-Batches

import tonic
import tonic.transforms as transforms

# Hyperparameters
in_channels = 2 
num_classes = 11 
learning_rate = 1e-5
batch_size = 8
########## SNN ###########
T_BIN = 15
VTH = 0.3       #0.3
DECAY = 0.3     #0.3
########## Surrogate ###########
alpha = 0.5  #alpha = lens*2

#Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"NOW device is using: {device}")


# Dataset - DVS-Gesture
# 2 x 128 x 128
sensor_size = tonic.datasets.DVSGesture.sensor_size
frame_transform = transforms.Compose([transforms.Denoise(filter_time=10000),
                                      transforms.ToFrame(sensor_size=sensor_size, n_time_bins=T_BIN)])
    
trainset = tonic.datasets.DVSGesture(save_to="../../Torch_condaENV/Working_folder/dataset/",transform=frame_transform, train=True)
testset = tonic.datasets.DVSGesture(save_to="../../Torch_condaENV/Working_folder/dataset/", transform=frame_transform, train=False)

#collation整理 => pad out填充 shorting recordings to have same dimension
train_loader = DataLoader(
    dataset = trainset,
    batch_size= batch_size,
    collate_fn= tonic.collation.PadTensors(batch_first=False),
    shuffle = True,
    drop_last=True
)

test_loader = DataLoader(
    dataset = testset,
    batch_size= batch_size,
    collate_fn= tonic.collation.PadTensors(batch_first=False),
    shuffle = False,
    drop_last=True
)

class CSNN_Model(nn.Module):

    def __init__(self,in_channels,num_classes):
        super(CSNN_Model,self).__init__()

        self.pool  = nn.MaxPool2d(2,2)

        self.conv0 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)  #128x128 
        self.conv1 = nn.Conv2d( 64 , 128, kernel_size=3,  stride=1, padding=1, bias=False)        #64x64 
        self.conv2 = nn.Conv2d( 128 , 128, kernel_size=3, stride=1, padding=1, bias=False)       #32x32
        self.conv3 = nn.Conv2d( 128 , 256, kernel_size=3, stride=2, padding=1, bias=False)       #16x16

        self.fc1   = nn.Linear(4 * 4 * 256, 1024, bias = False)  # 4096*1024 
        self.fc2   = nn.Linear(1024, num_classes, bias = False) 

    def forward(self,input):

        # Reseting Neurons
        c0_mem = c0_spike = torch.zeros(batch_size, 64, 128, 128,  device=device)
        c1_mem = c1_spike = torch.zeros(batch_size, 128, 64, 64, device=device) 
        c2_mem = c2_spike = torch.zeros(batch_size, 128, 32, 32, device=device)
        c3_mem = c3_spike = torch.zeros(batch_size, 256, 8, 8, device=device)

        h1_mem = h1_spike = torch.zeros(batch_size, 1024, device=device)
        h2_mem = h2_spike = h2_sumspike = torch.zeros(batch_size, num_classes, device=device)

        for i in range(T_BIN): # Every single piece of t belongs to T
            # .view change shape/dtype of tensor #####  -1 squeeze all dimension to 1 
            # #1 tensor(N,P,H,W,T) to vector, #2 vector to tensor(N,-1)
            x = input[i,:,:,:,:].to(device)

            c0_mem, c0_spike = mem_update(self.conv0, x, c0_mem, c0_spike)
            p0_spike = self.pool(c0_spike)

            c1_mem, c1_spike = mem_update(self.conv1, p0_spike, c1_mem, c1_spike) 
            p1_spike = self.pool(c1_spike) 

            c2_mem, c2_spike = mem_update(self.conv2, p1_spike, c2_mem, c2_spike) 
            p2_spike = self.pool(c2_spike) 

            c3_mem, c3_spike = mem_update(self.conv3, p2_spike, c3_mem, c3_spike) 
            # print(torch.logical_or(p2_spike, p2_spike).to(torch.float32))
            p3_spike = self.pool(c3_spike) 

            x = p3_spike.view(batch_size, -1)

            h1_mem, h1_spike = mem_update(self.fc1, x, h1_mem, h1_spike)
            h2_mem, h2_spike = mem_update(self.fc2, h1_spike, h2_mem, h2_spike)
            h2_sumspike += h2_spike

        # Where SumSpike = (N,#neurons) (2D matrix)/scalar 
        outputs = h2_sumspike / T_BIN
        # print(torch.mean(outputs,dim=0))
        return outputs


class ActivationFun(torch.autograd.Function):
# For forward: 1/0 spike
# For backward: Surrogate gradient -> unit retangular functionv rect(t) = 1/a if -a/2 < t < +a/2
# h1(t) in spatio-temporal backpropagation by (Wu etal., 2018) 

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.gt(VTH).float()      # torch.gt(a,b) compare a and b : return 1/0 spike

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp = abs(input - VTH) < alpha/2    # lens = alpha/2
        return grad_input * temp.float() / alpha  # intensify spiking output (Wu et. 2018 w/o 2*len) 

act_fun = ActivationFun.apply

def mem_update(fc, x, volt, spike):
    volt = volt * DECAY * (1 - spike) + fc(x)
    spike = act_fun(volt)
    return volt, spike




NOW device is using: cpu


In [21]:
#Network initialization 
model = CSNN_Model(in_channels=in_channels,num_classes=num_classes).to(device)

# weight_path= "./pretrained_DVS_csnn_128e_91a.t7"
weight_path = "./pretrained_DVS_csnn_128e_91a.t7"
checkpoint = torch.load(weight_path,map_location=device)
model.load_state_dict(checkpoint['net'])

#Model evaluation
def check_accuracy(loader, model):
    if loader.dataset.train:
        print("Checking on training data")
        
    else:
        print("Checking on testing data")
    
    num_correct = 0
    num_sample = 0
    model.eval()  #(Equivalent to model.train(False)) Nothing learn

    with torch.no_grad():   #no need to cal grad
        for image,label in loader:
            image= image.to(device)
            label= label.to(device)
            
            # T x N x 2312 => N x 2312
            out_firing = model(image)

            #64x10 output
            _ , prediction = out_firing.max(1)  #64x1 (value in 2nd dimension)
            num_correct += (prediction==label).sum()
            num_sample += prediction.size(0)  #64 (value in 1st dimension)
            
        print(f'Got {num_correct}/{num_sample} with accuracy {float(num_correct)/float(num_sample)*100:.2f}')
    
    model.train() #Set back to train mode
    return num_correct/num_sample    


In [39]:
a = checkpoint['net']

In [40]:

layer_names = list(a.keys())
print(layer_names)


['conv0.weight', 'conv1.weight', 'conv2.weight', 'conv3.weight', 'fc1.weight', 'fc2.weight']


In [26]:
i = []
for weight in model.parameters():
    print(weight.data.shape)
    i+=[weight.data.shape]
print(i)



torch.Size([64, 2, 3, 3])
torch.Size([128, 64, 3, 3])
torch.Size([128, 128, 3, 3])
torch.Size([256, 128, 3, 3])
torch.Size([1024, 4096])
torch.Size([11, 1024])
[torch.Size([64, 2, 3, 3]), torch.Size([128, 64, 3, 3]), torch.Size([128, 128, 3, 3]), torch.Size([256, 128, 3, 3]), torch.Size([1024, 4096]), torch.Size([11, 1024])]


In [None]:
import numpy as np



a.append(3)