# Spiking Neural Network with Eve

## build a spiking neural network
In this notebook, we will define a spiking neural network with Eve, and 
apply it on the cifar10 classification task.

In [1]:
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import eve.core as core
import eve.core.node
import eve.core.layer
from eve.core.state import State
from torch import Tensor
import os

class SpikingNeuralNetwork(core.eve.Eve):
    def __init__(
        self,
        node: str = "IFNode",
        node_kwargs: dict = {
            "voltage_threshold": 1.0,
            "voltage_reset": 0.0,
            "learnable_threshold": False,
            "learnable_reset": False,
            "time_dependent": True,
            "neuron_wise": False,
            "surrogate_fn": "Sigmoid",
            "binary": True,
        },
        encoder: str = "PoissonEncoder",
        encoder_kwargs: dict = {
            "timesteps": 8,
        }
    ):
        super().__init__()
        
        # Do forget to reset golbal state 
        State.reset()
        
        node = getattr(eve.core.node, node)
        encoder = getattr(eve.core.layer, encoder)
        
        # build encoder
        self.encoder = encoder(**encoder_kwargs)
        
        # convolution layer
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
        )
        self.node1 = node(State(self.conv1), **node_kwargs)
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
        )
        self.node2 = node(State(self.conv2), **node_kwargs)
        
        self.conv3 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
        )
        self.node3 = node(State(self.conv3), **node_kwargs)
        
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
        )
        self.node4 = node(State(self.conv4), **node_kwargs)
        
        self.conv5 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
        )
        self.node5 = node(State(self.conv5), **node_kwargs)
        
        self.conv6 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
        )
        self.node6 = node(State(self.conv6), **node_kwargs)
        
        self.classifier = nn.Linear(512 * 16, 10)
    
    def set_timesteps(self, timesteps):
        self.encoder.timesteps = timesteps
    
    def spiking_forward(self, x: Tensor) -> Tensor:
        encoder = self.encoder(x)
        
        conv1 = self.conv1(encoder)
        node1 = self.node1(conv1)
        
        conv2 = self.conv2(node1)
        node2 = self.node2(conv2)
        
        conv3 = self.conv3(node2)
        node3 = self.node3(conv3)
        
        conv4 = self.conv4(node3)
        node4 = self.node4(conv4)
        
        conv5 = self.conv5(node4)
        node5 = self.node5(conv5)
        
        conv6 = self.conv6(node5)
        node6 = self.node6(conv6)
        
        feat = F.max_pool2d(node6, kernel_size=2, stride=2)
        feat = th.flatten(feat, 1)
        
        return self.classifier(feat)

    def non_spiking_forward(self, x: Tensor) -> Tensor:
        return self.spiking_forward(x)
    
    def forward(self, x: Tensor) -> Tensor:
        if self.spiking:
            # Don't forget to reset membrane voltage every forward.
            self.reset()
            res = [self.spiking_forward(x) for _ in range(self.encoder.timesteps)]
            return th.stack(res, dim=0).mean(dim=0)
        else:
            return self.non_spiking_forward(x)
    
    
    def load_pretrained_model(self, pretrained: str = "vgg.pth"):
        if os.path.isfile(pretrained):
            ckpt = th.load(pretrained)
        else:
            import wget
            model_urls = 'https://github.com/rhhc/zxd_releases/releases/download/Re/cifar10-vggsmall-zxd-93.4-8943fa3.pth'
            wget.download(model_url, pretrained)
        key_map = {
            'classifier.bias': 'classifier.bias',
            'classifier.weight': 'classifier.weight',
            'conv1.0.weight': 'features.0.weight',
            'conv1.1.bias': 'features.1.bias',
            'conv1.1.num_batches_tracked': 'features.1.num_batches_tracked',
            'conv1.1.running_mean': 'features.1.running_mean',
            'conv1.1.running_var': 'features.1.running_var',
            'conv1.1.weight': 'features.1.weight',
            'conv2.0.weight': 'features.3.weight',
            'conv2.1.bias': 'features.4.bias',
            'conv2.1.num_batches_tracked': 'features.4.num_batches_tracked',
            'conv2.1.running_mean': 'features.4.running_mean',
            'conv2.1.running_var': 'features.4.running_var',
            'conv2.1.weight': 'features.4.weight',
            'conv3.1.weight': 'features.7.weight',
            'conv3.2.bias': 'features.8.bias',
            'conv3.2.num_batches_tracked': 'features.8.num_batches_tracked',
            'conv3.2.running_mean': 'features.8.running_mean',
            'conv3.2.running_var': 'features.8.running_var',
            'conv3.2.weight': 'features.8.weight',
            'conv4.0.weight': 'features.10.weight',
            'conv4.1.bias': 'features.11.bias',
            'conv4.1.num_batches_tracked': 'features.11.num_batches_tracked',
            'conv4.1.running_mean': 'features.11.running_mean',
            'conv4.1.running_var': 'features.11.running_var',
            'conv4.1.weight': 'features.11.weight',
            'conv5.1.weight': 'features.14.weight',
            'conv5.2.bias': 'features.15.bias',
            'conv5.2.num_batches_tracked': 'features.15.num_batches_tracked',
            'conv5.2.running_mean': 'features.15.running_mean',
            'conv5.2.running_var': 'features.15.running_var',
            'conv5.2.weight': 'features.15.weight',
            'conv6.0.weight': 'features.17.weight',
            'conv6.1.bias': 'features.18.bias',
            'conv6.1.num_batches_tracked': 'features.18.num_batches_tracked',
            'conv6.1.running_mean': 'features.18.running_mean',
            'conv6.1.running_var': 'features.18.running_var',
            'conv6.1.weight': 'features.18.weight',
        }
        new_state_dict = {}
        for k, v in self.state_dict().items():
            if k in key_map:
                new_state_dict[k] = ckpt[key_map[k]]
            else:
                new_state_dict[k] = v
        self.load_state_dict(new_state_dict)

