[![Github](https://img.shields.io/github/stars/lab-ml/samples?style=social)](https://github.com/lab-ml/samples)                

## MNIST Pytorch Lightning

Install the library

In [1]:
!pip install labml pytorch_lightning

Collecting labml
[?25l  Downloading https://files.pythonhosted.org/packages/7f/00/199e012664feda73d863b2c45e6809a9b99473c46b3530e8b4d34a129074/labml-0.4.102-py3-none-any.whl (101kB)
[K     |████████████████████████████████| 102kB 6.7MB/s 
[?25hCollecting pytorch_lightning
[?25l  Downloading https://files.pythonhosted.org/packages/e7/d4/d2751586c7961f238a6077a6dc6e4a9214445da3219f463aa44b29fe4b42/pytorch_lightning-1.1.8-py3-none-any.whl (696kB)
[K     |████████████████████████████████| 696kB 15.0MB/s 
[?25hCollecting gitpython
[?25l  Downloading https://files.pythonhosted.org/packages/fb/67/47a04d8a9d7f94645676fe683f1ee3fe9be01fe407686c180768a92abaac/GitPython-3.1.13-py3-none-any.whl (159kB)
[K     |████████████████████████████████| 163kB 18.9MB/s 
Collecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 829kB 12.

Import the library

In [2]:
import pytorch_lightning as pl
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

from labml import lab, experiment
from labml.utils.lightning import LabMLLightningLogger

Define the Model

In [3]:
class MNISTModel(pl.LightningModule):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        self.log('loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

Run the experiment

In [4]:
def main():
    # Init our model
    mnist_model = MNISTModel()

    # Init DataLoader from MNIST Dataset
    train_ds = MNIST(str(lab.get_data_path()), train=True, download=True, transform=transforms.ToTensor())
    train_loader = DataLoader(train_ds, batch_size=32)

    # Initialize a trainer
    trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20, logger=LabMLLightningLogger())

    # Train the model ⚡
    with experiment.record(name='mnist_lightning', disable_screen=True):
        trainer.fit(mnist_model, train_loader)


if __name__ == '__main__':
    main()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /content/data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/data/MNIST/raw/train-images-idx3-ubyte.gz to /content/data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /content/data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/data/MNIST/raw/train-labels-idx1-ubyte.gz to /content/data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /content/data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/data/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /content/data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/data/MNIST/raw
Processing...
Done!





  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 7.9 K 
--------------------------------
7.9 K     Trainable params
0         Non-trainable params
7.9 K     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…


