In [91]:
import torch
from torch import nn
from spikingjelly.activation_based import neuron as snn, surrogate as sg

In [92]:
x = torch.randn(8, 1, 4, requires_grad=True)
x

tensor([[[-2.1987,  1.4351,  0.4654,  1.3688]],

        [[-1.2085, -1.5836, -2.6538,  0.7185]],

        [[ 0.1599, -0.9066, -1.1468, -0.9280]],

        [[-0.6140,  0.3344, -0.4832,  0.2014]],

        [[-0.1100,  1.3366,  2.1782,  0.4904]],

        [[-1.8840, -1.9816,  1.3575,  0.0211]],

        [[-0.3933,  0.2061, -0.3247, -0.5327]],

        [[-0.6359, -0.5674,  0.0686, -0.3607]]], requires_grad=True)

In [93]:
class LIF2(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x0, v0, v_threshold, tau):
        v1 = v0 + (x0 - v0) / tau
        x1 = (v1 >= v_threshold).float()
        v2 = v1 * (1 - x1)
        ctx.save_for_backward(x1, v1, v_threshold, tau)
        return x1, v2

    @staticmethod
    def backward(ctx, grad_x, grad_v):
        x1, v1, v_threshold, tau = ctx.saved_tensors
        grad_x = grad_x + grad_v * -v1
        grad_v = grad_v * (1 - x1) + grad_x * torch.sigmoid(v1 - v_threshold) * (
            1 - torch.sigmoid(v1 - v_threshold)
        )  # (1 / ((v1 - v_threshold) ** 2 + 1))
        grad_x = grad_v * (1 / tau)
        grad_v = grad_v * (1 - 1 / tau)
        return grad_x, grad_v, None, None

In [94]:
lif1 = snn.LIFNode(surrogate_function=sg.Sigmoid(alpha=1))
x.grad = None
out = []
for xt in x:
    out += [lif1(xt)]
    print(lif1.v)
out = torch.stack(out)
out.mean().backward()
out, x.grad

tensor([[-1.0994,  0.7175,  0.2327,  0.6844]],
       grad_fn=<DifferentiableGraphBackward>)
tensor([[-1.1539, -0.4331, -1.2105,  0.7015]],
       grad_fn=<DifferentiableGraphBackward>)
tensor([[-0.4970, -0.6698, -1.1787, -0.1133]],
       grad_fn=<DifferentiableGraphBackward>)
tensor([[-0.5555, -0.1677, -0.8309,  0.0441]],
       grad_fn=<DifferentiableGraphBackward>)
tensor([[-0.3328,  0.5844,  0.6736,  0.2672]],
       grad_fn=<DifferentiableGraphBackward>)
tensor([[-1.1084, -0.6986,  0.0000,  0.1441]],
       grad_fn=<DifferentiableGraphBackward>)
tensor([[-0.7508, -0.2462, -0.1624, -0.1943]],
       grad_fn=<DifferentiableGraphBackward>)
tensor([[-0.6934, -0.4068, -0.0469, -0.2775]],
       grad_fn=<DifferentiableGraphBackward>)


(tensor([[[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 1., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]]], grad_fn=<StackBackward0>),
 tensor([[[0.0038, 0.0060, 0.0051, 0.0064]],
 
         [[0.0041, 0.0052, 0.0036, 0.0063]],
 
         [[0.0048, 0.0052, 0.0040, 0.0060]],
 
         [[0.0046, 0.0057, 0.0047, 0.0061]],
 
         [[0.0043, 0.0056, 0.0052, 0.0059]],
 
         [[0.0032, 0.0042, 0.0033, 0.0053]],
 
         [[0.0031, 0.0040, 0.0044, 0.0042]],
 
         [[0.0021, 0.0025, 0.0030, 0.0027]]]))

In [95]:
lif2 = LIF2.apply
v=torch.zeros_like(x[0])
x.grad = None
v.grad = None
v_threshold = torch.tensor(1.0)
tau = torch.tensor(2)
out = []
for xt in x:
    spike, v = lif2(xt, v, v_threshold, tau)
    print(v)
    out += [spike]
out = torch.stack(out)
out.mean().backward()
out, x.grad

tensor([[-1.0994,  0.7175,  0.2327,  0.6844]], grad_fn=<LIF2Backward>)
tensor([[-1.1539, -0.4331, -1.2105,  0.7015]], grad_fn=<LIF2Backward>)
tensor([[-0.4970, -0.6698, -1.1787, -0.1133]], grad_fn=<LIF2Backward>)
tensor([[-0.5555, -0.1677, -0.8309,  0.0441]], grad_fn=<LIF2Backward>)
tensor([[-0.3328,  0.5844,  0.6736,  0.2672]], grad_fn=<LIF2Backward>)
tensor([[-1.1084, -0.6986,  0.0000,  0.1441]], grad_fn=<LIF2Backward>)
tensor([[-0.7508, -0.2462, -0.1624, -0.1943]], grad_fn=<LIF2Backward>)
tensor([[-0.6934, -0.4068, -0.0469, -0.2775]], grad_fn=<LIF2Backward>)


(tensor([[[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 1., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]]], grad_fn=<StackBackward0>),
 tensor([[[0.0038, 0.0060, 0.0051, 0.0064]],
 
         [[0.0041, 0.0052, 0.0036, 0.0063]],
 
         [[0.0048, 0.0052, 0.0040, 0.0060]],
 
         [[0.0046, 0.0057, 0.0047, 0.0061]],
 
         [[0.0043, 0.0056, 0.0052, 0.0059]],
 
         [[0.0032, 0.0042, 0.0033, 0.0053]],
 
         [[0.0031, 0.0040, 0.0044, 0.0042]],
 
         [[0.0021, 0.0025, 0.0030, 0.0027]]]))