# Training on MNIST with XABY with Convolutions

This notebook demonstrates how to train a convolutional network on MNIST with the XABY framework.

I'm going to use torchvision to load in the MNIST data, because it's super great.

In [1]:
import time
import xaby as xb
import xaby.nn as xn

import jax
import jax.numpy as np

# For loading MNIST data
import torch
from torchvision import transforms, datasets



In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

mnist_train = datasets.MNIST("~/.pytorch", train=True, transform=transform, download=True)
mnist_test = datasets.MNIST("~/.pytorch", train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=64, num_workers=2)

## Train a convolutional model

Below I'll define a convolutional network and train it on the MNIST dataset.

In [4]:
### Define a model
conv1 = xn.conv2d(1, 32, 3, 2, 1) >> xn.relu
conv2 = xn.conv2d(32, 64, 3, 2, 1) >> xn.relu
fc = xn.linear(3136, 128) >> xn.relu
classifier = xn.linear(128, 10) >> xn.log_softmax(axis=0)
model = conv1 >> conv2 >> xb.flatten(axis=0) >> fc >> classifier

# loss function
loss = xb.split(model, xb.skip) >> xn.losses.nll_loss()

# Update function
update = xb.optim.sgd(lr=0.003)

step = 0
start = time.time()
epochs = 2
batch_size = train_loader.batch_size
print_every = 100
for e in range(epochs):
    for images, labels in train_loader:
        step += 1
        
        # Wrap up the inputs
        inputs = xb.pack(xb.array(images), xb.array(labels))
        
        # Get the gradients
        train_loss, grads = loss << inputs
        
        # And update our parameters
        update(loss, grads)

        if step % print_every == 0:
            stop = time.time()
            test_losses = []
            test_accuracy = []
            for images, labels in test_loader:
                inputs = xb.pack(xb.array(images), xb.array(labels))
            
                log_p, = inputs >> xb.select(0) >> model
                pred_label = xb.jnp.argmax(log_p, axis=1)
                test_accuracy.append((inputs[1] == pred_label).mean())

                test_loss = inputs >> loss
                test_losses.append(test_loss)

            print(f"Epoch: {e+1}/{epochs}  "
                  f"Train loss: {train_loss:.3f}  "
                  f"Test loss: {sum(test_losses)/len(test_losses):.3f}  "
                  f"Test acc.: {sum(test_accuracy)/len(test_accuracy):.3f}  "
                  f"Images/sec: {print_every*batch_size/(stop - start):.3f}")
            
            start = time.time()

Epoch: 1/2  Train loss: 4.121  Test loss: 4.114  Test acc.: 0.458  Images/sec: 1261.911
Epoch: 1/2  Train loss: 4.085  Test loss: 4.057  Test acc.: 0.599  Images/sec: 2737.229
Epoch: 1/2  Train loss: 3.949  Test loss: 3.944  Test acc.: 0.675  Images/sec: 2843.397
Epoch: 1/2  Train loss: 3.731  Test loss: 3.697  Test acc.: 0.732  Images/sec: 2738.279
Epoch: 1/2  Train loss: 3.400  Test loss: 3.272  Test acc.: 0.791  Images/sec: 2845.163
Epoch: 1/2  Train loss: 3.042  Test loss: 2.940  Test acc.: 0.825  Images/sec: 2772.471
Epoch: 1/2  Train loss: 2.948  Test loss: 2.775  Test acc.: 0.844  Images/sec: 2772.951
Epoch: 1/2  Train loss: 2.661  Test loss: 2.679  Test acc.: 0.854  Images/sec: 2641.313
Epoch: 1/2  Train loss: 2.761  Test loss: 2.620  Test acc.: 0.861  Images/sec: 2409.088
Epoch: 2/2  Train loss: 2.589  Test loss: 2.578  Test acc.: 0.870  Images/sec: 1427.188
Epoch: 2/2  Train loss: 2.641  Test loss: 2.549  Test acc.: 0.876  Images/sec: 2397.966
Epoch: 2/2  Train loss: 2.738  T