[![Github](https://img.shields.io/github/stars/labmlai/labml?style=social)](https://github.com/labmlai/labml/tree/master/samples/pytorch/mnist)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/labml/blob/master/master/samples/pytorch/mnist/mnist.ipynb)

## MNIST Pytorch

### Install the labml

In [1]:
!pip install labml



### Import libraries

In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from labml import lab, tracker, experiment

### Model definition

In [3]:
class Model(nn.Module):
  def __init__(self):
    super().__init__()
    
    self.conv1 = nn.Conv2d(1, 20, 5, 1) # 28 * 28
    self.pool1 = nn.MaxPool2d(2) # 24 * 24
    self.conv2 = nn.Conv2d(20, 50, 5, 1) # 12 * 12
    self.pool2 = nn.MaxPool2d(2) # 8 * 8

    self.fc1 = nn.Linear(4 * 4 * 50, 500) # 4 * 4
    self.fc2 = nn.Linear(500, 10)
    self.activation = nn.ReLU()

  def forward(self, x):
    x = self.activation(self.conv1(x))
    x = self.pool1(x)
    x = self.activation(self.conv2(x))
    x = self.pool2(x)
    x = x.view(-1, 4 * 4 * 50)
    x = self.activation(self.fc1(x))
    x = self.fc2(x)
    return x

### Training code

This trains the model for one epoch.

We increment the step by the number of samples processed.
The loss is saved on every batch and the model stats are saved every `model_log_interval`.

In [4]:
def train(model, loss_func, optimizer, loader, device, model_log_interval):
  model.train()

  for i, (data, target) in enumerate(loader):
    data, target = data.to(device), target.to(device)

    output = model(data)
    loss = loss_func(output, target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # ✨ Increment the global step
    tracker.add_global_step(len(data))
    # ✨ Save stats
    tracker.save({'loss.train': loss})

    if (i + 1) % model_log_interval == 0:
        # ✨ Save model stats
        tracker.save(model=model)

### Validation code

This evaluates the model on validation dataset, and save the stats at the end.

In [5]:
def validate(model, loss_func, loader, device):
  model.eval()

  correct = 0
  with torch.no_grad():
    for data, target in loader:
      data, target = data.to(device), target.to(device)

      output = model(data)
      tracker.add('loss.valid', loss_func(output, target))

      pred = output.argmax(dim=1, keepdim=True)
      correct += pred.eq(target.view_as(pred)).sum().item()
  
  valid_accuracy = 100. * correct / len(valid_loader.dataset)

  # **✨ Save stats**
  tracker.save({'accuracy.valid': valid_accuracy})

### \[Optional\] Setup tracker indicators

This tells the tracker to:
* use a queue of length 20 for training loss,
* save the validation losses as a histogram,
* save the validation accuracy as a scalar,

and print each of the metrics to the terminal.

In [6]:
# ✨ Set the types of the stats/indicators.
# They default to scalars if not specified
tracker.set_queue('loss.train', 20, True)
tracker.set_histogram('loss.valid', True)
tracker.set_scalar('accuracy.valid', True)

### Configurations

In [7]:
configs = {
    'epochs': 10,
    'train_batch_size': 64,
    'valid_batch_size': 100,
    'use_cuda': True,
    'seed': 5,
    'train_log_interval': 10,
    'learning_rate': 0.01,
}

### Initialize

In [8]:
is_cuda = configs['use_cuda'] and torch.cuda.is_available()
if not is_cuda:
    device = torch.device("cpu")
else:
    device = torch.device(f"cuda:0")

mnist_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])

train_loader = DataLoader(datasets.MNIST(str(lab.get_data_path()), 
                                         train=True,
                                         transform=mnist_transform, 
                                         download=True),
                          batch_size=configs['train_batch_size'], shuffle=True)

valid_loader = DataLoader(datasets.MNIST(str(lab.get_data_path()), 
                                         train=False, 
                                         download=True,
                                         transform=mnist_transform),
                          batch_size=configs['valid_batch_size'], shuffle=False)

model = Model().to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=configs['learning_rate'])
torch.manual_seed(configs['seed'])

<torch._C.Generator at 0x7f7e50c5ebd0>

### Run the experiment

Create the experiment

In [9]:
experiment.create(name='mnist_pytorch')

Save experiment configurations/hyper-parameters

In [10]:
experiment.configs(configs)

Set PyTorch models for checkpoint saving and loading

In [11]:
experiment.add_pytorch_models(dict(model=model))

Run the experiment

In [12]:
with experiment.start():
  for epoch in range(1, configs['epochs'] + 1):
    train(model, loss_func, optimizer, train_loader, device, configs['train_log_interval'])
    validate(model, loss_func, valid_loader, device)
    tracker.new_line()

    # ✨ Save the models
    experiment.save_checkpoint()