# Training on MNIST with XABY with Convolutions

This notebook demonstrates how to train a convolutional network on MNIST with the XABY framework.

I'm going to use torchvision to load in the MNIST data, because it's super great.

In [None]:
import time
import xaby as xb
import xaby.nn as xn

import jax
import jax.numpy as np

# For loading MNIST data
import torch
from torchvision import transforms, datasets

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

mnist_train = datasets.MNIST("~/.pytorch", train=True, transform=transform, download=True)
mnist_test = datasets.MNIST("~/.pytorch", train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=64, num_workers=2)

## Train a convolutional model

Below I'll define a convolutional network and train it on the MNIST dataset.

In [None]:
### Define a model
conv1 = xn.conv2d(1, 32, 3, 2, 1) >> xn.relu
conv2 = xn.conv2d(32, 64, 3, 2, 1) >> xn.relu
fc = xn.linear(3136, 128) >> xn.relu
classifier = xn.linear(128, 10) >> xn.log_softmax(axis=0)
model = conv1 >> conv2 >> xb.flatten(axis=0) >> fc >> classifier

# loss function
loss = xb.nn.nll_loss(model)

# Update function
update = xb.optim.sgd(lr=0.003)

step = 0
start = time.time()
epochs = 2
batch_size = train_loader.batch_size
print_every = 100
for e in range(epochs):
    for images, labels in train_loader:
        step += 1
        inputs = xb.pack(xb.array(images), xb.array(labels))
        
        train_loss, grads = loss(inputs)
        update(model, grads)

        if step % print_every == 0:
            stop = time.time()
            test_losses = []
            test_accuracy = []
            for images, labels in test_loader:
                inputs = xb.pack(xb.array(images), xb.array(labels))
            
                log_p, = model([inputs[0]])
                pred_label = xb.jnp.argmax(log_p, axis=1)
                test_accuracy.append((inputs[1] == pred_label).mean())
                
                # Make it so we don't have to calculate the gradients for speeeeeed
                test_loss, _ = loss(inputs)
                test_losses.append(test_loss)

            print(f"Epoch: {e+1}/{epochs}  "
                  f"Train loss: {train_loss:.3f}  "
                  f"Test loss: {sum(test_losses)/len(test_losses):.3f}  "
                  f"Test acc.: {sum(test_accuracy)/len(test_accuracy):.3f}  "
                  f"Images/sec: {print_every*batch_size/(stop - start):.3f}")
            
            start = time.time()