<a href="https://colab.research.google.com/github/misakiyanan/Deep-Learning-Projects/blob/main/pytorch_lightning_MNIST_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# capture
!pip install pytorch-lightning



In [None]:
## https://blog.csdn.net/weixin_43792166/article/details/97952312
## https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01-mnist-hello-world.ipynb#scrollTo=4DNItffri95Q
# https://github.com/pytorch/examples/blob/master/mnist/main.py

## https://www.youtube.com/watch?v=DbESHcCoWbM&t=513s
## https://www.youtube.com/watch?v=tgp56S2eGFE

## https://www.youtube.com/watch?v=OMDn66kM9Qc
## https://www.youtube.com/watch?v=vD5iQkdqMqU

import torch
from torch import nn
import torch.nn.functional as F

from torch import optim

from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
torch.set_printoptions(precision=10)


In [None]:
train_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.MNIST('data', train=False, download=True, transform=transforms.ToTensor())

mnist_train, mnist_val = random_split(train_data, [50000, 10000])
train_loader = DataLoader(mnist_train, batch_size=64)
val_loader = DataLoader(mnist_val, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)

In [None]:
class CNNClassifier(pl.LightningModule):

  ## 1. model
  # def __init__(self):
  #   super().__init__()
  #   self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,
  #                                        out_channels=32,
  #                                        kernel_size=3,
  #                                        stride=1,
  #                                        padding=1), #padding=（kernel_size-stride）/2
  #                                        nn.ReLU(),
  #                                        nn.MaxPool2d(kernel_size=2))
    

  #   self.conv1 = nn.Sequential(nn.Conv2d(in_channels=32,
  #                                        out_channels=64,
  #                                        kernel_size=3,
  #                                        stride=1,
  #                                        padding=1), #padding=（kernel_size-stride）/2
  #                                        nn.ReLU(),
  #                                        nn.MaxPool2d(kernel_size=2))

  #   self.prediction = nn.Linear(32*7*7, 10)

  # def forward(self, x):
  #   x = self.conv1(x)
  #   x = self.conv2(x)
  #   x = x.view(x.size(0), -1)
  #   logits = self.prediction(x)
  #   return logits

  ## 1. model
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(1, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 64, 3, 1)
    self.dropout1 = nn.Dropout(0.25)
    self.dropout2 = nn.Dropout(0.5)
    # self.fc1 = nn.Linear(9216, 128)
    # self.fc2 = nn.Linear(128, 10)
    self.fc1 = nn.Linear(9216, 10)
    
    self.loss = nn.CrossEntropyLoss()

  def forward(self, x):
    x = self.conv1(x)
    x = F.relu(x)
    # x = F.max_pool2d(x, 2)
    x = self.conv2(x)
    x = F.relu(x)
    x = F.max_pool2d(x, 2)
    # x = self.dropout1(x)
    x = torch.flatten(x, 1)
    x = self.fc1(x)
    # x = F.relu(x)
    # x = self.dropout2(x)
    # x = self.fc2(x)
    logits = F.log_softmax(x, dim=1)
    return logits

 
  ## 2. optimizer
  def configure_optimizers(self):
    optimizer = optim.Adam(self.parameters(), lr=1e-2)
    return optimizer
  
  
  ## 3.1 train
  def training_step(self, batch, batch_idx):
    x, y = batch
    
    logits = self.forward(x) # forward
    loss = self.loss(logits, y)  # objective function
    acc = accuracy(logits, y)

    self.log('train_loss', loss, prog_bar=True)
    self.log('train_acc', acc, prog_bar=True)
    
    return loss


  ## 3.2 validation
  def validation_step(self, batch, batch_idx):
    x,y = batch

    logits = self.forward(x) # forward
    loss = self.loss(logits, y)  # objective function
    acc = accuracy(logits, y)

    self.log('val_loss', loss, prog_bar=True)
    self.log('val_acc', acc, prog_bar=True)

    return loss

  
  ## 3.3 test
  def test_step(self, batch, batch_idx):
    x,y = batch

    logits = self.forward(x) # forward
    loss = self.loss(logits, y)  # objective function
    acc = accuracy(logits, y)

    self.log('test_loss', loss, prog_bar=True)
    self.log('test_acc', acc, prog_bar=True)
    return loss


  ## 4. data 
  def train_dataloader(self): 
    return train_loader
  
  def val_dataloader(self):
    return val_loader
  
  def test_dataloader(self):
    return test_loader


model = CNNClassifier()

In [None]:
# trainer = pl.Trainer(progress_bar_refresh_rate=20, max_epochs=5, gpus=1)
from pytorch_lightning.callbacks import EarlyStopping
early_stop_callback = EarlyStopping(monitor='val_acc', patience=5, verbose=True)
trainer = pl.Trainer(callbacks=[early_stop_callback], progress_bar_refresh_rate=20, gpus=1)

trainer.fit(model)
trainer.test(model)

EarlyStopping mode set to max for monitoring val_acc.
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type             | Params
----------------------------------------------
0 | conv1    | Conv2d           | 320   
1 | conv2    | Conv2d           | 18.5 K
2 | dropout1 | Dropout          | 0     
3 | dropout2 | Dropout          | 0     
4 | fc1      | Linear           | 92.2 K
5 | loss     | CrossEntropyLoss | 0     
----------------------------------------------
110 K     Trainable params
0         Non-trainable params
110 K     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.9844999909, device='cuda:0'),
 'test_loss': tensor(0.0534103997, device='cuda:0')}
--------------------------------------------------------------------------------


[{'test_acc': 0.984499990940094, 'test_loss': 0.053410399705171585}]