# Training on MNIST with XABY

This notebook demonstrates how to train a fully connected network (not convolutional!) on MNIST with the XABY framework. I'll also compare it to PyTorch so you can see the different APIs and performances.

I'm going to use torchvision to load in the MNIST data, because it's super great.

In [1]:
import time
import xaby as xb
import xaby.nn as xn

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets, transforms



In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

mnist_train = datasets.MNIST("~/.pytorch", train=True, transform=transform, download=True)
mnist_test = datasets.MNIST("~/.pytorch", train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=128, num_workers=2)

## Defining Models

First up, I'll define two models with the same architecture. One with XABY, the other with PyTorch. 

XABY models are defined as a sequence of operations. When a model is defined, it is compiled behind the scenes into a single function. You call the function with some input like `inputs >> model`. I had a lot of fun messing with Python operators. My intention of doing it this way is if you chain a lot of functions, the last function called is the first function you read. I'm using the `>>` operator so you can write the chain of functions in the order they are called.

You can define the PyTorch model with `torch.nn.Sequential`, but sublassing from `torch.nn.Module` is the preferred method, so I'll do that.

In [3]:
## XABY model ##
xaby_model = xb.flatten(axis=0) >> xn.linear(784, 256) >> xn.relu \
          >> xn.linear(256, 10) >> xn.log_softmax(axis=0)

## PyTorch Model ##
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)
        
    def forward(self, x):
        x = x.flatten(1)
        x = F.relu(self.fc1(x))
        x = F.log_softmax(self.fc2(x), dim=1)
        return x
    
torch_model = Model()

## Let's run data through the models!

Just a small example of using XABY models for inference.

In [4]:
# Get data from the image loader
images, labels = next(iter(train_loader))

# Convert PyTorch Tensor to a XABY array (actually a JAX DeviceArray)
inputs = xb.array(images)

# Thanks to JAX, XABY tensors are automatically on the GPU (if one is available)
print(f"XABY device: {inputs.device_buffer.device()}")

# # Call the model in a fun manner
log_p = xb.pack(inputs) >> xaby_model

# Normal function call... boring....
log_p = xaby_model([inputs])

XABY device: CPU_0


I should also note you can run XABY tensors through operations without creating models. This returns another tensor. If you start the sequence with an operation, it'll create a model. If you start with a tensor, it'll run through the operations and return a tensor.

## Timing XABY and PyTorch

Below I'll test how long it takes for inference with these models.

In [5]:
# First on CPU
torch_model = torch_model.requires_grad_(False)
torch_model.to("cpu")
images = images.to("cpu")

%timeit -n 1000 torch_model(images)

632 µs ± 160 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [6]:
# Uncomment the code below to time things on a GPU
# torch_model.to("cuda")
# images = images.to("cuda")

# %timeit -n 1000 torch_model(images)

In [7]:
# Now the XABY model, automatically runs on a GPU if one is available
inputs = xb.pack(xb.array(images))

%timeit -n 1000 inputs >> xaby_model

637 µs ± 42.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


XABY is slightly slower than PyTorch in this test. This might be due to JAX being slower or it's possible I can do some more optimization in XABY.

Either way, time to train the models. First up, XABY. I'll use simple stochastic gradient descent for both. The output of the model is the log-probability, so I'll use the negative log-likelihood loss.

In [8]:
batch_size = train_loader.batch_size
print_every = 50

In [9]:
### Define a fresh model, in two lines for readability
model = xb.flatten(axis=0) >> xn.linear(784, 256) >> xn.relu \
                           >> xn.linear(256, 10) >> xn.log_softmax(axis=0)

# loss function
loss = xb.split(model, xb.skip) >> xn.losses.nll_loss()

# Update function
update = xb.optim.sgd(lr=0.003)

