In [1]:
import torch
from torch import nn
import timeit
from functools import partial

In [2]:
device = torch.device("cuda")
runs = 20
neurons = [1024, 2048, 4096, 8192, 16384]
x = torch.rand(500, 16, neurons[3], requires_grad=True).to(device)
x = torch.rand(8, 1, 4, requires_grad=True).to(device)
x.retain_grad()

In [3]:
# SpikingJelly
torch.cuda.reset_max_memory_allocated()
from spikingjelly.activation_based import (
    neuron as snn,
    surrogate as sg,
    functional as sf,
)

step_mode = "m"
backend = "cupy"
lif = snn.IFNode(
    surrogate_function=sg.Sigmoid(alpha=1.0), step_mode=step_mode, v_reset=0.0
).to(device)
sf.set_backend(lif, backend)


def run():
    global out
    lif.reset()
    x.grad = None
    if step_mode == "m":
        out = lif(x)
    else:
        out = []
        for xt in x:
            out += [lif(xt)]
        out = torch.stack(out)
    out.mean().backward()


result = timeit.timeit(run, number=runs)
result / runs, torch.cuda.max_memory_allocated() / 1024**2, out, x.grad



(0.008588854999999996,
 0.00390625,
 tensor([[[0., 0., 0., 0.]],
 
         [[0., 0., 0., 1.]],
 
         [[1., 1., 1., 0.]],
 
         [[0., 0., 0., 1.]],
 
         [[1., 1., 1., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[1., 0., 1., 1.]],
 
         [[0., 1., 0., 0.]]], device='cuda:0', grad_fn=<ViewBackward0>),
 tensor([[[ 0.0163,  0.0171,  0.0165,  0.0108]],
 
         [[ 0.0104,  0.0111,  0.0102,  0.0038]],
 
         [[ 0.0034,  0.0042,  0.0032,  0.0096]],
 
         [[ 0.0108,  0.0073,  0.0103,  0.0030]],
 
         [[ 0.0042, -0.0007,  0.0032,  0.0176]],
 
         [[ 0.0121,  0.0194,  0.0125,  0.0114]],
 
         [[ 0.0054,  0.0137,  0.0058,  0.0047]],
 
         [[ 0.0072,  0.0077,  0.0074,  0.0071]]], device='cuda:0'))

In [4]:
# Function s
torch.cuda.reset_max_memory_allocated()


class LIF(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, v, th, tau):
        v = v + x
        x = (v >= th).to(x)
        ctx.save_for_backward(x, v, th)
        v = v * (1 - x)
        return x, v

    @staticmethod
    def backward(ctx, grad_x, grad_v):
        x, v, th = ctx.saved_tensors
        grad_x = grad_x + grad_v * -v
        sg = torch.sigmoid(v - th)
        grad_v = grad_v * (1 - x) + grad_x * sg * (1 - sg)
        grad_x = grad_v
        return grad_x, grad_v, None, None


lif = LIF.apply
th = torch.tensor(1.0)
tau = torch.tensor(2)


def run():
    global out
    v = torch.zeros_like(x[0])
    x.grad = None
    v.grad = None
    out = []
    for xt in x:
        xt, v = lif(xt, v, th, tau)
        out += [xt]
    out = torch.stack(out)
    out.mean().backward()


result = timeit.timeit(run, number=runs)
result / runs, torch.cuda.max_memory_allocated() / 1024**2, out, x.grad



(0.003124664999999993,
 0.01513671875,
 tensor([[[0., 0., 0., 0.]],
 
         [[0., 0., 0., 1.]],
 
         [[1., 1., 1., 0.]],
 
         [[0., 0., 0., 1.]],
 
         [[1., 1., 1., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[1., 0., 1., 1.]],
 
         [[0., 1., 0., 0.]]], device='cuda:0', grad_fn=<StackBackward0>),
 tensor([[[ 0.0163,  0.0171,  0.0165,  0.0108]],
 
         [[ 0.0104,  0.0111,  0.0102,  0.0038]],
 
         [[ 0.0034,  0.0042,  0.0032,  0.0096]],
 
         [[ 0.0108,  0.0073,  0.0103,  0.0030]],
 
         [[ 0.0042, -0.0007,  0.0032,  0.0176]],
 
         [[ 0.0121,  0.0194,  0.0125,  0.0114]],
 
         [[ 0.0054,  0.0137,  0.0058,  0.0047]],
 
         [[ 0.0072,  0.0077,  0.0074,  0.0071]]], device='cuda:0'))

In [5]:
# Function m
torch.cuda.reset_accumulated_memory_stats()


class Floor(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x.floor()

    @staticmethod
    def backward(ctx, grad_x):
        x = ctx.saved_tensors[0]
        x = x % 1
        sg0 = torch.sigmoid(10 * x)
        sg1 = torch.sigmoid(10 * (x - 1))
        grad_x = grad_x * (sg0 * (1 - sg0) + sg1 * (1 - sg1))
        return grad_x


class IF(nn.Module):
    def __init__(self, th=1, tau=2):
        super().__init__()
        self.v = None
        self.th = th
        self.tau = tau

    def forward(self, x):
        self.v = x.cumsum(0).relu()
        x = torch.nn.functional.pad(
            Floor.apply(self.v / self.th).to(x), (0, 0, 0, 0, 1, 0)
        )
        x = x[1:] - x[:-1]
        self.v = self.v - x * self.th
        return x


lif = IF()
th = torch.tensor(1.0)
tau = torch.tensor(2)


def run():
    global out
    v = torch.zeros_like(x)
    x.grad = None
    lif.v = None
    out = lif(x)
    out.mean().backward()


result = timeit.timeit(run, number=runs)
result / runs, torch.cuda.max_memory_allocated() / 1024**2, out, x.grad

(0.001877059999999986,
 0.01513671875,
 tensor([[[0., 0., 0., 0.]],
 
         [[0., 0., 0., 1.]],
 
         [[1., 1., 1., 0.]],
 
         [[0., 1., 1., 1.]],
 
         [[1., 1., 1., 0.]],
 
         [[1., 0., 0., 1.]],
 
         [[0., 1., 1., 1.]],
 
         [[1., 0., 0., 0.]]], device='cuda:0', grad_fn=<SubBackward0>),
 tensor([[[0.0005, 0.0015, 0.0012, 0.0005]],
 
         [[0.0005, 0.0015, 0.0012, 0.0005]],
 
         [[0.0005, 0.0015, 0.0012, 0.0005]],
 
         [[0.0005, 0.0015, 0.0012, 0.0005]],
 
         [[0.0005, 0.0015, 0.0012, 0.0005]],
 
         [[0.0005, 0.0015, 0.0012, 0.0005]],
 
         [[0.0005, 0.0015, 0.0012, 0.0005]],
 
         [[0.0005, 0.0015, 0.0012, 0.0005]]], device='cuda:0'))