<a href="https://colab.research.google.com/github/ganesh3/pytorch-work/blob/master/mnist_resnet_pytorch_lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pytorch-lightning

Collecting pytorch-lightning
[?25l  Downloading https://files.pythonhosted.org/packages/ed/af/2f10c8ee22d7a05fe8c9be58ad5c55b71ab4dd895b44f0156bfd5535a708/pytorch_lightning-0.9.0-py3-none-any.whl (408kB)
[K     |████████████████████████████████| 409kB 5.6MB/s 
[?25hCollecting PyYAML>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 276kB 17.2MB/s 
[?25hCollecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 829kB 27.4MB/s 
Collecting tensorboard==2.2.0
[?25l  Downloading https://files.pythonhosted.org/packages/54/f5/d75a6f7935e4a4870d85770bc9976b12e7024fbceb83a1a6bc50e6deb7c4/tensorboard-2.2.0-py3-none-any.whl (2.8MB)
[K     |████████████████████████████████| 2.8M

In [24]:
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy

In [43]:
class Resnet(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.l1 = nn.Linear(28 * 28, 64)
    self.l2 = nn.Linear(64, 64)
    self.l3 = nn.Linear(64, 10)
    self.do = nn.Dropout(0.1)
    self.loss = nn.CrossEntropyLoss()

  def forward(self, x):
    h1 = nn.functional.relu(self.l1(x))
    h2 = nn.functional.relu(self.l2(h1))
    do = self.do(h2 + h1)
    logits = self.l3(do)
    return logits

  def configure_optimizers(self):
    optimizer =  optim.SGD(self.parameters(), lr=1e-2)
    return optimizer

  def training_step(self, batch, batch_idx):
    x, y = batch
    
    b = x.size(0)
    x = x.view(b, -1)
    
    logit = self(x)
    
    J = self.loss(logit, y)

    acc = accuracy(logit, y)
    pbar = {'train_accuracy': acc}

    return {'loss': J, 'progress_bar': pbar}

  def validation_step(self, batch, batch_idx):
    results = self.training_step(batch, batch_idx)
    results['progress_bar']['val_acc'] = results['progress_bar']['train_accuracy']
    del results['progress_bar']['train_accuracy']
    return results

  def validation_epoch_end(self, val_step_outputs):
    avg_val_loss = torch.tensor([x['loss'] for x in val_step_outputs]).mean()
    avg_val_acc = torch.tensor([x['progress_bar']['val_acc'] for x in val_step_outputs]).mean()
    pbar = {'avg_val_acc': avg_val_acc}
    return {'val_loss': avg_val_loss, 'progress_bar': pbar}

  def prepare_data(self):
    datasets.MNIST('data', train=True, download=True, transform = transforms.ToTensor())

  def setup(self):
    dataset = datasets.MNIST('data', train=True, download=False, transform = transforms.ToTensor())
    self.train, self.val = random_split(dataset, [55000, 5000])
  
  def train_dataloader(self):
    train_loader = DataLoader(self.train, batch_size=32)
    return train_loader

  def val_dataloader(self):
    val_loader = DataLoader(self.val, batch_size=32)
    return val_loader

model = Resnet()

In [44]:
trainer = pl.Trainer(progress_bar_refresh_rate=20, max_epochs=5, gpus=8, num_nodes=32)
trainer.fit(model)

MisconfigurationException: ignored

In [32]:
!ls lightning_logs/version_12/checkpoints/

'epoch=4.ckpt'
