In [7]:
import torch
from torch import nn
import timeit

In [8]:
device = torch.device("cuda" if True else "cpu")
runs = 100
neurons = [1024, 2048, 4096, 8192, 16384]
x = torch.randn(500, 16, neurons[3], requires_grad=True).to(device)
x = torch.randn(8, 1, 4, requires_grad=True).to(device)
x.retain_grad()

In [9]:
# SpikingJelly
x.grad = None
v = None
out = None
torch.cuda.reset_accumulated_memory_stats()
from spikingjelly.activation_based import (
    neuron as snn,
    surrogate as sg,
    functional as sf,
)

step_mode = "m"
lif = snn.LIFNode(
    surrogate_function=sg.Sigmoid(alpha=1.0), step_mode=step_mode, v_reset=0.0
).to(device)
if step_mode == "m":
    sf.set_backend(lif, "cupy")


def run():
    global out
    lif.reset()
    x.grad = None
    if step_mode == "m":
        out = lif(x)
    else:
        out = []
        for xt in x:
            out += [lif(xt)]
        out = torch.stack(out)
    out.mean().backward()


result = timeit.timeit(run, number=runs)
result / runs, torch.cuda.max_memory_allocated() / 1024**2, out, x.grad

(0.0021242391600026166,
 3001.0009765625,
 tensor([[[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[1., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]]], device='cuda:0', grad_fn=<ViewBackward0>),
 tensor([[[0.0042, 0.0060, 0.0029, 0.0044]],
 
         [[0.0044, 0.0061, 0.0026, 0.0049]],
 
         [[0.0044, 0.0059, 0.0030, 0.0062]],
 
         [[0.0029, 0.0050, 0.0038, 0.0060]],
 
         [[0.0056, 0.0048, 0.0048, 0.0059]],
 
         [[0.0053, 0.0047, 0.0055, 0.0052]],
 
         [[0.0042, 0.0046, 0.0043, 0.0049]],
 
         [[0.0031, 0.0027, 0.0024, 0.0028]]], device='cuda:0'))

In [10]:
from torchspike import LIF

x.grad = None
v = None
out = None
torch.cuda.reset_accumulated_memory_stats()

lif = LIF.apply
th = torch.tensor(1.0)
tau = torch.tensor(2)


def run():
    global out
    v = torch.zeros_like(x[0])
    x.grad = None
    v.grad = None
    out = []
    for xt in x:
        xt, v = lif(xt, v, th, tau)
        out += [xt]
    out = torch.stack(out)
    out.mean().backward()


result = timeit.timeit(run, number=runs)
result / runs, torch.cuda.max_memory_allocated() / 1024**2, out, x.grad

(0.0038084656899991386,
 3001.0009765625,
 tensor([[[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[1., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]]], device='cuda:0', grad_fn=<StackBackward0>),
 tensor([[[0.0042, 0.0060, 0.0029, 0.0044]],
 
         [[0.0044, 0.0061, 0.0026, 0.0049]],
 
         [[0.0044, 0.0059, 0.0030, 0.0062]],
 
         [[0.0029, 0.0050, 0.0038, 0.0060]],
 
         [[0.0056, 0.0048, 0.0048, 0.0059]],
 
         [[0.0053, 0.0047, 0.0055, 0.0052]],
 
         [[0.0042, 0.0046, 0.0043, 0.0049]],
 
         [[0.0031, 0.0027, 0.0024, 0.0028]]], device='cuda:0'))

In [11]:
from torchspike import LIF_CPU

x.grad = None
v = None
out = None
torch.cuda.reset_accumulated_memory_stats()

lif = LIF_CPU.apply
th = torch.tensor(1.0)
tau = torch.tensor(2)


def run():
    global out
    v = torch.zeros_like(x[0])
    x.grad = None
    v.grad = None
    out = []
    for xt in x:
        xt, v = lif(xt, v, th, tau)
        out += [xt]
    out = torch.stack(out)
    out.mean().backward()


result = timeit.timeit(run, number=runs)
result / runs, torch.cuda.max_memory_allocated() / 1024**2, out, x.grad

(0.0039619789400057925,
 3001.0009765625,
 tensor([[[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[1., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]]], device='cuda:0', grad_fn=<StackBackward0>),
 tensor([[[0.0042, 0.0060, 0.0029, 0.0044]],
 
         [[0.0044, 0.0061, 0.0026, 0.0049]],
 
         [[0.0044, 0.0059, 0.0030, 0.0062]],
 
         [[0.0029, 0.0050, 0.0038, 0.0060]],
 
         [[0.0056, 0.0048, 0.0048, 0.0059]],
 
         [[0.0053, 0.0047, 0.0055, 0.0052]],
 
         [[0.0042, 0.0046, 0.0043, 0.0049]],
 
         [[0.0031, 0.0027, 0.0024, 0.0028]]], device='cuda:0'))

In [12]:
from torchspike import LIF_CUDA

x.grad = None
v = None
out = None
torch.cuda.reset_accumulated_memory_stats()

lif = LIF_CUDA.apply
th = torch.tensor(1.0)
tau = torch.tensor(2.0)


def run():
    global out
    v = torch.zeros_like(x[0])
    x.grad = None
    v.grad = None
    out = []
    for xt in x:
        xt, v = lif(xt, v, th, tau)
        out += [xt]
    out = torch.stack(out)
    out.mean().backward()


result = timeit.timeit(run, number=runs)
result / runs, torch.cuda.max_memory_allocated() / 1024**2, out, x.grad

(0.0016004657099983888,
 3001.0009765625,
 tensor([[[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]]], device='cuda:0', grad_fn=<StackBackward0>),
 tensor([[[0.0061, 0.0061, 0.0061, 0.0061]],
 
         [[0.0061, 0.0061, 0.0061, 0.0061]],
 
         [[0.0060, 0.0060, 0.0060, 0.0060]],
 
         [[0.0060, 0.0060, 0.0060, 0.0060]],
 
         [[0.0058, 0.0058, 0.0058, 0.0058]],
 
         [[0.0054, 0.0054, 0.0054, 0.0054]],
 
         [[0.0046, 0.0046, 0.0046, 0.0046]],
 
         [[0.0031, 0.0031, 0.0031, 0.0031]]], device='cuda:0'))