In [1]:
# spikingjelly.activation_based.examples.conv_fashion_mnist
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from spikingjelly.activation_based import encoding
from spikingjelly.activation_based import neuron, functional, surrogate, layer
from torch.utils.tensorboard import SummaryWriter
import os
import time
import argparse
from torch.cuda import amp
import sys
import datetime
from spikingjelly import visualizing

device = None
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

cuda


Define the model

In [2]:
class LeNetSNN(nn.Module):
    def __init__(self, T: int):
        super().__init__()
        self.T = T
        self.conv_and_fc = nn.Sequential(
            layer.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),
            neuron.IFNode(surrogate_function=surrogate.LeakyKReLU()),
            layer.MaxPool2d(kernel_size=2, stride=2),

            layer.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            neuron.IFNode(surrogate_function=surrogate.LeakyKReLU()),
            layer.MaxPool2d(kernel_size=2, stride=2),

            layer.Flatten(),
            layer.Linear(5*5*16, 120),
            neuron.IFNode(surrogate_function=surrogate.LeakyKReLU()),
            layer.Linear(120, 84),
            neuron.IFNode(surrogate_function=surrogate.LeakyKReLU()),
            layer.Linear(84, 10),
            neuron.IFNode(surrogate_function=surrogate.ATan())
            )
        functional.set_step_mode(self, step_mode='m')
    def forward(self, x):
        x_seq = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1)
        z = self.conv_and_fc(x_seq)
        fr = z.mean(0)
        return fr

Set Variables

In [3]:
# tau = 2.0
timesteps = 20
model = LeNetSNN(T=timesteps).to(device=device)
EPOCHS=100
AMP=True #automatic mixed precision training
lr= 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
out_dir = "./outputs/CSNN"
encoder = encoding.PoissonEncoder()
batch_size=64
num_workers = 4


Download the sets

In [4]:
root = './FMNIST'
train_set = torchvision.datasets.FashionMNIST(
    root=root,
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

# train_loader = torch.utils.data.DataLoader(
#     dataset=train_set,
#     batch_size=batch_size,
#     shuffle=True
# )

test_set = torchvision.datasets.FashionMNIST(
    root=root,
    train=False,
    download=True,
    transform=transforms.ToTensor()
)

# test_loader = torch.utils.data.DataLoader(
#     dataset=test_set,
#     batch_size=batch_size,
#     shuffle=False
# )

train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
    pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=batch_size,
    shuffle=True,
    drop_last=False,
    num_workers=num_workers,
    pin_memory=True
)

Train the CSNN

In [5]:
scaler = None
if AMP:
    scaler = amp.GradScaler()
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    print(f'Mkdir {out_dir}.')

writer = SummaryWriter(out_dir, purge_step=0)
with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
    args_txt.write('\n')
    args_txt.write(' '.join(sys.argv))

In [None]:
functional.reset_net(model)
for epoch in range(EPOCHS):
    start_time = time.time()
    model.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    for x, label in train_loader:
        optimizer.zero_grad()
        x = x.to(device)
        label = label.to(device)
        label_onehot = F.one_hot(label, 10).float()

        if scaler is None:
            out_fr = 0.
            # for t in range(timesteps):
            #     encoded_img = encoder(x)
            #     out_fr += model(encoded_img)
            # out_fr = out_fr / timesteps
            loss = F.mse_loss(out_fr, label_onehot)
            loss.backward()
            optimizer.step()
        else:
            with amp.autocast():

                out_fr = model(x)
                # out_fr = out_fr / timesteps
                loss = F.mse_loss(out_fr, label_onehot)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        
        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        # print(out_fr.shape)
        # print(label.shape)
        train_acc += (out_fr.argmax(1) == label).float().sum().item()

        functional.reset_net(model) #need to reset the snn before reuse
    print('epoch: ' + str(epoch) + '; loss' + str(train_loss))

epoch: 0; loss6239.7500829696655
epoch: 1; loss6303.500084400177
epoch: 2; loss6299.350085258484
epoch: 3; loss6298.550093173981
epoch: 4; loss6293.550090789795
epoch: 5; loss6291.550089836121
epoch: 6; loss6295.2500920295715
epoch: 7; loss6300.100090503693
epoch: 8; loss6292.0500864982605
epoch: 9; loss6292.650089263916
epoch: 10; loss6295.9000878334045
epoch: 11; loss6302.750090122223
epoch: 12; loss6304.0000829696655
epoch: 13; loss6306.050095081329
epoch: 14; loss6303.100093841553
epoch: 15; loss6296.200090885162
epoch: 16; loss6296.800085544586
epoch: 17; loss6306.700092792511
epoch: 18; loss6300.600088119507
epoch: 19; loss6290.650085449219
epoch: 20; loss6297.350081443787
epoch: 21; loss6299.700090408325
epoch: 22; loss6289.050083637238
epoch: 23; loss6299.350088596344
epoch: 24; loss6301.700087070465
epoch: 25; loss6301.200090408325
epoch: 26; loss6301.7500829696655
epoch: 27; loss6291.000094413757
epoch: 28; loss6296.35008764267
epoch: 29; loss6298.6000900268555
epoch: 30; los

: 