# Training on MNIST with XABY with Convolutions

This notebook demonstrates how to train a convolutional network on MNIST with the XABY framework.

I'm going to use torchvision to load in the MNIST data, because it's super great.

In [None]:
import time
import xaby
import xaby.nn.functional as xf

import jax
import jax.numpy as np

# For loading MNIST data
import torch
from torchvision import transforms, datasets

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

mnist_train = datasets.MNIST("~/.pytorch", train=True, transform=transform, download=True)
mnist_test = datasets.MNIST("~/.pytorch", train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=64, num_workers=2)

## Train a convolutional model

Below I'll define a convolutional network and train it on the MNIST dataset.

In [7]:
### Define a model
conv1 = xf.conv2d(1, 32, 3, 2, 1) >> xf.relu
conv2 = xf.conv2d(32, 64, 3, 2, 1) >> xf.relu
fc = xf.linear(3136, 128) >> xf.relu
classifier = xf.linear(128, 10) >> xf.log_softmax

model = conv1 >> conv2 >> xf.flatten >> fc >> classifier

# Model with backpropagation
backprop = model << xf.nlloss

# Update function
update = xaby.optim.sgd(model, lr=0.003)

step = 0
start = time.time()
epochs = 2
batch_size = train_loader.batch_size
print_every = 100
for e in range(epochs):
    for images, labels in train_loader:
        step += 1
        inputs, targets = xaby.tensor(images), xaby.tensor(labels)
        
        train_loss, grads = backprop(inputs, targets)
        grads >> update

        if step % print_every == 0:
            stop = time.time()
            test_losses = []
            test_accuracy = []
            for images, labels in test_loader:
                inputs, targets = xaby.tensor(images), xaby.tensor(labels)

                log_p = inputs >> model
                loss = log_p >> xf.nlloss << targets
                accuracy = xaby.metrics.accuracy(log_p, targets)

                test_losses.append(loss.numpy())
                test_accuracy.append(accuracy)

            print(f"Epoch: {e+1}/{epochs}  "
                  f"Train loss: {train_loss:.3f}  "
                  f"Test loss: {sum(test_losses)/len(test_losses):.3f}  "
                  f"Test acc.: {sum(test_accuracy)/len(test_accuracy):.3f}  "
                  f"Images/sec: {print_every*batch_size/(stop - start):.3f}")
            
            start = time.time()

Epoch: 1/2  Train loss: 2.248  Test loss: 2.262  Test acc.: 0.175  Images/sec: 1004.998
Epoch: 1/2  Train loss: 2.223  Test loss: 2.194  Test acc.: 0.522  Images/sec: 2194.561
Epoch: 1/2  Train loss: 2.042  Test loss: 2.046  Test acc.: 0.663  Images/sec: 2387.174
Epoch: 1/2  Train loss: 1.728  Test loss: 1.708  Test acc.: 0.730  Images/sec: 2011.710
Epoch: 1/2  Train loss: 1.166  Test loss: 1.157  Test acc.: 0.789  Images/sec: 2442.315
Epoch: 1/2  Train loss: 0.790  Test loss: 0.767  Test acc.: 0.831  Images/sec: 2132.453
Epoch: 1/2  Train loss: 0.389  Test loss: 0.591  Test acc.: 0.858  Images/sec: 1997.072
Epoch: 1/2  Train loss: 0.634  Test loss: 0.509  Test acc.: 0.868  Images/sec: 2300.078
Epoch: 1/2  Train loss: 0.420  Test loss: 0.456  Test acc.: 0.879  Images/sec: 2172.295
Epoch: 2/2  Train loss: 0.520  Test loss: 0.421  Test acc.: 0.887  Images/sec: 1108.372
Epoch: 2/2  Train loss: 0.314  Test loss: 0.402  Test acc.: 0.888  Images/sec: 2206.709
Epoch: 2/2  Train loss: 0.505  T