In [8]:
# spikingjelly.activation_based.examples.conv_fashion_mnist
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from spikingjelly.activation_based import encoding
from spikingjelly.activation_based import neuron, functional, surrogate, layer
from torch.utils.tensorboard import SummaryWriter
import os
import time
import argparse
from torch.cuda import amp
import sys
import datetime
from spikingjelly import visualizing

device = None
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

cuda


Define the model

In [9]:
class SCNN(nn.Module):
    def __init__(self, T: int):
        super().__init__()
        self.T = T
        self.conv_and_fc = nn.Sequential(
            layer.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),
            neuron.IFNode(surrogate_function=surrogate.LeakyKReLU()),
            layer.MaxPool2d(kernel_size=2, stride=2),

            layer.Conv2d(in_channels=6, out_channels=12, kernel_size=5),
            neuron.IFNode(surrogate_function=surrogate.LeakyKReLU()),
            layer.MaxPool2d(kernel_size=2, stride=2),

            layer.Flatten(),
            layer.Linear(5*5*12, 10),
            neuron.IFNode(surrogate_function=surrogate.LeakyKReLU()),
            # layer.Linear(120, 84),
            # neuron.IFNode(surrogate_function=surrogate.LeakyKReLU()),
            # layer.Linear(84, 10),
            # neuron.IFNode(surrogate_function=surrogate.ATan())
            )
        functional.set_step_mode(self, step_mode='m')
    def forward(self, x):
        x_seq = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1)
        z = self.conv_and_fc(x_seq)
        fr = z.mean(0)
        return fr

Set Variables

In [10]:
# tau = 2.0
timesteps = 40
model = SCNN(T=timesteps).to(device=device)
EPOCHS=100
AMP=True #automatic mixed precision training
lr= 0.00001
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
out_dir = "./outputs/CSNN"
encoder = encoding.PoissonEncoder()
batch_size=64
num_workers = 10


Download the sets

In [11]:
root = './FMNIST'
train_set = torchvision.datasets.FashionMNIST(
    root=root,
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

# train_loader = torch.utils.data.DataLoader(
#     dataset=train_set,
#     batch_size=batch_size,
#     shuffle=True
# )

test_set = torchvision.datasets.FashionMNIST(
    root=root,
    train=False,
    download=True,
    transform=transforms.ToTensor()
)

# test_loader = torch.utils.data.DataLoader(
#     dataset=test_set,
#     batch_size=batch_size,
#     shuffle=False
# )

train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
    pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=batch_size,
    shuffle=True,
    drop_last=False,
    num_workers=num_workers,
    pin_memory=True
)

Train the CSNN

In [12]:
scaler = None
if AMP:
    scaler = amp.GradScaler()
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    print(f'Mkdir {out_dir}.')

writer = SummaryWriter(out_dir, purge_step=0)
with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
    args_txt.write('\n')
    args_txt.write(' '.join(sys.argv))
file_dir = './Models/'
if not os.path.exists(file_dir):
    os.makedirs(file_dir)
    print(f'Mkdir {file_dir}.')
full_path = file_dir + '/CSNN.pt'

In [13]:
functional.reset_net(model)
best_loss = 10000000.0
for epoch in range(EPOCHS):
    start_time = time.time()
    model.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    for x, label in train_loader:
        optimizer.zero_grad()
        x = x.to(device)
        label = label.to(device)
        label_onehot = F.one_hot(label, 10).float()

        if scaler is None:
            out_fr = 0.
            # for t in range(timesteps):
            #     encoded_img = encoder(x)
            #     out_fr += model(encoded_img)
            # out_fr = out_fr / timesteps
            loss = F.mse_loss(out_fr, label_onehot)
            loss.backward()
            optimizer.step()
        else:
            with amp.autocast():

                out_fr = model(x)
                # out_fr = out_fr / timesteps
                loss = F.mse_loss(out_fr, label_onehot)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        
        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        # print(out_fr.shape)
        # print(label.shape)
        train_acc += (out_fr.argmax(1) == label).float().sum().item()

        functional.reset_net(model) #need to reset the snn before reuse
    if train_loss < best_loss:
        torch.save(model.state_dict(), f=full_path)
        best_loss = train_loss
        print('new best model saved')
    print('epoch: ' + str(epoch) + '; loss' + str(train_loss))

new best model saved
epoch: 0; loss5556.556187152863
new best model saved
epoch: 1; loss4791.5078320503235
new best model saved
epoch: 2; loss4490.977624416351
new best model saved
epoch: 3; loss4340.243544578552
new best model saved
epoch: 4; loss4246.435659408569
new best model saved
epoch: 5; loss4105.152578115463
new best model saved
epoch: 6; loss3975.1210095882416
new best model saved
epoch: 7; loss3905.2128167152405
new best model saved
epoch: 8; loss3851.2688777446747
new best model saved
epoch: 9; loss3805.5682961940765
new best model saved
epoch: 10; loss3763.6618542671204
new best model saved
epoch: 11; loss3738.7917115688324
new best model saved
epoch: 12; loss3723.7766494750977
new best model saved
epoch: 13; loss3706.3287122249603
new best model saved
epoch: 14; loss3689.5630280971527
new best model saved
epoch: 15; loss3673.541717529297
new best model saved
epoch: 16; loss3658.6415803432465
new best model saved
epoch: 17; loss3646.0450201034546
new best model saved
epoch