[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/labml/blob/master/samples/wandb/cifar10.ipynb)

## CIFAR-10 Sample

This notebook trains a VGG model on CIFAR-10 dataset.

[W&B dashboard for a sample run](https://wandb.ai/vpj/labml/runs/f2u6ip41?workspace=user-vpj)

[labml.ai monitoring](https://app.labml.ai/run/451082b89e7f11ebbc450242ac1c0002)

Install `labml` and `wandb` packages for monitoring and organizing experiments.

In [None]:
!pip install labml wandb

Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torchvision import datasets, transforms

from labml import lab, tracker, experiment, monit, logger
from labml.logger import Text

VGG model

In [2]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        layers = []
        in_channels = 3
        for block in [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]]:
            for channels in block:
                layers += [nn.Conv2d(in_channels, channels, kernel_size=3, padding=1),
                           nn.BatchNorm2d(channels),
                           nn.ReLU(inplace=True)]
                in_channels = channels
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        self.layers = nn.Sequential(*layers)
        self.fc = nn.Linear(512, 10)

    def forward(self, x):
        x = self.layers(x)
        x = x.view(x.shape[0], -1)
        return self.fc(x)

A simple class to create the training and validation data loaders.

In [3]:
class DataLoaderFactory:
    def __init__(self):
        data_transform =  transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.dataset = [
                        datasets.CIFAR10(str(lab.get_data_path()),
                            train=False,
                            download=True,
                            transform=data_transform),
                        datasets.CIFAR10(str(lab.get_data_path()),
                            train=True,
                            download=True,
                            transform=data_transform),
        ]
     
    def __call__(self, train, batch_size):
        return torch.utils.data.DataLoader(self.dataset[train],
                                           batch_size=batch_size, shuffle=True)

Model training function for a single epoch.

In [5]:
def train(model, optimizer, train_loader, device):
    model.train()
    for batch_idx, (data, target) in monit.enum("Train", train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        tracker.add_global_step(data.shape[0])
        tracker.save({'loss.train': loss})

Function to test the model on validation data.

In [6]:
def validate(model, valid_loader, device):
    model.eval()
    valid_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in monit.iterate("valid", valid_loader):
            data, target = data.to(device), target.to(device)

            output = model(data)
            valid_loss += F.cross_entropy(output, target,
                                          reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    valid_loss /= len(valid_loader.dataset)
    valid_accuracy = 100. * correct / len(valid_loader.dataset)

    tracker.save({'loss.valid': valid_loss, 'accuracy.valid': valid_accuracy})

Main function

In [9]:
def main():
    configs = {
        'epochs': 50,
        'learning_rate': 2.5e-4,
        'device': "cuda:0" if torch.cuda.is_available() else "cpu",
        'batch_size': 1024,
    }

    device = torch.device(configs['device'])
    dl_factory = DataLoaderFactory()

    train_loader = dl_factory(True, configs['batch_size'])
    valid_loader = dl_factory(False, configs['batch_size'])

    model = Net().to(device)
    optimizer = optim.Adam(model.parameters(), lr=configs['learning_rate'])

    experiment.create(name='cifar10')
    experiment.configs(configs)
    experiment.add_pytorch_models(dict(model=model))

    with experiment.start():
        for _ in monit.loop(range(1, configs['epochs'] + 1)):
            torch.cuda.empty_cache()
            train(model, optimizer, train_loader, device)
            validate(model, valid_loader, device)
            logger.log()

    experiment.save_checkpoint()

In [None]:
main()

Files already downloaded and verified
Files already downloaded and verified


[34m[1mwandb[0m: Currently logged in as: [33mvpj[0m (use `wandb login --relogin` to force relogin)
