In [1]:
import torch
from torch import nn
import timeit

In [2]:
device = torch.device("cuda"if True else "cpu")
runs = 100
neurons = [1024, 2048, 4096, 8192, 16384]
x = torch.randn(500, 16, neurons[3], requires_grad=True).to(device)
# x = torch.randn(8, 1, 4, requires_grad=True).to(device)
x.retain_grad()

In [3]:
# SpikingJelly
x.grad=None
v=None
out=None
torch.cuda.reset_accumulated_memory_stats()
from spikingjelly.activation_based import (
    neuron as snn,
    surrogate as sg,
    functional as sf,
)

step_mode = "m"
lif = snn.LIFNode(
    surrogate_function=sg.Sigmoid(alpha=1.0), step_mode=step_mode, v_reset=0.0
).to(device)
if step_mode == "m":
    sf.set_backend(lif, "cupy")


def run():
    global out
    lif.reset()
    x.grad = None
    if step_mode == "m":
        out = lif(x)
    else:
        out = []
        for xt in x:
            out += [lif(xt)]
        out = torch.stack(out)
    out.mean().backward()


result = timeit.timeit(run, number=runs)
result / runs, torch.cuda.max_memory_allocated() / 1024**2, out, x.grad

(0.13817385927999568,
 2001.5,
 tensor([[[0., 0., 1.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 1.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[1., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         ...,
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
  

In [None]:
from torchspike import LIF
x.grad=None
v=None
out=None
torch.cuda.reset_accumulated_memory_stats()

lif = LIF.apply
th = torch.tensor(1.0)
tau = torch.tensor(2)


def run():
    global out
    v = torch.zeros_like(x[0])
    x.grad = None
    v.grad = None
    out = []
    for xt in x:
        xt, v = lif(xt, v, th, tau)
        out += [xt]
    out = torch.stack(out)
    out.mean().backward()


result = timeit.timeit(run, number=runs)
result / runs, torch.cuda.max_memory_allocated() / 1024**2, out, x.grad

(0.23552153757998895,
 1754.0009765625,
 tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 1., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 1.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         ...,
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 1., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
        

In [None]:
from torchspike import LIF_CPP
x.grad=None
v=None
out=None
torch.cuda.reset_accumulated_memory_stats()

lif = LIF_CPP.apply
th = torch.tensor(1.0)
tau = torch.tensor(2)


def run():
    global out
    v = torch.zeros_like(x[0])
    x.grad = None
    v.grad = None
    out = []
    for xt in x:
        xt, v = lif(xt, v, th, tau)
        out += [xt]
    out = torch.stack(out)
    out.mean().backward()


result = timeit.timeit(run, number=runs)
result / runs, torch.cuda.max_memory_allocated() / 1024**2, out, x.grad

(0.21837792891999924,
 2254.5009765625,
 tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 1., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 1.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         ...,
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 1., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
        

In [None]:
from torchspike import LIF_CUDA
x.grad=None
v=None
out=None
torch.cuda.reset_accumulated_memory_stats()

lif = LIF_CUDA.apply
th =torch.tensor(1.0)
tau = torch.tensor(2.0)


def run():
    global out
    v = torch.zeros_like(x[0])
    x.grad = None
    v.grad = None
    out = []
    for xt in x:
        xt, v = lif(xt, v, th, tau)
        out += [xt]
    out = torch.stack(out)
    out.mean().backward()


result = timeit.timeit(run, number=runs)
result / runs, torch.cuda.max_memory_allocated() / 1024**2, out, x.grad

(0.14886225096001,
 2501.0009765625,
 tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         ...,
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          .