In [1]:
# pylint: disable=redefined-outer-name

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision

import pytorch_spiking

torch.manual_seed(0)
np.random.seed(0)

In [9]:
# download MNIST dataset
BATCH_SIZE = 256

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        #torchvision.transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

train_data = torchvision.datasets.MNIST(
    root=".",
    train=True,
    download=True,
    transform=transform,
)

test_data = torchvision.datasets.MNIST(
    root=".",
    train=False,
    download=True,
    transform=transform,
    )

train_data = train_data.data
test_data = test_data.data

train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=BATCH_SIZE, shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=BATCH_SIZE, shuffle=False
)


# flatten images
train_data = train_data.reshape((train_data.shape[0], -1))/255.0
test_data = test_data.reshape((test_data.shape[0], -1))/255.0

print(test_data.shape)

torch.Size([10000, 784])


In [16]:
def train(input_model, train_x, test_x):
    minibatch_size = 32
    optimizer = torch.optim.Adam(input_model.parameters())

    input_model.train()
    for j in range(10):
        train_acc = 0
        for i in range(train_x.shape[0] // minibatch_size):
            input_model.zero_grad()

            batch_in = train_x[i * minibatch_size : (i + 1) * minibatch_size]
            # flatten images
            batch_in = batch_in.reshape((-1,) + train_x.shape[1:-2] + (784,))
            batch_label = train_labels[i * minibatch_size : (i + 1) * minibatch_size]
            output = input_model(torch.tensor(batch_in))

            # compute sparse categorical cross entropy loss
            logp = torch.nn.functional.log_softmax(output, dim=-1)
            logpy = torch.gather(logp, 1, torch.tensor(batch_label).view(-1, 1))
            loss = -logpy.mean()

            loss.backward()
            optimizer.step()

            train_acc += torch.mean(
                torch.eq(torch.argmax(output, dim=1), torch.tensor(batch_label)).float()
            )

        train_acc /= train_x.shape[0] // minibatch_size + 1
        print(f"Train accuracy ({j}): {train_acc}")

    # compute test accuracy
    input_model.eval()
    test_acc = 0
    for i in range(test_x.shape[0] // minibatch_size):
        batch_in = test_x[i * minibatch_size : (i + 1) * minibatch_size]
        batch_in = batch_in.reshape((-1,) + test_x.shape[1:-2] + (784,))
        batch_label = test_labels[i * minibatch_size : (i + 1) * minibatch_size]
        output = input_model(torch.tensor(batch_in))

        test_acc += torch.mean(
            torch.eq(torch.argmax(output, dim=1), torch.tensor(batch_label)).float()
        )

    test_acc /= train_x.shape[0] // minibatch_size + 1

    print(f"Test accuracy: {test_acc}")


In [17]:
# repeat the images for n_steps
n_steps = 10
train_sequences = np.tile(train_data[:, None], (1, n_steps, 1, 1))
test_sequences = np.tile(test_data[:, None], (1, n_steps, 1, 1))

spiking_model = torch.nn.Sequential(
    torch.nn.Linear(784, 128),
    # wrap ReLU in SpikingActivation
    pytorch_spiking.SpikingActivation(torch.nn.ReLU(), spiking_aware_training=False),
    # use average pooling layer to average spiking output over time
    pytorch_spiking.TemporalAvgPool(),
    torch.nn.Linear(128, 10),
)

# train the model, identically to the non-spiking version,
# except using the time sequences as input
train(spiking_model, train_sequences, test_sequences)

Train accuracy (0): 0.0
Train accuracy (1): 0.0
Train accuracy (2): 0.0
Train accuracy (3): 0.0
Train accuracy (4): 0.0
Train accuracy (5): 0.0
Train accuracy (6): 0.0
Train accuracy (7): 0.0
Train accuracy (8): 0.0
Train accuracy (9): 0.0
Test accuracy: 0.0
