In [1]:
import torch, torch.nn as nn
from norse.torch import LICell             # Leaky integrator
from norse.torch import LIFCell            # Leaky integrate-and-fire
from norse.torch import SequentialState    # Stateful sequential layers

model = SequentialState(
    nn.Conv2d(1, 20, 5, 1),      # Convolve from 1 -> 20 channels
    LIFCell(),                   # Spiking activation layer
    nn.MaxPool2d(2, 2),
    nn.Conv2d(20, 50, 5, 1),     # Convolve from 20 -> 50 channels
    LIFCell(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),                # Flatten to 800 units
    nn.Linear(800, 10),
    LICell(),                    # Non-spiking integrator layer
)

data = torch.randn(8, 1, 28, 28) # 8 batches, 1 channel, 28x28 pixels
output, state = model(data)      # Provides a tuple (tensor (8, 10), neuron state)

In [2]:
output
state

[None,
 LIFFeedForwardState(v=tensor([[[[ 9.3430e-03,  1.4734e-01, -1.0415e-01,  ..., -5.8091e-02,
             7.7243e-02, -2.1763e-03],
           [ 2.8457e-02, -4.7374e-02,  4.1512e-02,  ..., -2.2892e-02,
            -1.5391e-02, -3.0367e-02],
           [-6.3488e-02, -1.4106e-02, -1.4465e-02,  ..., -5.4619e-02,
            -5.0325e-02,  2.0751e-02],
           ...,
           [ 4.0946e-02, -2.4511e-02,  6.0611e-02,  ...,  2.0317e-02,
             6.4698e-04, -4.5325e-02],
           [-1.2428e-02, -2.2056e-03, -5.4270e-02,  ..., -1.9613e-02,
             1.0644e-01, -1.9704e-02],
           [ 2.9310e-02, -9.8151e-02, -6.6272e-02,  ..., -5.2045e-02,
             5.2665e-02, -2.0760e-02]],
 
          [[-7.5537e-02,  3.6055e-02, -7.4513e-02,  ...,  6.2591e-02,
             2.5240e-02, -1.8935e-02],
           [ 8.9327e-03,  3.4610e-03,  8.9551e-02,  ..., -2.7209e-02,
             1.4528e-02, -1.2328e-02],
           [-3.1664e-02,  4.4654e-02,  1.2239e-02,  ...,  7.1043e-02,
          