## Define dataset and dataloader

In [2]:
from torchvision import transforms
from torchvision.datasets import CIFAR10

data_root = "/media/densechen/data/dataset/"

cifar10_transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])
cifar10_transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

train_dataset = CIFAR10(root=data_root, train=True, download=False, 
                        transform=cifar10_transform_train)
test_dataset = CIFAR10(root=data_root, train=False, download=False,
                       transform=cifar10_transform_test)

train_dataloader = th.utils.data.DataLoader(train_dataset, batch_size=128, 
                                            shuffle=True, num_workers=4)
test_dataloader = th.utils.data.DataLoader(test_dataset, batch_size=128,
                                           shuffle=False, num_workers=4)

## trainer and tester

In [3]:
from tqdm import notebook

device = "cuda:0"

def trainer(net, optimizer):
    net.train()
    progress = notebook.tqdm(train_dataloader)
    for data in progress:
        data = [x.to(device) for x in data]
        x, y = data
        optimizer.zero_grad()
        y_hat = net(x)
        loss = F.cross_entropy(y_hat, y)
        loss.backward()
        optimizer.step()
        
        tp1 = top_one_accuracy(y_hat, y)
        progress.set_description(f"train: acc={tp1.item() * 100:.2f}%, loss={loss.item():.2f}")
        
def top_one_accuracy(y_hat, y):
    return (y_hat.max(dim=-1)[1] == y).float().mean()

def tester(net):
    net.eval()
    progress = notebook.tqdm(test_dataloader)
    acc = []
    for data in progress:
        data = [x.to(device) for x in data]
        x, y = data
        y_hat = net(x)
        tp1 = top_one_accuracy(y_hat, y)
        acc.append(tp1)
        progress.set_description(f"test: acc={tp1.item() * 100:.2f}%")
    print(f"Test Accuracy: {th.mean(th.stack(acc)).item() * 100:.2f}%")
        
    

## load and test checkpoints

In [4]:
spiking_neural_net = SpikingNeuralNetwork()
spiking_neural_net.to(device)

# first, let's load the checkpoint and test the accuracy in non-spiking mode.
spiking_neural_net.load_pretrained_model(pretrained="/media/densechen/data/code/eve-mli/examples/checkpoint/cifar10-vggsmall-zxd-93.4-8943fa3.pth")


# do forget ot set non_spiking mode
spiking_neural_net.non_spike()

