# Introduction to Pytorch Lightning ⚡
- https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01-mnist-hello-world.ipynb

In [1]:
!pip install pytorch-lightning --quiet

In [2]:
import time
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'available device: {device}')
torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed_all(777)

available device: cuda


In [3]:
class MyAccurateEye(pl.LightningModule):
    def __init__(self, learning_rate, batch_size=100, num_workers=8, data_dir='../dat/'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.learning_rate = learning_rate
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        self.dataset = dict()
        self.train_metric = pl.metrics.Accuracy()
        self.valid_metric = pl.metrics.Accuracy()
        self.test_metric = pl.metrics.Accuracy()

        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = torch.nn.ReLU()
        self.conv12 = torch.nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.relu12 = torch.nn.ReLU()
        self.pool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = torch.nn.ReLU()
        self.conv22 = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.relu22 = torch.nn.ReLU()
        self.pool2 = torch.nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.relu3 = torch.nn.ReLU()
        self.conv32 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.relu32 = torch.nn.ReLU()
        self.conv33 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.relu33 = torch.nn.ReLU()

        self.fc = torch.nn.Linear(7 * 7 * 128, 10, bias=True)
        self.fc_bn = torch.nn.BatchNorm1d(10)
        torch.nn.init.xavier_uniform_(self.fc.weight)

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu1(out)
        out = self.conv12(out)
        out = self.relu12(out)
        out = self.pool1(out)

        out = self.conv2(out)
        out = self.relu2(out)
        out = self.conv22(out)
        out = self.relu22(out)
        out = self.pool2(out)

        out = self.conv3(out)
        out = self.relu3(out)
        out = self.conv32(out)
        out = self.relu32(out)
        out = self.conv33(out)
        out = self.relu33(out)

        out = out.view(out.size(0), -1)
        out = self.fc(out)
        out = self.fc_bn(out)
        return out

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            data = MNIST(self.data_dir, train=True, transform=self.transform)
            self.dataset['train'], self.dataset['valid'] = \
                random_split(data, lengths=[int(len(data) * 0.9), len(data) - int(len(data) * 0.9)])
        if stage == 'test' or stage is None:
            data = MNIST(self.data_dir, train=False, transform=self.transform)
            self.dataset['test'] = data

    def train_dataloader(self):
        return DataLoader(self.dataset['train'], batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.dataset['valid'], batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.dataset['test'], batch_size=self.batch_size, num_workers=self.num_workers)

    def calc_loss(self, batch, accuracy: pl.metrics.Accuracy):
        x, y = batch
        pred = self(x)
        accuracy.update(pred, y)
        return F.cross_entropy(pred, y)

    def training_step(self, batch, batch_idx):
        return self.calc_loss(batch, self.train_metric)

    def validation_step(self, batch, batch_idx):
        self.calc_loss(batch, self.valid_metric)

    def test_step(self, batch, batch_idx):
        self.calc_loss(batch, self.test_metric)

    def training_epoch_end(self, outputs):
        show_accuracy('Train', self.train_metric)

    def validation_epoch_end(self, outputs):
        show_accuracy('Valid', self.valid_metric)

    def test_epoch_end(self, outputs):
        show_accuracy('Test', self.test_metric)

model = pl.Trainer(gpus=1, max_epochs=2, num_sanity_val_steps=0, progress_bar_refresh_rate=40)
def show_accuracy(name, acc):
    detail = f'(={acc.correct}/{acc.total})'
    metric = acc.compute()
    print(f'{name} Accuracy: {metric * 100:.2f}% {detail}')
    return metric

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [4]:
t0 = time.time()
model.fit(MyAccurateEye(learning_rate=0.001))
print(f"Train Time: {time.time() - t0:.3f}s")


   | Name         | Type        | Params
----------------------------------------------
0  | train_metric | Accuracy    | 0     
1  | valid_metric | Accuracy    | 0     
2  | test_metric  | Accuracy    | 0     
3  | conv1        | Conv2d      | 320   
4  | relu1        | ReLU        | 0     
5  | conv12       | Conv2d      | 9 K   
6  | relu12       | ReLU        | 0     
7  | pool1        | MaxPool2d   | 0     
8  | conv2        | Conv2d      | 18 K  
9  | relu2        | ReLU        | 0     
10 | conv22       | Conv2d      | 36 K  
11 | relu22       | ReLU        | 0     
12 | pool2        | MaxPool2d   | 0     
13 | conv3        | Conv2d      | 73 K  
14 | relu3        | ReLU        | 0     
15 | conv32       | Conv2d      | 147 K 
16 | relu32       | ReLU        | 0     
17 | conv33       | Conv2d      | 147 K 
18 | relu33       | ReLU        | 0     
19 | fc           | Linear      | 62 K  
20 | fc_bn        | BatchNorm1d | 20    


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Valid Accuracy: 99.05% (=5943/6000)
Train Accuracy: 97.74% (=52778/54000)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Valid Accuracy: 99.13% (=5948/6000)
Train Accuracy: 99.14% (=53537/54000)

Train Time: 14.684s


In [5]:
t0 = time.time()
model.test()
print(f"Test Time: {time.time() - t0:.3f}s")

HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

Test Accuracy: 99.22% (=9922/10000)
--------------------------------------------------------------------------------

Test Time: 0.610s
