Adjusting ResNet architecture for use with MNIST dataset. Training ResNet-18 from scratch.

Code:<br>
https://zablo.net/blog/post/pytorch-resnet-mnist-jupyter-notebook-2021/

In [1]:
# install  PyTorch
!pip install torch torchvision pytorch-lightning

Collecting pytorch-lightning
[?25l  Downloading https://files.pythonhosted.org/packages/c4/99/68da5c6ca999de560036d98c492e507d17996f5eeb7e76ba64acd4bbb142/pytorch_lightning-1.2.8-py3-none-any.whl (841kB)
[K     |████████████████████████████████| 849kB 8.1MB/s 
Collecting torchmetrics>=0.2.0
[?25l  Downloading https://files.pythonhosted.org/packages/3a/42/d984612cabf005a265aa99c8d4ab2958e37b753aafb12f31c81df38751c8/torchmetrics-0.2.0-py3-none-any.whl (176kB)
[K     |████████████████████████████████| 184kB 16.7MB/s 
[?25hCollecting fsspec[http]>=0.8.1
[?25l  Downloading https://files.pythonhosted.org/packages/e9/91/2ef649137816850fa4f4c97c6f2eabb1a79bf0aa2c8ed198e387e373455e/fsspec-2021.4.0-py3-none-any.whl (108kB)
[K     |████████████████████████████████| 112kB 24.7MB/s 
[?25hCollecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |███████████████

In [2]:
import torch
torch.__version__

'1.8.1+cu101'

In [3]:
from torchvision.models import resnet18
from torch import nn
from torch.utils.data import DataLoader

## Load the model

In [4]:
model = resnet18(num_classes=10) # MNIST has 10 classes

In [5]:
# Let's look at the model
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In order to use this model on MNIST, input layer needs to accept single channel instead of 3. This is because MNIST images are single-channel (grayscale), whereas the model is set up to be used with ImageNet, which is 3-channel (RGB).

In [6]:
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

## Load the dataset

In [7]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

In [8]:
train_ds = MNIST("mnist", train=True, download=True, transform=ToTensor())
test_ds = MNIST("mnist", train=False, download=True, transform=ToTensor())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting mnist/MNIST/raw/train-images-idx3-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to mnist/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting mnist/MNIST/raw/train-labels-idx1-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [9]:
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=64)

**PyTorch Lightning**<br>
A framework that structures PyTorch code so it can abstract the details of training.

In [10]:
import pytorch_lightning as pl
from pytorch_lightning.core.decorators import auto_move_data

In [11]:
class ResNetMNIST(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.model = resnet18(num_classes=10)
    self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    self.loss = nn.CrossEntropyLoss()

  @auto_move_data
  def forward(self, x):
    return self.model(x)
  
  def training_step(self, batch, batch_no):
    x, y = batch
    logits = self(x)
    loss = self.loss(logits, y)
    return loss
  
  def configure_optimizers(self):
    return torch.optim.RMSprop(self.parameters(), lr=0.005)

In [12]:
model = ResNetMNIST()

In [22]:
trainer = pl.Trainer(
    gpus=1,
    max_epochs=10,
    progress_bar_refresh_rate=20
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [23]:
trainer.fit(model, train_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params
-------------------------------------------
0 | model | ResNet           | 11.2 M
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.701    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




1

In [25]:
trainer.save_checkpoint("resnet18_mnist.pt")

In [24]:
def get_prediction(x, model: pl.LightningModule):
  model.freeze() # prepares model for predicting
  probabilities = torch.softmax(model(x), dim=1)
  predicted_class = torch.argmax(probabilities, dim=1)
  return predicted_class, probabilities

In [26]:
from tqdm.autonotebook import tqdm

In [27]:
inference_model = ResNetMNIST.load_from_checkpoint("resnet18_mnist.pt", map_location="cuda")

In [28]:
true_y, pred_y = [], []
for batch in tqdm(iter(test_dl), total=len(test_dl)):
  x, y = batch
  true_y.extend(y)
  preds, probs = get_prediction(x, inference_model)
  pred_y.extend(preds.cpu())

HBox(children=(FloatProgress(value=0.0, max=157.0), HTML(value='')))




In [29]:
from sklearn.metrics import classification_report

In [30]:
print(classification_report(true_y, pred_y, digits=3))

              precision    recall  f1-score   support

           0      0.995     0.994     0.994       980
           1      0.994     0.998     0.996      1135
           2      0.996     0.994     0.995      1032
           3      0.993     0.997     0.995      1010
           4      0.997     0.983     0.990       982
           5      0.992     0.989     0.990       892
           6      0.991     0.997     0.994       958
           7      0.989     0.995     0.992      1028
           8      0.998     0.987     0.992       974
           9      0.982     0.992     0.987      1009

    accuracy                          0.993     10000
   macro avg      0.993     0.993     0.993     10000
weighted avg      0.993     0.993     0.993     10000

