## Introduction to Pytorch Lightning ⚡
  * https://nbviewer.jupyter.org/github/PyTorchLightning/pytorch-lightning/blob/master/notebooks/01-mnist-hello-world.ipynb

In [1]:
!pip install pytorch-lightning --quiet

import os
os.environ['CURRENT_FILE'] = 'MNIST-with-lightning.ipynb'
!date "+[%F %R:%S] [INIT] $CURRENT_FILE (on $CONDA_DEFAULT_ENV)"

import time
t0 = time.time()

[2020-10-29 14:32:09] [INIT] MNIST-with-lightning.ipynb (on lightn)


In [2]:
from typing import Any, List

import torch
from pytorch_lightning import Trainer, LightningModule, LightningDataModule
from pytorch_lightning.metrics import Accuracy
from torch import nn, optim, Tensor
from torch.nn.functional import cross_entropy
from torch.utils.data import random_split, DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST


def str_loss(loss: List[Tensor]):
    metric = torch.mean(torch.stack(loss))
    return f'{metric:.4f}'


def str_accuracy(acc: Accuracy, detail: bool = False):
    backup = acc.correct, acc.total
    metric = acc.compute()
    acc.correct, acc.total = backup
    return f'{metric * 100:.2f}%' if not detail else f'{metric * 100:.2f}%(={acc.correct}/{acc.total})'


torch.manual_seed(777)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(777)


class DataMNIST(LightningDataModule):
    def __init__(self, data_dir: str = '/dat/data/', batch_size: int = 100, num_workers: int = 8):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform = transforms.ToTensor()
        self.dataset = dict()

    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        self.dataset['train'], self.dataset['valid'] = random_split(MNIST(self.data_dir, train=True, transform=self.transform), [55000, 5000])
        self.dataset['test'] = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.dataset['train'], batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.dataset['valid'], batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.dataset['test'], batch_size=self.batch_size, num_workers=self.num_workers)


class ModelMNIST(LightningModule):
    def __init__(self, learning_rate, metric_detail=True):
        super().__init__()
        self.learning_rate = learning_rate
        self.metric_detail = metric_detail
        self.metric = {
            'train': {"loss": list(), "acc": Accuracy()},
            'valid': {"loss": list(), "acc": Accuracy()},
            'test': {"loss": list(), "acc": Accuracy()},
        }

        self.conv1A = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1A = nn.ReLU()
        self.conv1B = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.relu1B = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2A = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2A = nn.ReLU()
        self.conv2B = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.relu2B = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3A = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.relu3A = nn.ReLU()
        self.conv3B = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.relu3B = nn.ReLU()

        self.fc = nn.Linear(7 * 7 * 128, 10, bias=True)
        self.fc_bn = nn.BatchNorm1d(10)
        nn.init.xavier_uniform_(self.fc.weight)

    def forward(self, inp):
        out = self.conv1A(inp)
        out = self.relu1A(out)
        out = self.conv1B(out)
        out = self.relu1B(out)
        out = self.pool1(out)

        out = self.conv2A(out)
        out = self.relu2A(out)
        out = self.conv2B(out)
        out = self.relu2B(out)
        out = self.pool2(out)

        out = self.conv3A(out)
        out = self.relu3A(out)
        out = self.conv3B(out)
        out = self.relu3B(out)

        out = out.view(out.size(0), -1)
        out = self.fc(out)
        out = self.fc_bn(out)
        return out

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)

    def training_step(self, batch: List[Tensor], batch_idx: int):
        inputs: Tensor = batch[0]
        labels: Tensor = batch[1]
        logits: Tensor = self(inputs)
        loss = cross_entropy(logits, labels)
        logits = logits.detach().cpu()
        labels = labels.detach().cpu()
        self.metric['train']['acc'].update(preds=logits, target=labels)
        self.metric['train']['loss'].append(loss.detach().cpu())
        return {'loss': loss, 'logits': logits, 'labels': labels}

    def validation_step(self, batch: List[Tensor], batch_idx: int):
        inputs: Tensor = batch[0]
        labels: Tensor = batch[1]
        logits: Tensor = self(inputs)
        loss = cross_entropy(logits, labels)
        logits = logits.detach().cpu()
        labels = labels.detach().cpu()
        self.metric['valid']['acc'].update(preds=logits, target=labels)
        self.metric['valid']['loss'].append(loss.detach().cpu())
        return {'loss': loss, 'logits': logits, 'labels': labels}

    def test_step(self, batch: List[Tensor], batch_idx: int):
        inputs: Tensor = batch[0]
        labels: Tensor = batch[1]
        logits: Tensor = self(inputs)
        loss = cross_entropy(logits, labels)
        logits = logits.detach().cpu()
        labels = labels.detach().cpu()
        self.metric['test']['acc'].update(preds=logits, target=labels)
        self.metric['test']['loss'].append(loss.detach().cpu())
        return {'loss': loss, 'logits': logits, 'labels': labels}

    def test_epoch_end(self, outputs: List[Any]):
        pass

    def on_epoch_start(self):
        for k in self.metric.keys():
            self.metric[k]['loss'] = list()
            self.metric[k]['acc'].reset()

    def on_epoch_end(self):
        print()
        print(f"| Loss     | {{"
              f" train: {str_loss(self.metric['train']['loss'])},"
              f" valid: {str_loss(self.metric['valid']['loss'])} }}")
        print(f"| Accuracy | {{"
              f" train: {str_accuracy(self.metric['train']['acc'], self.metric_detail)},"
              f" valid: {str_accuracy(self.metric['valid']['acc'], self.metric_detail)} }}")
        print("=" * 5 + f" [DONE] [Epoch {self.current_epoch + 1}/{self.trainer.max_epochs}] " + "=" * 70)
        print()

    def on_test_epoch_end(self):
        print()
        print(f"| Loss     | {{"
              f" test: {str_loss(self.metric['test']['loss'])},"
              f" valid: {str_loss(self.metric['valid']['loss'])} }}")
        print(f"| Accuracy | {{"
              f" test: {str_accuracy(self.metric['test']['acc'], self.metric_detail)},"
              f" valid: {str_accuracy(self.metric['valid']['acc'], self.metric_detail)} }}")
        print("=" * 5 + f" [DONE] [Test Epoch] " + "=" * 70)
        print()


