<a href="https://colab.research.google.com/github/bbinibini/Pytorch-Tutorial/blob/main/Pytorch_Lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pytorch Lightning

PyTorch Lightning은 크게 2가지 영역으로 추상화하여, 코드 스타일의 혁신을 추구하고 있는데요.
이 2가지 영역의 핵심 요소, LightningModule과 Trainer에 대해 더 자세히 살펴보도록 하겠습니다.

<코드 출처>
* [Pytorch-Lightning-tutorial-1](https://baeseongsu.github.io/posts/pytorch-lightning-introduction/)  
* [Pytorch-Lightning-tutorial-2](https://www.secmem.org/blog/2021/01/07/pytorch-lightning-tutorial/)

## LightningModule 클래스

#### 1) 모델의 기본적인 구조정의

In [14]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl

In [24]:
class LightningMNISTClassifier(pl.LightningModule):
  def __init__(self):
    super(LightningMNISTClassifier, self).__init__()

    # mnist images are (1, 28, 28) (channels, width, height) 
    self.layer_1 = torch.nn.Linear(28 * 28, 128)
    self.layer_2 = torch.nn.Linear(128, 256)
    self.layer_3 = torch.nn.Linear(256, 10)

  def forward(self, x):
    batch_size, channels, width, height = x.size()

    # (b, 1, 28, 28) -> (b, 1*28*28)
    x = x.view(batch_size, -1)

    # layer 1 (b, 1*28*28) -> (b, 128)
    x = self.layer_1(x)
    x = torch.relu(x)

    # layer 2 (b, 128) -> (b, 256)
    x = self.layer_2(x)
    x = torch.relu(x)

    # layer 3 (b, 256) -> (b, 10)
    x = self.layer_3(x)

    # probability distribution over labels
    x = torch.log_softmax(x, dim=1)

    return x

  def cross_entropy_loss(self, logits, labels):
    return F.nll_loss(logits, labels) 

  def training_step(self, train_batch, batch_idx):
    x, y = train_batch
    logits = self.forward(x)
    loss = self.cross_entropy_loss(logits, y)

    logs = {'train_loss': loss}
    return {'loss': loss, 'log': logs}

  def validation_step(self, val_batch, batch_idx):
    x, y = val_batch
    logits = self.forward(x)
    loss = self.cross_entropy_loss(logits, y)
    return {'val_loss': loss}

  def validation_epoch_end(self, outputs):
    # called at the end of the validation epoch
    # outputs is an array with what you returned in validation_step for each batch
    # outputs = [{'loss': batch_0_loss}, {'loss': batch_1_loss}, ..., {'loss': batch_n_loss}] 
    avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
    tensorboard_logs = {'val_loss': avg_loss}
    return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

  def prepare_data(self):
    # transforms for images
    transform=transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.1307,), (0.3081,))])
        
    # prepare transforms standard to MNIST
    mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
    mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
      
    self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=0.02)

  def train_dataloader(self):
    return DataLoader(self.mnist_train, batch_size=64)

  def val_dataloader(self):
    return DataLoader(self.mnist_val, batch_size=64)

  def test_dataloader(self):
    return DataLoader(self,mnist_test, batch_size=64)


#### 2) 모델 학습 루프
* (Training, validation, test loop) * (___step(스텝마다), ___step_end(스텝 종료), ___epoch_end(1 epoch 종료))
* 해당되는 이름에 루프 패턴을 붙여서 정의

In [17]:
  # def training_step(self, train_batch, batch_idx):
  #   x, y = train_batch
  #   logits = self.forward(x)
  #   loss = self.cross_entropy_loss(logits, y)

  #   logs = {'train_loss': loss}
  #   return {'loss': loss, 'log': logs}

  # def validation_step(self, val_batch, batch_idx):
  #   x, y = val_batch
  #   logits = self.forward(x)
  #   loss = self.cross_entropy_loss(logits, y)
  #   return {'val_loss': loss}

  # def validation_epoch_end(self, outputs):
  #   # called at the end of the validation epoch
  #   # outputs is an array with what you returned in validation_step for each batch
  #   # outputs = [{'loss': batch_0_loss}, {'loss': batch_1_loss}, ..., {'loss': batch_n_loss}] 
  #   avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
  #   tensorboard_logs = {'val_loss': avg_loss}
  #   return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

