# Imports

In [1]:
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, GPUStatsMonitor
from pytorch_lightning.loggers import TensorBoardLogger
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
import torchvision
from torchvision import transforms
import torch
import torch.nn.functional as F
import os

# Data module

In [2]:
PATH = '.'

In [3]:
class MNISTDataModule(pl.LightningDataModule):

    def __init__(self, data_dir: str = PATH, batch_size=64):
        super().__init__()
        self.batch_size = batch_size
        self.data_dir = data_dir
        
    def setup(self, stage=None):
        self.train_ds = MNIST(self.data_dir, 
                                              train=True, 
                                              download=True,
                                              transform=transforms.Compose([
                                               torchvision.transforms.ToTensor(),
                                               torchvision.transforms.Normalize(
                                                 (0.1307,), (0.3081,))
                                             ]))

        self.val_ds = MNIST(self.data_dir, 
                                            train=False, 
                                            download=True,
                                            transform=transforms.Compose([
                                            torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Normalize(
                                                (0.1307,), (0.3081,))
                                            ]))
        
        print(f'[INFO] Training on {len(self.train_ds)}')
        print(f'[INFO] Validating on {len(self.val_ds)}')
        
    def train_dataloader(self):
        return DataLoader(self.train_ds, 
                          batch_size=256, 
                          num_workers=os.cpu_count())

    def val_dataloader(self):
        return DataLoader(self.val_ds, 
                          batch_size=128, 
                          num_workers=os.cpu_count())

    def test_dataloader(self):
        return DataLoader(self.mnist_test, 
                          batch_size=128, 
                          num_workers=os.cpu_count())

In [4]:
dm = MNISTDataModule()
dm.setup()

[INFO] Training on 60000
[INFO] Validating on 10000


# Module/Model

In [5]:
class CNNDigitClassifier(pl.LightningModule):
    def __init__(self):
        super(CNNDigitClassifier, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, (3, 3))
        self.conv2 = torch.nn.Conv2d(32, 64, (3, 3))
        self.pool1 = torch.nn.MaxPool2d(2, 2)
        self.bn1 = torch.nn.BatchNorm2d(64)

        self.conv3 = torch.nn.Conv2d(64, 128, (3, 3))
        self.pool2 = torch.nn.MaxPool2d(2, 2)
        self.bn2 = torch.nn.BatchNorm2d(128)

        self.classifier = torch.nn.Linear(128 * 5 * 5, 10)


    def forward(self, x):
        x = torch.relu(self.conv1(x))        
        x = torch.relu(self.conv2(x))   
        x = torch.relu(self.pool1(self.bn1(x)))

        x = torch.relu(self.conv3(x))        
        x = torch.relu(self.pool2(self.bn2(x)))
        x = self.classifier(x.view(-1, 128 * 5 * 5))

        return x

    
    def training_step(self, batch, batch_idx):
        x , y = batch
        logits = self(x)
        preds = torch.nn.functional.softmax(logits, dim=1)
        loss = torch.nn.CrossEntropyLoss()(logits, y)
        acc = (preds.argmax(1) == y).float().mean()

        self.log('train_acc', acc, prog_bar=True, on_step=True, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x , y = batch
        logits = self(x)
        preds = torch.nn.functional.softmax(logits, dim=1)
        val_loss = torch.nn.CrossEntropyLoss()(logits, y)
        val_acc = (preds.argmax(1) == y).float().mean()

        self.log('val_loss', val_loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log('val_acc', val_acc, prog_bar=True, on_step=False, on_epoch=True)

        return val_loss

    def configure_optimizers(self):
        opt = torch.optim.AdamW(lr=1e-3, params=self.parameters())
        return opt

# Training phase

In [6]:
dm = MNISTDataModule()

In [7]:
model = CNNDigitClassifier()

In [8]:
model

CNNDigitClassifier(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (classifier): Linear(in_features=3200, out_features=10, bias=True)
)

In [9]:
ckpt_cb = ModelCheckpoint(
    monitor='val_loss', 
    mode='min', 
    dirpath='./logs', 
    filename='digit_classifier-{val_acc:.5f}-{val_loss:.5f}'
)

gpu_stats = GPUStatsMonitor(
    memory_utilization=True, 
    gpu_utilization=True, 
    fan_speed=True, 
    temperature=True
)
es = EarlyStopping(
    monitor='val_loss', 
    patience=2, 
    mode='min'
)

Logger = TensorBoardLogger(
    save_dir='./logs', 
    name='mnist'
)

Callbacks = [es, ckpt_cb]

trainer = pl.Trainer(
    gpus=-1, 
    max_epochs=5, 
    # precision=16
    callbacks=Callbacks,
    logger=Logger,
    # fast_dev_run=True
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [10]:
trainer.fit(model=model, datamodule=dm)

[INFO] Training on 60000
[INFO] Validating on 10000



  | Name       | Type        | Params
-------------------------------------------
0 | conv1      | Conv2d      | 320   
1 | conv2      | Conv2d      | 18.5 K
2 | pool1      | MaxPool2d   | 0     
3 | bn1        | BatchNorm2d | 128   
4 | conv3      | Conv2d      | 73.9 K
5 | pool2      | MaxPool2d   | 0     
6 | bn2        | BatchNorm2d | 256   
7 | classifier | Linear      | 32.0 K
-------------------------------------------
125 K     Trainable params
0         Non-trainable params
125 K     Total params


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

1

In [12]:
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [13]:
%tensorboard --logdir ./logs