# Pytorch_Lightningの基本的なデモ

In [1]:
import os
from sklearn import datasets
import numpy as np

import torch
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from torchmetrics import Accuracy
from torch import nn
from torch.nn import functional as F

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from torch.utils.data import DataLoader, random_split
from torchvision import transforms

from omegaconf import DictConfig, ListConfig

# DataModuleの準備

In [2]:
# 1. DataSetの作成
#   ここは従来通り、任意のデータに対応するDatasetを作成する
class IrisDataset(torch.utils.data.Dataset):
    def __init__(self, transforms=None):
        super().__init__()
        iris = datasets.load_iris()
        self.X, self.y = iris["data"], iris["target"]
        self.transforms = transforms

    def __getitem__(self, idx):
        data, label = self.X[idx], self.y[idx]
        if self.transforms is not None:
            data  = self.transforms(data)
            label = self.transforms(label)
        return data, label

    def __len__(self):
        return self.X.shape[0]

In [10]:
# 2. pl.DataModuleの準備
#   DataLoadersを作成するpl.DataModuleを作成する

class PLIrisData(pl.LightningDataModule):
    def __init__(self, BATCH_SIZE=16):
        super().__init__()
        self.batch_size = BATCH_SIZE
        self.transforms=None
         #self.transforms = transforms.Compose( [ transforms.ToTensor() ] )  #画像などで使う
    
    def setup(self, stage=None): #stageの引数は必須　
        all_data = IrisDataset(transforms=self.transforms)
        self.trn_data, self.val_data = random_split(all_data, [120,30])
    
    def train_dataloader(self):
        return DataLoader( dataset=self.trn_data, batch_size=self.batch_size ,shuffle=True)
    
    def val_dataloader(self):
        return DataLoader( dataset=self.val_data, batch_size=self.batch_size ,shuffle=False)

# Modelの作成

In [11]:
# 1. 従来通りのModelを作成する

class IrisNet(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.hidden_size  = cfg.model.hidden_size

        self.x1   = nn.Linear(in_features=4, out_features=self.hidden_size)
        self.act1 = nn.ReLU()
        self.x2   = nn.Linear(in_features=self.hidden_size, out_features=3)
        self.act2 = nn.Softmax(dim=1)
    
    def forward(self, x):
        x = self.x1(x)
        x = self.act1(x)
        x = self.x2(x)
        x = self.act2(x)
        return x

In [23]:
#2. train/valid stepを設定する、plmoduleを作成する

class PLIrisModel(pl.LightningModule):
    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.cfg    = cfg
        self.net    = IrisNet(cfg=cfg)
        self.mtrics = Accuracy()

    def forward(self, x):
        return self.net(x.float())

    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = F.nll_loss(pred, y)
        batch_loss = loss * x.size(0)
        return {"loss": loss, "y": y, "pred": pred.detach(), "batch_loss": batch_loss.detach()}
    
    def training_epoch_end(self, train_step_outputs):
        preds      = torch.cat( [trn["pred"] for trn in train_step_outputs], dim=0 )
        ys         = torch.cat( [trn["y"] for trn in train_step_outputs], dim=0 )
        epoch_loss = sum( [trn["batch_loss"] for trn in train_step_outputs] ) / ys.size(0)

        acc = self.mtrics(preds, ys)
        print('-------- Current Epoch {} --------'.format(self.current_epoch + 1))
        print('train Loss: {:.4f} train Acc: {:.4f}'.format(epoch_loss, acc))

    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = F.nll_loss(pred, y)
        batch_loss = loss * x.size(0)
        return {"y": y, "pred": pred.detach(), "batch_loss": batch_loss.detach()}
    
    def validation_epoch_end(self, valid_step_outputs):
        preds      = torch.cat( [val["pred"] for val in valid_step_outputs], dim=0 )
        ys         = torch.cat( [val["y"] for val in valid_step_outputs], dim=0 )
        epoch_loss = sum( [val["batch_loss"] for val in valid_step_outputs] ) / ys.size(0)

        acc = self.mtrics(preds, ys)

        # for EarlyStopping
        self.log("val_loss", epoch_loss)
        self.log("val_acc", acc)

        print('-------- Current Epoch {} --------'.format(self.current_epoch + 1))
        print('valid Loss: {:.4f} valid Acc: {:.4f}'.format(epoch_loss, acc))
    
    def configure_optimizers(self):
        optimizer = getattr(torch.optim, self.cfg.optim.optim_name)(self.parameters(), lr=self.cfg.optim.lr)
        return optimizer

# 学習の実行

In [25]:
# Hyper-Params
from omegaconf import OmegaConf

params = {
    "model":{
        "hidden_size": 16
    },
    "optim": {
        "optim_name": "Adam",
        "lr"        : 1e-2
    },
}

params = OmegaConf.create(params)

In [26]:
# CallBacksの設定

# モデルチェックポイント val_lossが最低となるモデルを保存
checkpoint_callback = ModelCheckpoint(
        dirpath=f"./models",
        filename="best-checkpoint",
        monitor="val_loss",
        mode="min",
    )

# EarlyStop 一定エポックval_lossの改善がなければ学習打ち切り
early_stopping_callback = EarlyStopping(
        monitor="val_loss", patience=3, verbose=True, mode="min"
    )

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


In [29]:
def main():
    model  = PLIrisModel(cfg=params)
    data   = PLIrisData()

    trainer = Trainer(
        gpus=1,
        max_epochs=30,
        callbacks=[ checkpoint_callback, early_stopping_callback ],
    )
    trainer.fit(model, data)

In [30]:
main()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type     | Params
------------------------------------
0 | net    | IrisNet  | 131   
1 | mtrics | Accuracy | 0     
------------------------------------
131       Trainable params
0         Non-trainable params
131       Total params
0.001     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

-------- Current Epoch 1 --------
valid Loss: -0.3596 valid Acc: 0.6333


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

-------- Current Epoch 1 --------
valid Loss: -0.4748 valid Acc: 0.6333
-------- Current Epoch 1 --------
train Loss: -0.4166 train Acc: 0.6667


Validating: 0it [00:00, ?it/s]

-------- Current Epoch 2 --------
valid Loss: -0.5347 valid Acc: 0.6333
-------- Current Epoch 2 --------
train Loss: -0.5354 train Acc: 0.6750


Validating: 0it [00:00, ?it/s]

-------- Current Epoch 3 --------
valid Loss: -0.5834 valid Acc: 0.6333


Monitored metric val_loss did not improve in the last 3 records. Best score: -0.655. Signaling Trainer to stop.


-------- Current Epoch 3 --------
train Loss: -0.5991 train Acc: 0.6750