#### 3) 데이터 준비
* PyTorch의 데이터 준비하는 과정을 크게 5가지 형태로 구조화
  * 1) 다운로드
  * 2) 데이터 정리 혹은 메모리 저장
  * 3) 데이터셋 로드
  * 4) 데이터 전처리 (특히, transforms를 의미)
  * 5) dataloader 형태로 wrapping
* 이에 맞게 코드를 추상화
  * prepare_data()
  * train_dataloader, val_dataloader, test_dataloader

In [7]:
# def prepare_data(self):
#   # transforms for images
#   transform=transforms.Compose([transforms.ToTensor(), 
#                               transforms.Normalize((0.1307,), (0.3081,))])
      
#   # prepare transforms standard to MNIST
#   mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
#   mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
    
#   self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])

#   def train_dataloader(self):
#     return DataLoader(self.mnist_train, batch_size=64)

#   def val_dataloader(self):
#     return DataLoader(self.mnist_val, batch_size=64)

#   def test_dataloader(self):
#     return DataLoader(self,mnist_test, batch_size=64)

In [8]:
# def configure_optimizers(self):
#   return torch.optim.Adam(self.parameters(), lr=0.02)

LightningModule 클래스는 위와 같은 함수들을 순서에 따라 실행하는데, 이를 바로 **Lifecycle**이라고 부릅니다. (즉, 해당하는 순서에 따라 함수를 작성하는 것이 중요합니다.)

1.   `__init__`
2.   `prepare_data`
3.   `configure_optimizers`
4.   `train_dataloader`
5.   `val_dataloader`
6.   `test_dataloader` (`.test()`가 호출될 때 호출)

또한, 각 배치와 에폭마다 루프 메소드는 함수 이름에 맞게 정해진 순서대로 호출됩니다.

* `validation_step` : 배치마다 실행
* `validation_epoch_end` : 에폭마다 실행

## Tainer 클래스

#### 기본사용

In [25]:
from pytorch_lightning import Trainer
  
model = LightningMNISTClassifier()
  
trainer = Trainer()
trainer.fit(model)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /content/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /content/MNIST/raw/train-images-idx3-ubyte.gz to /content/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /content/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /content/MNIST/raw/train-labels-idx1-ubyte.gz to /content/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /content/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /content/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /content/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /content/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K 
1 | layer_2 | Linear | 33.0 K
2 | layer_3 | Linear | 2.6 K 
-----------------------------------
136 K     Trainable params
0         Non-trainable params
136 K     Total params
0.544     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]

  f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


#### main.py 로 작성 시

In [26]:
# from argparse import ArgumentParser
  
# def main(hparams):
#     model = LightningModule()
#     trainer = Trainer(gpus=hparams.gpus)
#     trainer.fit(model)
  
# if __name__ == '__main__':
#     parser = ArgumentParser()
#     parser.add_argument('--gpus', default=None)
#     args = parser.parse_args()
  
#     main(args)

In [27]:
# # 실행
# $ python main.py --gpus 2


#### Testing

In [29]:
trainer.test()

  rank_zero_warn(f"you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop")


[]

#### Deployment / prediction

In [None]:
# # load model
# pretrained_model = LightningModule.load_from_checkpoint(PATH)
# pretrained_model.freeze()
  
# # use it for finetuning
# def forward(self, x):
#     features = pretrained_model(x)
#     classes = classifier(features)
  
# # or for prediction
# out = pretrained_model(x)
# api_write({'response': out}