trainer = Trainer(max_epochs=5, num_sanity_val_steps=0, progress_bar_refresh_rate=20, gpus=1)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [3]:
t = time.time()
trainer.fit(model=ModelMNIST(learning_rate=0.001), datamodule=DataMNIST())
print(f"* Train Time: {time.time() - t:.3f}s")


   | Name   | Type        | Params
----------------------------------------
0  | conv1A | Conv2d      | 320   
1  | relu1A | ReLU        | 0     
2  | conv1B | Conv2d      | 9 K   
3  | relu1B | ReLU        | 0     
4  | pool1  | MaxPool2d   | 0     
5  | conv2A | Conv2d      | 18 K  
6  | relu2A | ReLU        | 0     
7  | conv2B | Conv2d      | 36 K  
8  | relu2B | ReLU        | 0     
9  | pool2  | MaxPool2d   | 0     
10 | conv3A | Conv2d      | 73 K  
11 | relu3A | ReLU        | 0     
12 | conv3B | Conv2d      | 147 K 
13 | relu3B | ReLU        | 0     
14 | fc     | Linear      | 62 K  
15 | fc_bn  | BatchNorm1d | 20    


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…


| Loss     | { train: 0.3325, valid: 0.1345 }
| Accuracy | { train: 97.26%(=53491/55000), valid: 99.10%(=4955/5000) }


| Loss     | { train: 0.1270, valid: 0.0816 }
| Accuracy | { train: 99.11%(=54511/55000), valid: 99.36%(=4968/5000) }


| Loss     | { train: 0.0731, valid: 0.0550 }
| Accuracy | { train: 99.41%(=54674/55000), valid: 99.32%(=4966/5000) }


| Loss     | { train: 0.0466, valid: 0.0380 }
| Accuracy | { train: 99.63%(=54796/55000), valid: 99.34%(=4967/5000) }


| Loss     | { train: 0.0315, valid: 0.0342 }
| Accuracy | { train: 99.78%(=54879/55000), valid: 99.22%(=4961/5000) }


* Train Time: 34.975s


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

In [4]:
t = time.time()
trainer.test()
print(f"* Test Time: {time.time() - t:.3f}s")

HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------

| Loss     | { test: 0.0285, valid: 0.0342 }
| Accuracy | { test: 99.42%(=9942/10000), valid: 99.22%(=4961/5000) }


* Test Time: 0.742s


In [5]:
os.environ['ELASPED_TIME'] = f"{time.time() - t0:.3f}s"
!date "+[%F %R:%S] [EXIT] $CURRENT_FILE (on $CONDA_DEFAULT_ENV) (in $ELASPED_TIME)"


[2020-10-29 14:32:47] [EXIT] MNIST-with-lightning.ipynb (on lightn) (in 38.140s)
