# Training on MNIST with XABY with Convolutions

This notebook demonstrates how to train a convolutional network on MNIST with the XABY framework.

I'm going to use torchvision to load in the MNIST data, because it's super great.

In [1]:
import time
import xaby as xb
import xaby.nn as xn

import jax
import jax.numpy as np

# For loading MNIST data
import torch
from torchvision import transforms, datasets



In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

mnist_train = datasets.MNIST("~/.pytorch", train=True, transform=transform, download=True)
mnist_test = datasets.MNIST("~/.pytorch", train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=64, num_workers=2)

## Train a convolutional model

Below I'll define a convolutional network and train it on the MNIST dataset.

In [8]:
### Define a model
conv1 = xn.conv2d(1, 32, 3, 2, 1) >> xn.relu
conv2 = xn.conv2d(32, 64, 3, 2, 1) >> xn.relu
fc = xn.linear(3136, 128) >> xn.relu
classifier = xn.linear(128, 10) >> xn.log_softmax

model = conv1 >> conv2 >> xb.flatten >> fc >> classifier

# Model with backpropagation
backprop = model << xn.nlloss

# Update function
update = xb.optim.sgd(model, lr=0.003)

step = 0
start = time.time()
epochs = 2
batch_size = train_loader.batch_size
print_every = 100
for e in range(epochs):
    for images, labels in train_loader:
        step += 1
        inputs, targets = xb.tensor(images), xb.tensor(labels)
        
        train_loss, grads = backprop(inputs, targets)
        grads >> update

        if step % print_every == 0:
            stop = time.time()
            test_losses = []
            test_accuracy = []
            for images, labels in test_loader:
                inputs, targets = xb.tensor(images), xb.tensor(labels)

                log_p = inputs >> model
                loss = log_p >> xn.nlloss << targets
                accuracy = xb.metrics.accuracy(log_p, targets)

                test_losses.append(loss.numpy())
                test_accuracy.append(accuracy)

            print(f"Epoch: {e+1}/{epochs}  "
                  f"Train loss: {train_loss:.3f}  "
                  f"Test loss: {sum(test_losses)/len(test_losses):.3f}  "
                  f"Test acc.: {sum(test_accuracy)/len(test_accuracy):.3f}  "
                  f"Images/sec: {print_every*batch_size/(stop - start):.3f}")
            
            start = time.time()

Epoch: 1/2  Train loss: 2.238  Test loss: 2.222  Test acc.: 0.481  Images/sec: 762.972
Epoch: 1/2  Train loss: 2.085  Test loss: 2.084  Test acc.: 0.658  Images/sec: 1616.803
Epoch: 1/2  Train loss: 1.806  Test loss: 1.758  Test acc.: 0.712  Images/sec: 1552.074
Epoch: 1/2  Train loss: 1.270  Test loss: 1.193  Test acc.: 0.800  Images/sec: 1380.925
Epoch: 1/2  Train loss: 0.909  Test loss: 0.765  Test acc.: 0.839  Images/sec: 1163.274
Epoch: 1/2  Train loss: 0.501  Test loss: 0.572  Test acc.: 0.870  Images/sec: 1209.064
Epoch: 1/2  Train loss: 0.686  Test loss: 0.482  Test acc.: 0.878  Images/sec: 1215.050
Epoch: 1/2  Train loss: 0.485  Test loss: 0.438  Test acc.: 0.881  Images/sec: 1448.099
Epoch: 1/2  Train loss: 0.458  Test loss: 0.406  Test acc.: 0.892  Images/sec: 1148.984
Epoch: 2/2  Train loss: 0.416  Test loss: 0.386  Test acc.: 0.893  Images/sec: 694.396
Epoch: 2/2  Train loss: 0.328  Test loss: 0.373  Test acc.: 0.895  Images/sec: 1291.712
Epoch: 2/2  Train loss: 0.502  Tes