In [9]:
import torch.nn as nn
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import torch
import torch.nn.functional as F
from sinabs.from_torch import from_model
import sinabs.layers as sl
import matplotlib.pyplot as plt
import numpy as np

In [10]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
print(f"Running on device: `{device}`")

Running on device: `mps`


In [16]:
params = torch.load("pretrained_ann_weights.pth")
for k, v in params.items():
    print(k, v.shape)

0.weight torch.Size([16, 2, 5, 5])
2.weight torch.Size([16, 16, 3, 3])
5.weight torch.Size([8, 16, 3, 3])
9.weight torch.Size([256, 128])
11.weight torch.Size([10, 256])


In [38]:
N_INPUTS = 1
ann = nn.Sequential(
    nn.Conv2d(
        N_INPUTS, 16, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1), bias=False
    ),  # 16, 18, 18
    nn.ReLU(),
    nn.Conv2d(
        16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
    ),  # 8, 18,18
    nn.ReLU(),
    sl.SumPool2d(kernel_size=(2, 2)),  # 8, 17,17
    nn.Conv2d(
        16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
    ),  # 8, 9, 9
    nn.ReLU(),
    sl.SumPool2d(kernel_size=(2, 2)),
    nn.Flatten(),  # 4, 4, 8 -> 128
    nn.Linear(128, 256, bias=False),
    nn.ReLU(),
    nn.Linear(256, 10, bias=False),
    nn.ReLU(),
)
ann

Sequential(
  (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1), bias=False)
  (1): ReLU()
  (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (3): ReLU()
  (4): SumPool2d(norm_type=1, kernel_size=(2, 2), stride=None, ceil_mode=False)
  (5): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (6): ReLU()
  (7): SumPool2d(norm_type=1, kernel_size=(2, 2), stride=None, ceil_mode=False)
  (8): Flatten(start_dim=1, end_dim=-1)
  (9): Linear(in_features=128, out_features=256, bias=False)
  (10): ReLU()
  (11): Linear(in_features=256, out_features=10, bias=False)
  (12): ReLU()
)

In [36]:
class MNIST(datasets.MNIST):
    def __init__(self, root, train=True, is_spiking=False, time_window=100):
        super().__init__(
            root=root, train=train, download=True, transform=transforms.ToTensor()
        )
        self.is_spiking = is_spiking
        self.time_window = time_window

    def __getitem__(self, index):
        img, target = self.data[index].unsqueeze(0) / 255, self.targets[index]
        # img is now a tensor of 1x28x28

        if self.is_spiking:
            img = (torch.rand(self.time_window, *img.shape) < img).float()

        return img, target

In [39]:
mnist_train = MNIST("./data", train=True, is_spiking=False)
train_loader = DataLoader(mnist_train, batch_size=32, shuffle=True)

mnist_test = MNIST("./data", train=False, is_spiking=False)
test_loader = DataLoader(mnist_test, batch_size=32, shuffle=False)

ann = ann.to(device)
ann.train()

optim = torch.optim.Adam(ann.parameters(), lr=1e-3)

n_epochs = 2

losses = []

for n in tqdm(range(n_epochs)):
    for data, target in tqdm(iter(train_loader)):
        data, target = data.to(device), target.to(device)
        print(data.shape)
        output = ann(data)
        optim.zero_grad()

        loss = F.cross_entropy(output, target)
        loss.backward()
        losses.append(loss.item())
        optim.step()

  0%|                                                                                                                        | 0/2 [00:00<?, ?it/s]
  0%|                                                                                                                     | 0/1875 [00:00<?, ?it/s][A
  0%|                                                                                                                        | 0/2 [00:00<?, ?it/s]

torch.Size([32, 1, 28, 28])





RuntimeError: linear(): input and weight.T shapes cannot be multiplied (32x72 and 128x256)

In [None]:
correct_predictions = []

for data, target in iter(test_loader):
    data, target = data.to(device), target.to(device)
    output = ann(data)

    # get the index of the max log-probability
    pred = output.argmax(dim=1, keepdim=True)

    # Compute the total correct predictions
    correct_predictions.append(pred.eq(target.view_as(pred)))

correct_predictions = torch.cat(correct_predictions)
print(
    f"Classification accuracy: {correct_predictions.sum().item()/(len(correct_predictions))*100}%"
)

In [None]:
input_shape = (1, 28, 28)
num_timesteps = 100  # per sample

sinabs_model = from_model(
    ann, input_shape=input_shape, add_spiking_output=True, synops=False, num_timesteps=num_timesteps
)


In [None]:
sinabs_model.spiking_model

In [None]:
test_batch_size = 10

spike_mnist_test = MNIST(
    "./data", train=False, is_spiking=True, time_window=num_timesteps
)
spike_test_loader = DataLoader(
    spike_mnist_test, batch_size=test_batch_size, shuffle=True
)


In [None]:
correct_predictions = []

for data, target in tqdm(spike_test_loader):
    data, target = data.to(device), target.to(device)
    data = sl.FlattenTime()(data)
    with torch.no_grad():
        output = sinabs_model(data)
        output = output.unflatten(
            0, (test_batch_size, output.shape[0] // test_batch_size)
        )

    # get the index of the max log-probability
    pred = output.sum(1).argmax(dim=1, keepdim=True)

    # Compute the total correct predictions
    correct_predictions.append(pred.eq(target.view_as(pred)))
    if len(correct_predictions) * test_batch_size >= 300:
        break

correct_predictions = torch.cat(correct_predictions)
print(
    f"Classification accuracy: {correct_predictions.sum().item()/(len(correct_predictions))*100}%"
)

In [None]:
# Get one sample from the dataloader
img, label = spike_mnist_test[10]

%matplotlib inline

plt.imshow(img.sum(0)[0]);

In [None]:
snn_output = sinabs_model(img.to(device))

plt.pcolormesh(snn_output.T.detach().cpu())

plt.ylabel("Neuron ID")
plt.yticks(np.arange(10) + 0.5, np.arange(10))
plt.xlabel("Time");