# test it 
tester(spiking_neural_net)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 93.33%


## spiking neural network

In [5]:
optimizer = th.optim.Adam(spiking_neural_net.torch_parameters(), lr=1e-3)

spiking_neural_net.spike()
spiking_neural_net.set_timesteps(2)

# start training
for i in range(10):
    print(f"Epoch: {i}")
    trainer(spiking_neural_net, optimizer)
    tester(spiking_neural_net)

Epoch: 0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 16.00%
Epoch: 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 15.14%
Epoch: 2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 17.18%
Epoch: 3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 12.35%
Epoch: 4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 25.03%
Epoch: 5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 28.52%
Epoch: 6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 23.94%
Epoch: 7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 17.83%
Epoch: 8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 18.82%
Epoch: 9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 22.32%


## try with RateEncoder

In [6]:
spiking_neural_net = SpikingNeuralNetwork(encoder="RateEncoder")
spiking_neural_net.to(device)

# first, let's load the checkpoint and test the accuracy in non-spiking mode.
spiking_neural_net.load_pretrained_model(pretrained="/media/densechen/data/code/eve-mli/examples/checkpoint/cifar10-vggsmall-zxd-93.4-8943fa3.pth")

optimizer = th.optim.Adam(spiking_neural_net.torch_parameters(), lr=1e-3)

spiking_neural_net.spike()
spiking_neural_net.set_timesteps(2)

# start training
for i in range(10):
    print(f"Epoch: {i}")
    trainer(spiking_neural_net, optimizer)
    tester(spiking_neural_net)

Epoch: 0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 15.52%
Epoch: 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 22.91%
Epoch: 2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 26.49%
Epoch: 3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 21.06%
Epoch: 4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 30.97%
Epoch: 5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 30.19%
Epoch: 6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 33.96%
Epoch: 7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 35.16%
Epoch: 8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 35.84%
Epoch: 9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Test Accuracy: 35.85%


# try with LIFNode

In [7]:
spiking_neural_net = SpikingNeuralNetwork(node="LIFNode")
spiking_neural_net.to(device)

# first, let's load the checkpoint and test the accuracy in non-spiking mode.
spiking_neural_net.load_pretrained_model(pretrained="/media/densechen/data/code/eve-mli/examples/checkpoint/cifar10-vggsmall-zxd-93.4-8943fa3.pth")

optimizer = th.optim.Adam(spiking_neural_net.torch_parameters(), lr=1e-3)

spiking_neural_net.spike()
spiking_neural_net.set_timesteps(4)

# start training
for i in range(10):
    print(f"Epoch: {i}")
    trainer(spiking_neural_net, optimizer)
    tester(spiking_neural_net)

Epoch: 0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




RuntimeError: CUDA out of memory. Tried to allocate 32.00 MiB (GPU 0; 3.95 GiB total capacity; 3.14 GiB already allocated; 27.25 MiB free; 3.34 GiB reserved in total by PyTorch)

## test with flexible learnable parameters

In [8]:
spiking_neural_net = SpikingNeuralNetwork(
    encoder="RateEncoder",
    node_kwargs={
        "voltage_threshold": 1.0,
        "voltage_reset": 0.0,
        "learnable_threshold": True,
        "learnable_reset": True,
        "time_dependent": True,
        "neuron_wise": True,
        "surrogate_fn": "Sigmoid",
        "binary": True,
    },
)
spiking_neural_net.to(device)

# first, let's load the checkpoint and test the accuracy in non-spiking mode.
spiking_neural_net.load_pretrained_model(pretrained="/media/densechen/data/code/eve-mli/examples/checkpoint/cifar10-vggsmall-zxd-93.4-8943fa3.pth")

optimizer = th.optim.Adam(spiking_neural_net.torch_parameters(), lr=1e-3)

spiking_neural_net.spike()
spiking_neural_net.set_timesteps(2)

# start training
for i in range(10):
    print(f"Epoch: {i}")
    trainer(spiking_neural_net, optimizer)
    tester(spiking_neural_net)

Epoch: 0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 3.95 GiB total capacity; 3.05 GiB already allocated; 59.25 MiB free; 3.31 GiB reserved in total by PyTorch)