<a href="https://colab.research.google.com/github/inyong37/Study/blob/master/_Framework/PyTorch/lightning_mnist_lenet5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install lightning

Collecting lightning
  Downloading lightning-2.2.1-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m21.0 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities<2.0,>=0.8.0 (from lightning)
  Downloading lightning_utilities-0.11.0-py3-none-any.whl (25 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.3.2-py3-none-any.whl (841 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m841.5/841.5 kB[0m [31m64.7 MB/s[0m eta [36m0:00:00[0m
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.2.1-py3-none-any.whl (801 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m801.6/801.6 kB[0m [31m62.6 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch<4.0,>=1.13.0->lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
import os
from lightning.pytorch import LightningDataModule, LightningModule, Trainer
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision import datasets, transforms


class MNIST(LightningDataModule):
  def __init__(self, data_path: str = '.', batch_size: int = 32):
    super().__init__()
    self.data_path = data_path
    self.data_transform = transforms.Compose([
      transforms.Resize((32, 32)),
      transforms.ToTensor(),
    ])
    _train_data = datasets.MNIST(self.data_path, train=True, download=True, transform=self.data_transform)
    self.train_data, self.val_data = random_split(
        _train_data, [55000, 5000], generator=torch.Generator().manual_seed(37)
      )
    self.test_data = datasets.MNIST(self.data_path, train=False, transform=self.data_transform)
    self.batch_size = batch_size
  def train_dataloader(self):
    return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=4)
  def val_dataloader(self):
    return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=4, persistent_workers=True)
  def test_dataloader(self):
    return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=4, persistent_workers=True)
  def predict_dataloader(self):
    return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=4, persistent_workers=True)


class LeNet5(LightningModule):
  def __init__(self, learning_rate: float = 0.0001):
    super().__init__()
    self.save_hyperparameters()
    self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1)
    self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)
    self.conv3 = nn.Conv2d(16, 120, kernel_size=5, stride=1)
    self.fc1 = nn.Linear(120, 84)
    self.fc2 = nn.Linear(84, 10)
  def forward(self, x):
    x = F.tanh(self.conv1(x))
    x = F.avg_pool2d(x, 2, 2)
    x = F.tanh(self.conv2(x))
    x = F.avg_pool2d(x, 2, 2)
    x = F.tanh(self.conv3(x))
    x = x.view(-1, 120)
    x = F.tanh(self.fc1(x))
    x = self.fc2(x)
    return F.softmax(x, dim=1)
  def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self(x)
    loss = F.cross_entropy(y_hat, y)
    self.log("train_loss", loss, on_epoch=True)
    return loss
  def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self(x)
    loss = F.cross_entropy(y_hat, y)
    self.log("valid_loss", loss, on_epoch=True)
    return loss
  def test_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self(x)
    loss = F.cross_entropy(y_hat, y)
    self.log("test_loss", loss, on_epoch=True)
    return loss
  def predict_step(self, batch, batch_idx):
    x, _ = batch
    return self(x)
  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)


def pl_mnist_lenet5_main():
  data_path = os.path.join(os.path.dirname(os.getcwd()), 'data')
  datamodule = MNIST(data_path=data_path)
  model = LeNet5()
  trainer = Trainer(max_epochs=10)
  trainer.fit(model, datamodule=datamodule)
  trainer.test(datamodule=datamodule)
  trainer.validate(datamodule=datamodule)
  trainer.predict(datamodule=datamodule)


if __name__ == "__main__":
  pl_mnist_lenet5_main()

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name  | Type   | Params
---------------------------------
0 | conv1 | Conv2d | 156   
1 | conv2 | Conv2d | 2.4 K 
2 | conv3 | Conv2d | 48.1 K
3 | fc1   | Linear | 10.2 K
4 | fc2   | Linear | 850   
---------------------------------
61.7 K    Trainable params
0         Non-trainable params
61.7 K    Total params
0.247     Total estimated model params size (

Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=10` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
INFO: Restoring states from the checkpoint path at /content/lightning_logs/version_1/checkpoints/epoch=9-step=17190.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /content/lightning_logs/version_1/checkpoints/epoch=9-step=17190.ckpt
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: Loaded model weights from the checkpoint at /content/lightning_logs/version_1/checkpoints/epoch=9-step=17190.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at /content/lightning_logs/version_1/checkpoints/epoch=9-step=17190.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

INFO: Restoring states from the checkpoint path at /content/lightning_logs/version_1/checkpoints/epoch=9-step=17190.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /content/lightning_logs/version_1/checkpoints/epoch=9-step=17190.ckpt
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: Loaded model weights from the checkpoint at /content/lightning_logs/version_1/checkpoints/epoch=9-step=17190.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at /content/lightning_logs/version_1/checkpoints/epoch=9-step=17190.ckpt


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Restoring states from the checkpoint path at /content/lightning_logs/version_1/checkpoints/epoch=9-step=17190.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /content/lightning_logs/version_1/checkpoints/epoch=9-step=17190.ckpt
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: Loaded model weights from the checkpoint at /content/lightning_logs/version_1/checkpoints/epoch=9-step=17190.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at /content/lightning_logs/version_1/checkpoints/epoch=9-step=17190.ckpt


Predicting: |          | 0/? [00:00<?, ?it/s]