step = 0
start = time.time()
for images, labels in train_loader:
    step += 1
    
    # Wrap up our input data
    inputs = xb.pack(xb.array(images), xb.array(labels))
    
    # Get the gradients
    train_loss, grads = loss << inputs
    
    # Then, update the function with the gradients
    update(loss, grads)
    
    if step % print_every == 0:

        stop = time.time()
        test_losses = []
        test_accuracy = []
        
        for images, labels in test_loader:
            inputs = xb.pack(xb.array(images), xb.array(labels))
            
            log_p, = inputs >> xb.select(0) >> model
            pred_label = xb.jnp.argmax(log_p, axis=1)
            test_accuracy.append((inputs[1] == pred_label).mean())
            
            test_loss = inputs >> loss
            test_losses.append(test_loss)
            
        print(f"Train loss: {train_loss:.3f}  "
              f"Test loss: {sum(test_losses)/len(test_losses):.3f}  "
              f"Test acc.: {sum(test_accuracy)/len(test_accuracy):.3f}  "
              f"Images/sec: {print_every*batch_size/(stop - start):.3f}")
        start = time.time()

Train loss: 4.637  Test loss: 4.584  Test acc.: 0.513  Images/sec: 3541.424
Train loss: 4.366  Test loss: 4.338  Test acc.: 0.687  Images/sec: 8872.720
Train loss: 4.143  Test loss: 4.107  Test acc.: 0.746  Images/sec: 8731.576
Train loss: 3.969  Test loss: 3.912  Test acc.: 0.778  Images/sec: 8306.100
Train loss: 3.720  Test loss: 3.762  Test acc.: 0.802  Images/sec: 8700.962
Train loss: 3.657  Test loss: 3.650  Test acc.: 0.820  Images/sec: 7908.449
Train loss: 3.688  Test loss: 3.566  Test acc.: 0.831  Images/sec: 8166.549
Train loss: 3.543  Test loss: 3.503  Test acc.: 0.839  Images/sec: 8024.907
Train loss: 3.559  Test loss: 3.455  Test acc.: 0.849  Images/sec: 9572.297


In [10]:
# Start with a fresh model
torch_model = torch.nn.Sequential(
                    torch.nn.Flatten(),
                    torch.nn.Linear(784, 256),
                    torch.nn.ReLU(),
                    torch.nn.Linear(256, 10),
                    torch.nn.LogSoftmax(1))
torch_model.to("cpu")
optimizer = torch.optim.SGD(torch_model.parameters(), lr=0.003)
criterion = torch.nn.NLLLoss()

step = 0
start = time.time()
for images, labels in train_loader:
    step += 1
    
    inputs, targets = images.to("cpu"), labels.to("cpu")
    
    optimizer.zero_grad()
    log_p = torch_model(inputs)
    loss = criterion(log_p, targets)
    loss.backward()
    optimizer.step()
    
    train_loss = loss.item()
    
    if step % print_every == 0:
        stop = time.time()
        test_losses = []
        test_accuracy = []
        for images, labels in test_loader:
            with torch.no_grad():
                inputs, targets = images.to("cpu"), labels.to("cpu")
                log_p = torch_model(inputs)
                loss = criterion(log_p, targets)
                accuracy = (log_p.argmax(axis=1) == targets).sum()/float(len(images))
            
            test_losses.append(loss.item())
            test_accuracy.append(accuracy.item())
            
        print(f"Train loss: {train_loss:.3f}  "
              f"Test loss: {sum(test_losses)/len(test_losses):.3f}  "
              f"Test accuracy: {sum(test_accuracy)/len(test_accuracy):.3f}  "
              f"Images/sec: {print_every*batch_size/(stop - start):.3f}")
        start = time.time()

Train loss: 2.090  Test loss: 2.047  Test accuracy: 0.489  Images/sec: 6658.876
Train loss: 1.827  Test loss: 1.792  Test accuracy: 0.657  Images/sec: 7181.570
Train loss: 1.590  Test loss: 1.551  Test accuracy: 0.727  Images/sec: 9431.559
Train loss: 1.370  Test loss: 1.337  Test accuracy: 0.770  Images/sec: 8120.559
Train loss: 1.254  Test loss: 1.159  Test accuracy: 0.795  Images/sec: 8477.698
Train loss: 0.958  Test loss: 1.020  Test accuracy: 0.813  Images/sec: 7962.972
Train loss: 0.891  Test loss: 0.911  Test accuracy: 0.826  Images/sec: 6625.582
Train loss: 0.802  Test loss: 0.825  Test accuracy: 0.835  Images/sec: 6090.134
Train loss: 0.733  Test loss: 0.757  Test accuracy: 0.844  Images/sec: 8422.017
