Skip to content
/ sFPTT Public

FeedForward Propagation Through Time on Spiking Neural Network (SNNs)

License

Notifications You must be signed in to change notification settings

byin-cwi/sFPTT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

50 Commits
 
 
 
 
 
 

Repository files navigation

Training spiking neural networks with Forward Porpogation Through Time (FPTT)


This repository contains code to reproduce the key findings of "Training spiking neural networks with Forward Porpogation Through Time". This code implements the spiking recurrent networks with Liquid Time-Constant spiking neurons (LTC) on Pytorchtrained via FPTT for various tasks. The Notebook was created to illustrate the funcationality of LTC spiking neurons.

This is scientific software, and as such subject to many modifications; we aim to further improve the software to become more user-friendly and extendible in the future.

Datasets


  1. S/P-MNIST, R-MNIST: This dataset can easily be found in torchvision.datasets.MNIST(MNIST)
  2. Fashion-MNIST: This dataset can easily access via torchvision.datasets.FashionMNIST(FashionMNIST)
  3. DVS dataset: SpikingJelly includes neuromorphic datasets (Gesture128-DVS and Cifar10-DVS.You can also download the datasets from official sit. Our prerpocess of DVS datasets also support in SpikingJelly.
  4. PASCAL Visual Object Classes (VOC) dataset(VOC) contains 20 object categories. Each image in this dataset has pixel-level segmentation annotations, bounding box annotations, and object class annotations. This dataset has been widely used as a benchmark for object detection, semantic segmentation, and classification tasks. In this paper, SPiking-YOLO (SPYv4) network was trained and tested on VOC07+12.

Requirements


  1. Pyhton 3.8.10
  2. A working version of python and Pytorch This should be easy: either use the Google Colab facilities, or do a simple installation on your laptop could probabily using pip. (Start Locally | PyTorch) torch==1.7.1
  3. SpikingJelly(SpikingJelly)
  4. For object detection taskes, it requires OpenCV 2

FPTT posude code


for e in range(epochs): # epoch iteration
    for i in range(sequence_len): # read the sequence
        if i ==0:
            model.init_h(x_in.shape[0]) # At first step initialize the hidden states
        else:
            model.h = list(v.detach() for v in model.h) # detach computation graph from previous timestep
        out = model.forward_t(x_in[:,:,i]) # read input and generate output
        loss_c = (i)/sequence_len*criterion(out, targets) # get prediction loss 
        loss_r = get_regularizer_named_params(named_params, _lambda=1.0 ) # get regularizer loss
        loss = loss_c+loss_r
        optimizer.zero_grad()
        loss.backward() # calculate gradient of current timestep
        optimizer.step() # update the network
        post_optimizer_updates( named_params, epoch) # update trace \bar{w} and \delta{l}
    reset_named_parameter(named_params) # reset traces

Object detection Demo


A video demo of SPiking-YOLO (SPYv4) :

SPYv4

Running code


You can find more details in readme file of each task.

  1. Adding task
  2. P/S-MNIST task
  3. Image and DVS task
  4. Spiking YOLO Demo

Finally, we’d love to hear from you if you have any comments or suggestions.

References


[1]. https://github.com/bubbliiiing/yolov4-tiny-pytorch

License

MIT

About

FeedForward Propagation Through Time on Spiking Neural Network (SNNs)

Resources

License

Stars

Watchers

Forks

Packages

No packages published