In [1]:
import torch
from torch import nn
import timeit

In [2]:
device = torch.device("cuda")
runs = 10
neurons = [1024, 2048, 4096, 8192, 16384]
x = torch.randn(500, 16, neurons[2], requires_grad=True).to(device)

In [14]:
class LIF(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, v, th, tau):
        v += (x - v) / tau
        x = (v >= th).float()
        ctx.save_for_backward(x, v, th, tau)
        v = v * (1 - x)
        return x, v

    @staticmethod
    def backward(ctx, grad_x, grad_v):
        x, v, th, tau = ctx.saved_tensors
        sg = torch.sigmoid(v - th)
        grad_v = grad_v * (1 - x) + (grad_x + grad_v * -v) * sg * (1 - sg)
        grad_x = grad_v * (1 / tau)
        grad_v *= (1 - 1 / tau)
        return grad_x, grad_v, None, None

lif = LIF.apply
# lif = torch.compile(lif)
v_threshold = torch.tensor(1.0)
tau = torch.tensor(2)


def run():
    v = torch.zeros_like(x)
    x.grad = None
    v.grad = None
    out = []
    for xt in x:
        spike, v = lif(xt, v, v_threshold, tau)
        out += [spike]
    out = torch.stack(out)
    out.mean().backward()


result = timeit.timeit(run, number=runs)
result / runs, torch.cuda.max_memory_allocated() / 1024 ** 2

(0.09028430079999908, 1386.00537109375)