# 1. `LightningModule`

A LightningModule organizes your PyTorch code into 6 sections:
- Computations (init).
- Train Loop (training_step)
- Validation Loop (validation_step)
- Test Loop (test_step)
- Prediction Loop (predict_step)
- Optimizers and LR Schedulers (configure_optimizers)

The LightningModule has many convenience methods, but the core ones you need to know about are:
|Name|Description|
|--|--|
|init|Define computations here|
|forward|Use for inference only (separate from training_step)|
|training_step|the complete training loop|
|validation_step|the complete validation loop|
|test_step|the complete test loop|
|predict_step|the complete prediction loop|
|configure_optimizers|define optimizers and LR schedulers|
|||
|||

In [1]:
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule

## 1.1 Define the basic model

In [2]:
class LitMNIST(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(28 * 28, 128)
        self.layer_2 = nn.Linear(128, 256)
        self.layer_3 = nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, height, width = x.size()

        x = x.view(batch_size, -1)
        x = F.relu(self.layer_1(x))
        x = F.relu(self.layer_2(x))
        x = F.log_softmax(self.layer_3(x), dim=1)
        return x

net = LitMNIST()
x = torch.randn(1, 1, 28, 28)
out = net(x)
print(out.shape)

torch.Size([1, 10])


## 1.2 Add `training_step`

In [3]:
class LitMNIST(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(28 * 28, 128)
        self.layer_2 = nn.Linear(128, 256)
        self.layer_3 = nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, height, width = x.size()

        x = x.view(batch_size, -1)
        x = F.relu(self.layer_1(x))
        x = F.relu(self.layer_2(x))
        x = F.log_softmax(self.layer_3(x), dim=1)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

## 1.3 Add `configure_optimizers`

In [4]:
from torch.optim import Adam

class LitMNIST(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(28 * 28, 128)
        self.layer_2 = nn.Linear(128, 256)
        self.layer_3 = nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, height, width = x.size()

        x = x.view(batch_size, -1)
        x = F.relu(self.layer_1(x))
        x = F.relu(self.layer_2(x))
        x = F.log_softmax(self.layer_3(x), dim=1)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss
    
    def configure_optimizers(self):
        # 因为LightningModule是Module的子类，
        # 所以可以用self.parmeters()直接访问
        return Adam(self.parameters(), lr=1e-3)


### 几种情形：
#### 1⃣️. most cases. no learning rate scheduler
```python
def configure_optimizers(self):
    return Adam(self.parameters(), lr=1e-3)
```
#### 2⃣️. multiple optimizer case (e.g.: GAN)
```python
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    return gen_opt, dis_opt
```
#### 3⃣️. example with learning rate schedulers
```python
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    dis_sch = CosineAnnealing(dis_opt, T_max=10)
    return [gen_opt, dis_opt], [dis_sch]
```
#### 4⃣️. example with step-based learning rate schedulers and each optimizer has its own scheduler
```python
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    gen_sch = {
        'scheduler': ExponentialLR(gen_opt, 0.99),
        'interval': 'step'  # called after each training step
    }
    dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch
    return [gen_opt, dis_opt], [gen_sch, dis_sch]
```
#### 5⃣️. example with optimizer frequencies. see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1 https://arxiv.org/abs/1704.00028
```python
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    n_critic = 5
    return (
        {'optimizer': dis_opt, 'frequency': n_critic},
        {'optimizer': gen_opt, 'frequency': 1}
    )
```

# 2. Data

## 2.1 Pytorch Dataloader方式

In [None]:
import os
import sys
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
from pytorch_lightning import Trainer

# transforms
# prepare transforms standard to MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# data
from pathlib import Path
path_root = os.getcwd()
mnist_train = MNIST(os.path.join(str(path_root),'dataset'), train=True, download=True, transform=transform)
mnist_train = DataLoader(mnist_train, batch_size=64)

Pass in the dataloaders to the .fit() function directly

In [13]:
model = LitMNIST()
trainer = Trainer()
trainer.fit(model, mnist_train)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Missing logger folder: /Users/baixiang/Desktop/pytorch_lightning_learning/lightning_logs

  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K 
1 | layer_2 | Linear | 33.0 K
2 | layer_3 | Linear | 2.6 K 
-----------------------------------
136 K     Trainable params
0         Non-trainable params
136 K     Total params
0.544     Total estimated model params size (MB)
  rank_zero_warn(


Epoch 81:   9%|▊         | 81/938 [00:02<00:31, 27.54it/s, loss=0.00642, v_num=0]  

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


Epoch 81:   9%|▊         | 81/938 [00:13<02:27,  5.81it/s, loss=0.00642, v_num=0]

## 2.2