In [None]:
import numpy as np
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule

from torch.utils.data import DataLoader, random_split
from torchvision.datasets import FashionMNIST
import os
from torchvision import datasets, transforms
from torch.optim import Adam

from pytorch_lightning import Trainer
import psutil

import utils as utils

DATA_DIR = './fashionMNIST/'

# Find the number of physical CPU's on the machine (not virtual cores)
NUM_CPUS = psutil.cpu_count(logical=False)

In [None]:
# Setup pytorch and numpy for reproducability between runs.  That way we can
# make tweaks and see what the effect on performance is.  If we don't set this we will
# get different accuracy results between runs, and will complicate measuring the effect
# of our changes on the performance.
seed_val = 42
torch.manual_seed(seed_val)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed_val)

In [None]:
class LitMNIST(LightningModule):

    def __init__(self):
        super().__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Conv2d(1, 32, 3)
        self.layer_2 = torch.nn.MaxPool2d(2)
        self.layer_3 = torch.nn.Conv2d(32, 64, 3)
        self.layer_4 = torch.nn.MaxPool2d(2)
        self.layer_5 = torch.nn.Conv2d(64, 64, 3)
        self.fc1 = torch.nn.Linear(64 * 3 * 3, 64)
        self.fc2 = torch.nn.Linear(64, 10)
        

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        
        # (b, 1, 28, 28) -> (b, 1*28*28)
        # x = x.view(batch_size, -1)
        
        # conv + relu
        x = self.layer_1(x)
        x = torch.relu(x)
        
        # maxpool
        x = self.layer_2(x)
        
        # conv + relu
        x = self.layer_3(x)
        x = torch.relu(x)
        
        # maxpool
        x = self.layer_4(x)
        
        # conv + relu
        x = self.layer_5(x)
        x = torch.relu(x)

        # flatten
        x = torch.flatten(x, start_dim=1)
        
        # densely connected layers
        x = self.fc1(x)
        x = torch.relu(x)
        
        x = self.fc2(x)
        
        # probability distribution over labels
        x = torch.log_softmax(x, dim=1)

        return x
    
    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def prepare_data(self):
        # transforms for images
        transform=transforms.Compose([transforms.ToTensor()])
        fmnist_train = FashionMNIST(DATA_DIR, train=True, download=True, transform=transform)
        self.fmnist_test = FashionMNIST(DATA_DIR, train=False, download=True, transform=transform)
        
        self.fmnist_train, self.fmnist_val = random_split(fmnist_train, [55000, 5000])
        
    def train_dataloader(self):
        return DataLoader(self.fmnist_train, batch_size=64, num_workers=NUM_CPUS)

    def val_dataloader(self):
        return DataLoader(self.fmnist_val, batch_size=64, num_workers=NUM_CPUS)

    def test_dataloader(self):
        return DataLoader(self.fmnist_test, batch_size=64, num_workers=NUM_CPUS)

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        
        logs = {'train_loss': loss}
        return {'loss': loss, 'log': logs}
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.cross_entropy_loss(logits, y)
        return {'val_loss': loss}
    
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'val_loss': avg_loss, 'log': tensorboard_logs}
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.cross_entropy_loss(logits, y)
        
        return {'test_loss': loss}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        tensorboard_logs = {'test_loss': avg_loss}
        return {'test_loss': avg_loss, 'log': tensorboard_logs}

In [None]:
net = LitMNIST()
x = torch.Tensor(1, 1, 28, 28)
out = net(x)

In [None]:
out

In [None]:
net.prepare_data()

In [None]:
# Class map from the fashion mnist website https://github.com/zalandoresearch/fashion-mnist
classes = {0: 'T-shirt/top',
           1: 'Trouser',
           2: 'Pullover',
           3: 'Dress',
           4: 'Coat',
           5: 'Sandal',
           6: 'Shirt',
           7: 'Sneaker',
           8: 'Bag',
           9: 'Ankle Boot'}

In [None]:
dl = net.train_dataloader()
utils.display_grid_data(dl, classes)

In [None]:
model = LitMNIST()
trainer = Trainer(max_epochs=5)
trainer.fit(model)

In [None]:
trainer.test()

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

In [None]:
dl = model.test_dataloader()
import utils as utils

In [None]:
labels, predictions = utils.model_predictions(dl, model)
utils.measure_accuracy(labels, predictions, all_possible_labels=range(10))