In [6]:
# spikingjelly.activation_based.examples.conv_fashion_mnist
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from spikingjelly.activation_based import neuron, functional, surrogate, layer
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import os
import time
import argparse
from torch.cuda import amp
import sys
import datetime
from spikingjelly import visualizing

device = None
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

cuda


Model needs to account for 2 things:

1. ANN proposes Batch Normalization for fast training and convergence. Batch normalization aims to normalize the ANN output to 0 mean, which is contrary to the properties of SNNs. Therefore, the parameters of BN can be absorbed into the previous parameter layers (Linear, Conv2d)

2. According to the transformation theory, the input and output of each layer of ANN need to be limited to the range of [0,1], which requires scaling the parameters (model normalization)

3. There is not a good way to use MaxPooling. AvgPool is recommended instead.

In [23]:
class LeNet_Modified(nn.Module):
    def __init__(self):
        super(LeNet_Modified, self).__init__()
        #5x5 kernal on 28x28 image. Should have 2 padding for "32x32" image
        self.c1 = nn.Conv2d(in_channels=1, kernel_size=5, padding=2, out_channels=6) #results in 28x28 in 6 channels. Should have 1 channel-in bc it is one image at first
        self.bn1 = nn.BatchNorm2d(num_features=6, eps=1e-3)
        self.ap1 = nn.AvgPool2d(kernel_size=2, stride=2)
        #Relu 28x28 -> 28x28
        #avg pool 28x28 -> 14x14 (stride=2)
        self.c2 = nn.Conv2d(in_channels=6, kernel_size=5, out_channels=16) #6 channels to 16 channels. 14x14 -> 10x10 with 5x5kernel
        self.bn2 = nn.BatchNorm2d(num_features=16, eps=1e-3)
        self.ap2 = nn.AvgPool2d(kernel_size=2, stride=2)
        #avg pool 10x10 -> 5x5
        self.fc1 = nn.Linear(25*16, 120) #5x5 images, 16 channels in. 120 out
        self.fc2 = nn.Linear(120, 84) #120 -> 84
        self.fc3 = nn.Linear(84, 10)

        self.network = nn.Sequential(
            self.c1,
            self.bn1,
            nn.ReLU(),
            self.ap1,
            self.c2,
            self.bn2,
            nn.ReLU(),
            self.ap2,
            nn.Flatten(),
            self.fc1,
            nn.ReLU(),
            self.fc2,
            nn.ReLU(),
            self.fc3,
        )

    def forward(self, x):
        x = self.network(x)
        return x

    def name(self):
        return "LeNet"

Define variables and Loss

In [24]:
EPOCHS=100
lr = 0.001
batch_size = 32
model_ann = LeNet_Modified().to(device=device)
optimizer = torch.optim.Adam(model_ann.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

Download the FMNIST dataset and then create the dataloaders

In [5]:
root = './FMNIST'
train_set = torchvision.datasets.FashionMNIST(
    root=root,
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    shuffle=True
)

test_set = torchvision.datasets.FashionMNIST(
    root=root,
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=batch_size,
    shuffle=False
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./FMNIST\FashionMNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:26<00:00, 1010507.03it/s]


Extracting ./FMNIST\FashionMNIST\raw\train-images-idx3-ubyte.gz to ./FMNIST\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./FMNIST\FashionMNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 232737.83it/s]


Extracting ./FMNIST\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ./FMNIST\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./FMNIST\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:02<00:00, 1603047.41it/s]


Extracting ./FMNIST\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ./FMNIST\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./FMNIST\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<?, ?it/s]


Extracting ./FMNIST\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ./FMNIST\FashionMNIST\raw



Train the ANN model

In [None]:
for epoch in range(100):
    # trainning
    total_loss = 0
    for batch_idx, (x, target) in enumerate(train_loader):
        optimizer.zero_grad()
        x, target = x.cuda(), target.cuda()
        out = model_ann(x)
        loss = criterion(out, target)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
    avg_loss = total_loss / len(train_set)
    print(f'==>>> epoch: {epoch}, train loss: {avg_loss:.6f}')