# PyTorch Lightning example on MNIST

## References

* [Step-by-step walk-through — PyTorch Lightning 1.5.0dev documentation](https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html)
* [Introduction to Pytorch Lightning — PyTorch Lightning 1.5.0dev documentation](https://pytorch-lightning.readthedocs.io/en/latest/notebooks/lightning_examples/mnist-hello-world.html)
* [Mastering-PyTorch/pytorch_lightning.ipynb at master · PacktPublishing/Mastering-PyTorch](https://github.com/PacktPublishing/Mastering-PyTorch/blob/master/Chapter14/pytorch_lightning.ipynb)

---

* [torch.nn.functional](https://pytorch.org/docs/stable/nn.functional.html)
* [torchvision.datasets](https://pytorch.org/vision/stable/datasets.html)
* [torchvision.transforms](https://pytorch.org/vision/stable/transforms.html)
* [PyTorch Lightning Documentation — PyTorch Lightning 1.5.0dev documentation](https://pytorch-lightning.readthedocs.io/en/latest/)

In [1]:
import os
import argparse

import torch
import torch.nn.functional as F
import torchvision
import pytorch_lightning as pl

# For type hinting
from typing import Union, Dict, List, Any
from torch import Tensor

## References

* [pytorch_lightning.LightningModule](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html)
    - PyTorch のコードを 5つのセクションに整理する
        1. Computations (init)
        2. Train loop (training_step)
        3. Validation loop (validation_step)
        4. Test loop (test_step)
        5. Optimizers (configure_optimizers)
    - 特徴は以下の通り
        1. PyTorch Lightning のコードは PyTorch のコードと同じである。
        2. PyTorch のコードが抽象化されるわけではなく、整理される。
        3. `LightningModule` に無いコードは `Trainer` によって自動化されている。
        4. Lightning が処理をするため `.cuda()` や `.to()` といったコールは不要。
        5. `DataLoader` において、デフォルトでは `DistributedSampler` が設定される。
        6. `LightningModule` は `torch.nn.Module` の一つであり、それに機能を追加したもの。

### `__init__`

* [torch.nn.Conv2d](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d)
* [torch.nn.MaxPool2d](https://pytorch.org/docs/1.9.0/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d)
* [torch.nn.Dropout2d](https://pytorch.org/docs/stable/generated/torch.nn.Dropout2d.html#torch.nn.Dropout2d)
* [torch.nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear)

### `forward`

順伝播型ネットワークの定義および処理の実行

* [forward](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html?highlight=forward#forward)
    - [torch.nn.forward](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.forward)
* [torch.nn.functional.relu](https://pytorch.org/docs/1.9.0/generated/torch.nn.functional.relu.html)
    - [torch.nn.ReLU](https://pytorch.org/docs/1.9.0/generated/torch.nn.ReLU.html#torch.nn.ReLU)
    - 正規化線形関数を適用
* [torch.flatten](https://pytorch.org/docs/stable/generated/torch.flatten.html)
    - 次元を 1 + `start_dim` 次元に変更
* [torch.nn.functional.log_softmax](https://pytorch.org/docs/1.9.0/generated/torch.nn.functional.log_softmax.html#torch.nn.functional.log_softmax)
    - [torch.nn.LogSoftmax](https://pytorch.org/docs/1.9.0/generated/torch.nn.LogSoftmax.html#torch.nn.LogSoftmax)
    - ソフトマックス関数（活性化関数）を適用

---

### Training loop

#### `training_step`

学習を実行し、損失を返却する。 `train_dataloader` から学習用データを取得し、バッチごとに処理をする。処理には順伝播，勾配の最適化，逆伝播，パラメータの最適化が含まれる。
GPUごとに処理をしたい場合は、これに加えて `training_step_end` をオーバーライドし、各GPUを利用した `training_step` の結果を結合する処理を記述する。

* [training_step](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#training-step)
* [torch.nn.functional.cross_entropy](https://pytorch.org/docs/1.9.0/generated/torch.nn.functional.cross_entropy.html?highlight=cross_entropy#torch.nn.functional.cross_entropy)
    - [torch.nn.CrossEntropyLoss](https://pytorch.org/docs/1.9.0/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss)
* [LightningModule.log](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#log)

#### `training_step_end`

#### `training_epoch_end`

---

### Validation loop

#### `validation_step`

バリデーションを実行する。 `validation_epoch_end` への入力として集約したい値を返却する。ここでは汎化性能を確認するため、損失を返却する。

* [validation_step](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#validation-step)

#### `validation_step_end`

* [validation_step_end](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#validation-step-end)

#### `validation_epoch_end`

バリデーションのエポック終了時にコールされる。すべての `validation_step` の出力を入力として受け取る。

* [validation_epoch_end](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#validation-epoch-end)
* [torch.stack](https://pytorch.org/docs/1.9.0/generated/torch.stack.html?highlight=stack#torch.stack)

---

### Test loop

#### `test_step`

テストを実行する。`test_epoch_end` への入力として集約したい値を返却する。ここでは汎化性能を確認するため、損失を返却する。
テストループは `pytorch_lightning.Trainer.test(model)` が実行された場合のみ、実行される。

* [test_step](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#test-step)

#### `test_step_end`

* [test_step_end](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#test-step-end)

#### `test_epoch_end`

テストのエポック終了時にコールされる。すべての `test_step` の出力を入力として受け取る。

* [test_epoch_end](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#test-epoch-end)

---

### Prediction

#### `predict_step`

`pytorch_lightning.Trainer.predict(model)` が実行された場合のみ、実行される。

* [predict_step](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#predict-step)

---

### Optimizer

#### `configure_optimizers`

最適化のために使用するオプティマイザを選択する。

* [configure_optimizers](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers)
* [torch.optim.Adadelta](https://pytorch.org/docs/1.9.0/generated/torch.optim.Adadelta.html?highlight=adadelta#torch.optim.Adadelta)
    - [[1212.5701] ADADELTA: An Adaptive Learning Rate Method](https://arxiv.org/abs/1212.5701)

---

### DataLoaders

#### `train_dataloader`

* [train_dataloader](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#train-dataloader)
* [torchvision.transforms.ToTensor](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.ToTensor)
* [torchvision.transforms.Normalize](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Normalize)
* [torchvision.transforms.Compose](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Compose)
* [torchvision.datasets.MNIST](https://pytorch.org/vision/stable/datasets.html#mnist)
* [torch.utils.data.DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)

#### `val_dataloader`

* [val_dataloader](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#val-dataloader)

#### `test_dataloader`

* [test_dataloader](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#test-dataloader)

---

### Others


In [2]:
class MNISTConvNet(pl.LightningModule):
    def __init__(self,
            conv1_in_channels: int, conv1_out_channels: int, conv1_kernel_size: int, conv1_stride: int,
            conv2_in_channels: int, conv2_out_channels: int, conv2_kernel_size: int, conv2_stride: int,
            pool1_kernel_size: int, dropout1_p: float, dropout2_p: float,
            fullconn1_in_features: int, fullconn1_out_features: int, fullconn2_in_features: int, fullconn2_out_features: int,
            adadelta_lr: float, adadelta_rho: float, adadelta_eps: float, adadelta_weight_decay: float,
            dataset_root: str, dataset_download: bool,
            dataloader_mean: tuple, dataloader_std: tuple, dataloader_batch_size: int, dataloader_num_workers: int
            ) -> None:
        super(MNISTConvNet, self).__init__()
        
        self.conv1 = torch.nn.Conv2d(in_channels=conv1_in_channels, out_channels=conv1_out_channels, kernel_size=conv1_kernel_size, stride=conv1_stride)
        self.conv2 = torch.nn.Conv2d(in_channels=conv2_in_channels, out_channels=conv2_out_channels, kernel_size=conv2_kernel_size, stride=conv2_stride)
        self.pool1 = torch.nn.MaxPool2d(kernel_size=pool1_kernel_size)
        self.dropout1 = torch.nn.Dropout2d(p=dropout1_p, inplace=False)
        self.dropout2 = torch.nn.Dropout2d(p=dropout2_p, inplace=False)
        self.fullconn1 = torch.nn.Linear(in_features=fullconn1_in_features, out_features=fullconn1_out_features)
        self.fullconn2 = torch.nn.Linear(in_features=fullconn2_in_features, out_features=fullconn2_out_features)

        self.adadelta_params = {
            'lr': adadelta_lr,
            'rho': adadelta_rho,
            'eps': adadelta_eps,
            'weight_decay': adadelta_weight_decay,
        }

        self.dataset_params = {
            'root': dataset_root,
            'download': dataset_download,
        }

        self.dataloader_params = {
            'mean': dataloader_mean,
            'std': dataloader_std,
            'batch_size': dataloader_batch_size,
            'num_workers': dataloader_num_workers,
        }

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = F.relu(input=x)
        x = self.conv2(x)
        x = F.relu(input=x)
        x = self.pool1(x)
        x = self.dropout1(x)
        x = torch.flatten(input=x, start_dim=1)
        x = self.fullconn1(x)
        x = F.relu(input=x)
        x = self.dropout2(x)
        x = self.fullconn2(x)
        return F.log_softmax(input=x, dim=1)
    
    def _common_step(self, batch: Any, log_name: str,
            log_on_step: Any = None, log_on_epoch: Any = None, log_prog_bar: bool = False
            ) -> Union[Tensor, Dict[str, Any], None]:
        x, y = batch
        y_pred = self(x)
        loss = F.cross_entropy(input=y_pred, target=y)

        self.log(name=log_name, value=loss, prog_bar=log_prog_bar, on_step=log_on_step, on_epoch=log_on_epoch)

        return loss


    # Training loop

    # def on_train_epoch_start(self) -> None:
    #     return super().on_train_epoch_start()

    # def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
    #     return super().on_train_batch_start(batch, batch_idx, dataloader_idx)

    def training_step(self, batch: Any, batch_idx: int) -> Union[Tensor, Dict[str, Any]]:
        return self._common_step(batch=batch, log_name="train_loss", log_prog_bar=True, log_on_epoch=True)
    
    # def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
    #     return super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)

    # def training_step_end(self, *args, **kwargs) -> Any:
    #     return super().training_step_end(*args, **kwargs)
    
    # def training_epoch_end(self, outputs: Any) -> None:
    #     return super().training_epoch_end(outputs)

    # def on_train_epoch_end(self) -> None:
    #     return super().on_train_epoch_end()


    # Validation loop

    # def on_validation_epoch_start(self) -> None:
    #     return super().on_validation_epoch_start()

    # def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
    #     return super().on_validation_batch_start(batch, batch_idx, dataloader_idx)

    def validation_step(self, batch: Any, batch_idx: int) -> Union[Tensor, Dict[str, Any], None]:
        return self._common_step(batch=batch, log_name="val_loss", log_prog_bar=True, log_on_step=True, log_on_epoch=True)
    
    # def on_validation_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None:
    #     return super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx)

    # def validation_step_end(self, *args, **kwargs) -> Optional[Any]:
    #     return super().validation_step_end(*args, **kwargs)

    def validation_epoch_end(self, outputs: List[Union[Tensor, Dict[str, Any]]] ) -> None:
        return torch.stack(tensors=outputs).mean() # NOTE: Average loss
    
    # def on_validation_epoch_end(self) -> None:
    #     return super().on_validation_epoch_end()


    # Test loop

    # def on_test_epoch_start(self) -> None:
    #     return super().on_test_epoch_start()

    # def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
    #     return super().on_test_batch_start(batch, batch_idx, dataloader_idx)

    def test_step(self, batch: Any, batch_idx: int) -> Union[Tensor, Dict[str, Any], None]:
        return self._common_step(batch=batch, log_name="test_loss", log_prog_bar=True, log_on_step=True, log_on_epoch=True)
    
    # def on_test_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None:
    #     return super().on_test_batch_end(outputs, batch, batch_idx, dataloader_idx)

    # def test_step_end(self, *args, **kwargs) -> Optional[Any]:
    #     return super().test_step_end(*args, **kwargs)
    
    def test_epoch_end(self, outputs: List[Union[Tensor, Dict[str, Any]]] ) -> None:
        return torch.stack(tensors=outputs).mean() # NOTE: Average loss
    
    # def on_test_epoch_end(self) -> None:
    #     return super().on_test_epoch_end()
    

    # Prediction

    # def on_predict_epoch_start(self) -> None:
    #     return super().on_predict_epoch_start()

    # def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
    #     return super().on_predict_batch_start(batch, batch_idx, dataloader_idx)

    # def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]) -> Any:
    #     return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)
    
    # def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None:
    #     return super().on_predict_batch_end(outputs, batch, batch_idx, dataloader_idx)
    
    # def on_predict_epoch_end(self, results: List[Any]) -> None:
    #     return super().on_predict_epoch_end(results)


    # Optimizer

    def configure_optimizers(self) -> Any:
        return torch.optim.Adadelta(params=self.parameters(),
            lr=self.adadelta_params['lr'],
            rho=self.adadelta_params['rho'],
            eps=self.adadelta_params['eps'],
            weight_decay=self.adadelta_params['weight_decay'])
    
    # def on_before_optimizer_step(self, optimizer: Any, optimizer_idx: int) -> None:
    #     return super().on_before_optimizer_step(optimizer, optimizer_idx)

    # def optimizer_step(self, epoch: int, batch_idx: int, optimizer: Any, optimizer_idx: int, optimizer_closure: Optional[Any], on_tpu: bool, using_native_amp: bool, using_lbfgs: bool) -> None:
    #     return super().optimizer_step(epoch=epoch, batch_idx=batch_idx, optimizer=optimizer, optimizer_idx=optimizer_idx, optimizer_closure=optimizer_closure, on_tpu=on_tpu, using_native_amp=using_native_amp, using_lbfgs=using_lbfgs)


    # Dataloaders

    def _get_dataloader(self, train: bool) -> Any:
        transform_objects = [
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=self.dataloader_params['mean'], std=self.dataloader_params['std'])
        ]
        transform = torchvision.transforms.Compose(transforms=transform_objects)
        dataset = torchvision.datasets.MNIST(root=self.dataset_params['root'],
            train=train,
            download=self.dataset_params['download'],
            transform=transform)
        dataloader = torch.utils.data.DataLoader(dataset=dataset,
            batch_size=self.dataloader_params['batch_size'],
            num_workers=self.dataloader_params['num_workers'])
        return dataloader

    # def on_train_dataloader(self) -> None:
    #     return super().on_train_dataloader()

    def train_dataloader(self) -> Any:
        return self._get_dataloader(train=True)
    
    # def on_val_dataloader(self) -> None:
    #     return super().on_val_dataloader()

    def val_dataloader(self) -> Any:
        return self._get_dataloader(train=True)
    
    # def on_test_dataloader(self) -> None:
    #     return super().on_test_dataloader()

    def test_dataloader(self) -> Any:
        return self._get_dataloader(train=False)
    
    # def on_predict_dataloader(self) -> None:
    #     return super().on_predict_dataloader()

    # def predict_dataloader(self) -> Any:
    #     return super().predict_dataloader()


    # Others

    # def setup(self, stage: Optional[str]) -> None:
    #     return super().setup(stage=stage)

    # def teardown(self, stage: Optional[str]) -> None:
    #     return super().teardown(stage=stage)
    
    # def prepare_data(self) -> None:
    #     return super().prepare_data()


In [3]:
def get_argparser():
    parser = argparse.ArgumentParser(description='PyTorch Lightning MNIST Example')
    parser.add_argument('--conv1-in-channels', type=int, default=1)
    parser.add_argument('--conv1-out-channels', type=int, default=32)
    parser.add_argument('--conv1-kernel-size', type=int, default=3)
    parser.add_argument('--conv1-stride', type=int, default=1)
    parser.add_argument('--conv2-in-channels', type=int, default=32)
    parser.add_argument('--conv2-out-channels', type=int, default=64)
    parser.add_argument('--conv2-kernel-size', type=int, default=3)
    parser.add_argument('--conv2-stride', type=int, default=1)
    parser.add_argument('--pool1-kernel-size', type=int, default=2)
    parser.add_argument('--dropout1-p', type=float, default=0.25)
    parser.add_argument('--dropout2-p', type=float, default=0.5)
    parser.add_argument('--fullconn1-in-features', type=int, default=12*12*64)
    parser.add_argument('--fullconn1-out-features', type=int, default=128)
    parser.add_argument('--fullconn2-in-features', type=int, default=128)
    parser.add_argument('--fullconn2-out-features', type=int, default=10)
    parser.add_argument('--adadelta-lr', type=float, default=1.0)
    parser.add_argument('--adadelta-rho', type=float, default=0.9)
    parser.add_argument('--adadelta-eps', type=float, default=1e-06)
    parser.add_argument('--adadelta-weight-decay', type=float, default=0)
    parser.add_argument('--dataset-root', type=str, default=os.getcwd())
    parser.add_argument('--dataset-download', action='store_true', default=True)
    parser.add_argument('--dataloader-mean', type=tuple, default=(0.1302,))
    parser.add_argument('--dataloader-std', type=tuple, default=(0.3069,))
    parser.add_argument('--dataloader-batch-size', type=int, default=32)
    parser.add_argument('--dataloader-num-workers', type=int, default=4)
    parser.add_argument('--progress-bar-refresh-rate', type=int, default=20)
    parser.add_argument('--max-epochs', type=int, default=1)
    return parser

## References

* [pytorch_lightning.Trainer](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html?highlight=trainer)
    - [fit](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#fit)

In [4]:
def main(args=None) -> None:
    if not args:
        args = get_argparser().parse_args()
    
    model = MNISTConvNet(conv1_in_channels=args.conv1_in_channels, conv1_out_channels=args.conv1_out_channels, conv1_kernel_size=args.conv1_kernel_size, conv1_stride=args.conv1_stride,
        conv2_in_channels=args.conv2_in_channels, conv2_out_channels=args.conv2_out_channels, conv2_kernel_size=args.conv2_kernel_size, conv2_stride=args.conv2_stride,
        pool1_kernel_size=args.pool1_kernel_size, dropout1_p=args.dropout1_p, dropout2_p=args.dropout2_p,
        fullconn1_in_features=args.fullconn1_in_features, fullconn1_out_features=args.fullconn1_out_features, fullconn2_in_features=args.fullconn2_in_features, fullconn2_out_features=args.fullconn2_out_features,
        adadelta_lr=args.adadelta_lr, adadelta_rho=args.adadelta_rho, adadelta_eps=args.adadelta_eps, adadelta_weight_decay=args.adadelta_weight_decay,
        dataset_root=args.dataset_root, dataset_download=args.dataset_download,
        dataloader_mean=args.dataloader_mean, dataloader_std=args.dataloader_std, dataloader_batch_size=args.dataloader_batch_size, dataloader_num_workers=args.dataloader_num_workers)
    
    trainer = pl.Trainer(progress_bar_refresh_rate=args.progress_bar_refresh_rate, max_epochs=args.max_epochs)
    trainer.fit(model)
    trainer.test(model)
    # trainer.predict(model)

In [5]:
if __name__ == "__main__":
    argparser = get_argparser()
    args = argparser.parse_args([
        "--max-epochs", str(1),
        "--adadelta-lr", str(0.5),
    ])
    main(args)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name      | Type      | Params
----------------------------------------
0 | conv1     | Conv2d    | 320   
1 | conv2     | Conv2d    | 18.5 K
2 | pool1     | MaxPool2d | 0     
3 | dropout1  | Dropout2d | 0     
4 | dropout2  | Dropout2d | 0     
5 | fullconn1 | Linear    | 1.2 M 
6 | fullconn2 | Linear    | 1.3 K 
----------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.800     Total estimated model params size (MB)


Epoch 0: 100%|██████████| 3750/3750 [01:41<00:00, 37.06it/s, loss=nan, v_num=54, train_loss_step=nan.0, val_loss_step=nan.0, val_loss_epoch=nan.0]
Testing: 100%|██████████| 313/313 [00:04<00:00, 80.02it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': nan, 'test_loss_epoch': nan}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 313/313 [00:04<00:00, 75.81it/s]
