![](https://learnopencv.com/wp-content/uploads/2020/05/PTL-1024x408.png)

---

Lightning is a very lightweight wrapper on PyTorch. This means you don’t have to learn a new library. It defers the core training and validation logic to you and automates the rest. It guarantees tested and correct code with the best modern practices for the automated parts. We will need the following:

1. Model
2. Optimizer
3. Data
4. Training Loop
5. Validation Loop

In [None]:
%%capture
!pip install pytorch-lightning

In [None]:
import torch
import pytorch_lightning as pl
from torch import nn 
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

torch.manual_seed(42)

<torch._C.Generator at 0x7efe47b3dad0>

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
train_accuracy = pl.metrics.Accuracy()
valid_accuracy = pl.metrics.Accuracy(compute_on_step=False)

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        # download only
        datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
        datasets.MNIST('data', train=False, download=True, transform=transforms.ToTensor())

    def setup(self, stage):
        # transform
        transform=transforms.Compose([transforms.ToTensor()])
        training_dataset = datasets.MNIST('data', train=True, download=False, transform=transform)
        test_dataset = datasets.MNIST('data', train=False, download=False, transform=transform)

        # train/val split
        mnist_train, mnist_val = random_split(training_dataset, [55000, 5000])

        # assign to use in dataloaders
        self.train_dataset = mnist_train
        self.val_dataset = mnist_val
        self.test_dataset = test_dataset

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

class ImageClassifier(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.l1 = nn.Linear(28 * 28, 64)
    self.l2 = nn.Linear(64, 64)
    self.l3 = nn.Linear(64, 10)
    self.do = nn.Dropout(0.1)

    self.loss = nn.CrossEntropyLoss()

  def forward(self, x):
    h1 = nn.functional.relu(self.l1(x))
    h2 = nn.functional.relu(self.l2(h1))
    do = self.do(h2 + h1)
    logits = self.l3(do)
    return logits

  def configure_optimizers(self):
    optimizer = optim.SGD(self.parameters(), lr=1e-2)
    return optimizer

  def training_step(self, batch, batch_idx):
    x, y = batch

    # x: b x 1 x 28 x 28 (b & w image)
    b = x.size(0)
    x = x.view(b, -1)

    # 1 forward
    l = self(x) # l: logits
    # import pdb; pdb.set_trace() # debugging
    # print size: p l.size()
    # print output value (argmax): p l[0].detach().argmax()
    
    # 2 compute the objective function
    J = self.loss(l, y)
    acc = train_accuracy(l, y)
    pbar = {'train_acc': acc}
    return {'loss': J, 'progress_bar': pbar}
    # return J

  def validation_step(self, batch, batch_idx):
    results = self.training_step(batch, batch_idx)
    results['progress_bar']['val_acc'] = results['progress_bar']['train_acc']
    del results['progress_bar']['train_acc']
    return results

  def validation_epoch_end(self, val_step_outputs):
    avg_val_loss = torch.tensor([x['loss'] for x in val_step_outputs]).mean()
    avg_val_acc = torch.tensor([x['progress_bar']['val_acc'] for x in val_step_outputs]).mean()
    pbar = {'avg_val_acc': avg_val_acc}
    return {'val_loss': avg_val_loss, 'progress_bar': pbar}


  stream(template_mgs % msg_args)


In [None]:
model = ImageClassifier()
mnist_dm = MNISTDataModule()
trainer = pl.Trainer(progress_bar_refresh_rate=20, max_epochs=5)
trainer.fit(model, mnist_dm)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  "GPU available but not used. Set the gpus flag in your trainer"


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

  | Name | Type             | Params
------------------------------------------
0 | l1   | Linear           | 50.2 K
1 | l2   | Linear           | 4.2 K 
2 | l3   | Linear           | 650   
3 | do   | Dropout          | 0     
4 | loss | CrossEntropyLoss | 0     
------------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [None]:
trainer

<pytorch_lightning.trainer.trainer.Trainer at 0x7efe2d2e0d90>