In [1]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torchvision
import sys
import time

print('Pytorch version: ', torch.__version__)
print('GPU availability: ', torch.cuda.is_available())

Pytorch version:  2.0.1+cu118
GPU availability:  True


In [2]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_and_fc = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2), #including bias, evern though the SNN did not
            nn.BatchNorm2d(6),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5),
            nn.BatchNorm2d(12),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Flatten(),
            nn.Linear(5*5*12, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 10)
            
            )
    def forward(self, x):
        return self.conv_and_fc(x)

In [3]:
model = CNN().cuda()
EPOCHS=50 #set to 50 epochs bc of diminishing returns
AMP=True #automatic mixed precision training
lr= 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
out_dir = "./outputs/CNN"
batch_size=64

In [4]:
root = './FMNIST'
train_set = torchvision.datasets.FashionMNIST(
    root=root,
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

test_set = torchvision.datasets.FashionMNIST(
    root=root,
    train=False,
    download=True,
    transform=transforms.ToTensor()
)


train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=batch_size,
    shuffle=False,
)

In [5]:
#training
full_path = './Models/CNN.pt'
best_loss = 1000000.0
model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    for batch_idx, (x, target) in enumerate(train_loader):
        optimizer.zero_grad()
        x, target = x.cuda(), target.cuda()
        out = model(x)
        target_onehot = F.one_hot(target, 10).float()
        loss = F.mse_loss(out, target_onehot) #this is done in order to match the SNN loss
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
    avg_loss = total_loss / len(train_set)
    if avg_loss < best_loss:
        torch.save(model.state_dict(), f=full_path)
        best_loss = avg_loss
        print('new best model saved')
    print(f'==>>> epoch: {epoch}, train loss: {avg_loss:.6f}')

new best model saved
==>>> epoch: 0, train loss: 0.000406
new best model saved
==>>> epoch: 1, train loss: 0.000282
new best model saved
==>>> epoch: 2, train loss: 0.000252
new best model saved
==>>> epoch: 3, train loss: 0.000235
new best model saved
==>>> epoch: 4, train loss: 0.000220
new best model saved
==>>> epoch: 5, train loss: 0.000210
new best model saved
==>>> epoch: 6, train loss: 0.000200
new best model saved
==>>> epoch: 7, train loss: 0.000192
new best model saved
==>>> epoch: 8, train loss: 0.000184
new best model saved
==>>> epoch: 9, train loss: 0.000178
new best model saved
==>>> epoch: 10, train loss: 0.000171
new best model saved
==>>> epoch: 11, train loss: 0.000164
new best model saved
==>>> epoch: 12, train loss: 0.000159
new best model saved
==>>> epoch: 13, train loss: 0.000153
new best model saved
==>>> epoch: 14, train loss: 0.000148
new best model saved
==>>> epoch: 15, train loss: 0.000143
new best model saved
==>>> epoch: 16, train loss: 0.000139
new bes

In [6]:
#test
full_path = './Models/CNN.pt'
checkpoint = torch.load(f=full_path)
model.load_state_dict(checkpoint)
total_loss = 0
correct_cnt = 0
model.eval()
for batch_idx, (x, target) in enumerate(test_loader):
    x, target = x.cuda(), target.cuda()
    out = model(x)
    target_onehot = F.one_hot(target, 10).float()
    loss = F.mse_loss(out, target_onehot)

    _, pred_label = torch.max(out, 1)
    correct_cnt += (pred_label == target).sum()
    # smooth average
    total_loss += loss.item()

#time just inference
start = time.time()
for batch_idx, (x, target) in enumerate(test_loader):
    x, target = x.cuda(), target.cuda()
    out = model(x)
end = time.time()

avg_loss = total_loss / len(test_set)
avg_acc = correct_cnt / len(test_set)
print('Time: ' + str(end-start))
print(f'test loss: {avg_loss:.6f}, test accuracy: {avg_acc:.6f}')

Time: 0.5599989891052246
test loss: 0.000261, test accuracy: 0.897500


In [7]:
#infinite loop to see power difference
while True:
    start = time.time()
    for batch_idx, (x, target) in enumerate(test_loader):
        x, target = x.cuda(), target.cuda()
        out = model(x)
    end = time.time()

KeyboardInterrupt: 