# ResNet Pytorch implementation for FashionMNIST classification
First we import the required packages.

In [1]:
%matplotlib inline
import torch
import torch.nn as nn
from torch.nn import functional as F
from matplotlib import pyplot as plt
import numpy as np
import torchvision
import torchvision.datasets as datasets

import torchvision.models as models
from torchvision import transforms
import torch.optim as optim
import time
import tqdm as tqdm
from torch.autograd import Variable
from torch.utils.data import random_split

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer

from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping


from utils.LogME import LogME

## TODO:
[FreezeOut](https://github.com/ajbrock/FreezeOut)

[RandomErasing](https://github.com/zhunzhong07/Random-Erasing)   

- `transforms.RandomErasing(probability = args.p, sh = args.sh, r1 = args.r1, mean = [0.4914]),`

Other datasets: 
- CIFAR10
- ImageNet?
- Others

## Config

In [2]:
batch_size = 256
num_epochs = 60
augmentation = True
logme_test = False
use_pretrained = False
in_chan = 3 if use_pretrained else 1
train_val_split = 0.8
num_workers = 0

## Load Dataset
We can load data from pytorch dataset and preprocess it using transform function.

Note that the ResNet implemented in torchvision take RGB images as inputs, which has three channels. So, here we repeat the single-channel grey scale digits image three times to fit the torchvision model.

In [3]:
# Shuffle indices
indices = np.arange(60000)
np.random.shuffle(indices)

train_val_split = round((len(indices) * train_val_split))

### Augmentations

In [4]:
norm_mean = 0.2854  # try 0.1307   # 0.2854
norm_std = 0.3528   # try 0.3081   #0.3528
normalize = transforms.Normalize((norm_mean,), (norm_std,))
expand_transform = transforms.Lambda(lambda x: x.repeat(3, 1, 1))

# norm_mean = torch.mean(mnist_train.dataset.data[mnist_train.indices] / 255.)  # 0.1307, 0.3081 also used
# norm_std = torch.std(mnist_train.dataset.data[mnist_train.indices] / 255.)


# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],   # IMAGENET FOR PRETRAINED MODEL
#                                  std=[0.229, 0.224, 0.225])

if augmentation:
    # Prepare transforms and data augmentation
    train_transform = transforms.Compose([
        # transforms.RandomResizedCrop(224),
        transforms.RandomCrop(28, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])
else:
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
        #expand_transform
    ])
    
test_transform = transforms.Compose([
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
    
    transforms.ToTensor(),
    normalize,
    #expand_transform
])

if use_pretrained:
    train_transform.transforms.append(expand_transform)
    test_transform.transforms.append(expand_transform)

### Datasets and loaders

In [5]:
# download dataset
## Train
mnist_train = datasets.FashionMNIST(root='./data', train=True, download=True, transform=train_transform)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size,
    shuffle=False, sampler=torch.utils.data.SubsetRandomSampler(indices[:train_val_split]), num_workers=num_workers)

## Val
mnist_val = datasets.FashionMNIST(root='./data', train=True, download=True, transform=test_transform)

val_loader = torch.utils.data.DataLoader(mnist_val, batch_size=batch_size,
    shuffle=False, sampler=torch.utils.data.SubsetRandomSampler(indices[train_val_split:]))

## Test
mnist_test = datasets.FashionMNIST(root='./data', train=False, download=True, transform=test_transform)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size,
    shuffle=False, num_workers=num_workers)

## sizes
print(train_val_split, len(mnist_train) - train_val_split, len(mnist_test))

48000 12000 10000


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


## Building the model

In [6]:
# https://github.com/cmasch/zalando-fashion-mnist/blob/master/Simple_Convolutional_Neural_Network_Fashion-MNIST.ipynb
# This should get around 0.934 accuracy with data augmentation

class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 64, 4, padding='same')
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(64, 64, 4)
        
        self.fc1 = nn.Linear(64 * 5 * 5, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [7]:
if use_pretrained:
    # Import the module
    import torchvision

    # Download resnet18
    model = torchvision.models.resnet18(pretrained=True)

#     # Freeze all the layers bar the last one
#     for param in model.parameters():
#         param.requires_grad = False
    
    model.fc = nn.Linear(model.fc.in_features, 10)
else:
    model = Classifier()

## Optional: LogME test

In [8]:
if use_pretrained:
    feature_extractor = nn.Sequential(*list(model.children())[:-2])
else:
    feature_extractor = nn.Sequential(*list(model.children())[:-1])

In [9]:
if logme_test:
    score_list = []

    with torch.no_grad():
        for (x, y) in train_loader:
            if use_pretrained:
                features = feature_extractor(x)
            else:
                features = feature_extractor(x)
            score = LogME(features.squeeze(), y)
            score_list.append(score)

    print('LogME score is {}'.format(np.mean(score_list)))
    
    del feature_extractor

In [10]:
class LitMNIST(LightningModule):

    def __init__(self, model):
        super().__init__()
        
        self.model = model

    def forward(self, x):
        x = self.model(x)
        x = F.log_softmax(x, dim=1)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log('train_loss', loss)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log('val_loss', loss, prog_bar=True)
        
        # accuracy
        proba = torch.exp(logits)
        pred_class = torch.argmax(proba, dim=1)
        acc = (pred_class == y).float().mean()
        self.log('val_acc', acc)
        
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log('test_loss', loss)
        
        # accuracy
        proba = torch.exp(logits)
        pred_class = torch.argmax(proba, dim=1)
        acc = (pred_class == y).float().mean()
        self.log('test_acc', acc)
    
    def configure_optimizers(self):
        self.opt = torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=0.001)
        
#         optimizer = optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.001)

        self.reduce_lr_on_plateau = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.opt,
            mode='max',
            factor=0.1,
            patience=3,
            verbose=True,
#             cooldown=5,
            min_lr=1e-8,
        )

        return {"optimizer": self.opt, "lr_scheduler": self.reduce_lr_on_plateau, "monitor": "val_acc"} 

In [11]:
# model = Classifier()
# x = torch.randn(2, 1, 28, 28)
# out = model(x)
# x, y = next(iter(train_loader))
# out = model(x)
# print(out.shape)
# print(model)

## Train model

In [12]:
checkpoint_callback = ModelCheckpoint(monitor='val_loss')
earlystop_callback = EarlyStopping(monitor='val_loss', patience=5)

In [13]:
model = LitMNIST(model)
trainer = Trainer(max_epochs=num_epochs, callbacks=[checkpoint_callback, earlystop_callback])
trainer.fit(model, train_loader, val_loader)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores
  "GPU available but not used. Set the gpus flag in your trainer"

  | Name  | Type       | Params
-------------------------------------
0 | model | Classifier | 493 K 
-------------------------------------
493 K     Trainable params
0         Non-trainable params
493 K     Total params
1.975     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
  self.padding, self.dilation, self.groups)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch    31: reducing learning rate of group 0 to 1.0000e-04.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch    40: reducing learning rate of group 0 to 1.0000e-05.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch    46: reducing learning rate of group 0 to 1.0000e-06.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch    50: reducing learning rate of group 0 to 1.0000e-07.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')


In [14]:
# run test set
result = trainer.test(test_dataloaders=test_loader)

  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.9118000268936157, 'test_loss': 0.2417163997888565}
--------------------------------------------------------------------------------
