# Training on MNIST with XABY

This notebook demonstrates how to train a fully connected network (not convolutional!) on MNIST with the XABY framework. I'll also compare it to PyTorch so you can see the different APIs and performances.

I'm going to use torchvision to load in the MNIST data, because it's super great.

In [1]:
import time

import xaby
import xaby.nn.functional as xf

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets, transforms

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

mnist_train = datasets.MNIST("~/.pytorch", train=True, transform=transform, download=True)
mnist_test = datasets.MNIST("~/.pytorch", train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=128, num_workers=2)

## Defining Models

First up, I'll define two models with the same architecture. One with XABY, the other with PyTorch. 

XABY models are defined as a sequence of operations. When a model is defined, it is compiled behind the scenes into a single function. You call the function with some input like `inputs >> model`. I had a lot of fun messing with Python operators. My intention of doing it this way is if you chain a lot of functions, the last function called is the first function you read. I'm using the `>>` operator so you can write the chain of functions in the order they are called.

You can define the PyTorch model with `torch.nn.Sequential`, but sublassing from `torch.nn.Module` is the preferred method, so I'll do that.

In [3]:
## XABY model ##
xaby_model = xf.flatten >> xf.linear(784, 256) >> xf.relu >> xf.linear(256, 10) >> xf.log_softmax

## PyTorch Model ##
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)
        
    def forward(self, x):
        x = x.flatten(1)
        x = F.relu(self.fc1(x))
        x = F.log_softmax(self.fc2(x), dim=1)
        return x
    
torch_model = Model()

## Let's run data through the models!

Just a small example of using XABY models for inference.

In [4]:
# Get data from the image loader
images, labels = next(iter(train_loader))

# Convert PyTorch Tensor to a XABY Tensor
inputs = xaby.Tensor(images)

# Thanks to JAX, XABY tensors are automatically on the GPU
print(f"XABY device: {inputs.device()}")

# Call the model in a fun manner
log_p = inputs >> xaby_model

# Normal function call... boring....
log_p = xaby_model(inputs)

XABY device: GPU_0


I should also note you can run XABY tensors through operations without creating models. This returns another Tensor. If you start the sequence with an operation, it'll create a model. If you start with a tensor, it'll run through the operations and return a tensor.

In [5]:
inputs >> xf.flatten >> xf.linear(784, 10) >> xf.log_softmax

Tensor([[ -0.3852086  -10.208656    -4.454486   ...  -2.8513358   -5.7166343
   -7.173292  ]
 [ -0.16765976  -5.399795    -5.7430463  ...  -3.962609    -6.4383087
   -6.3500037 ]
 [ -5.0375943   -5.699171   -10.508496   ...  -4.827385    -5.69089
   -6.901062  ]
 ...
 [ -3.5123837   -4.5302687   -6.137228   ...  -5.2335353   -1.0158901
   -9.262663  ]
 [ -0.6661906   -9.394583    -2.7893095  ...  -5.8804502   -5.424393
   -3.649283  ]
 [ -3.7532034  -12.018103   -11.665488   ...  -7.9599133  -14.067234
  -14.053231  ]], dtype=float32)

## Timing XABY and PyTorch

Below I'll test how long it takes for inference with these models.

In [6]:
# First on CPU
torch_model = torch_model.requires_grad_(False)
torch_model.to("cpu")
images = images.to("cpu")

%timeit -n 1000 torch_model(images)

327 µs ± 11.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [7]:
# Now on GPU
torch_model.to("cuda")
images = images.to("cuda")

%timeit -n 1000 torch_model(images)

108 µs ± 460 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [8]:
# Now the XABY model, runs on GPU!
inputs = xaby.Tensor(images.to("cpu"))

%timeit -n 1000 inputs >> xaby_model

170 µs ± 5.18 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


XABY is slightly slower than PyTorch on the GPU. This might be due to JAX being slower or it's possible I can do some more optimization in XABY.

Either way, time to train the models. First up, XABY. I'll use simple stochastic gradient descent for both. The output of the models is log-softmax, so I'll use the negative log-likelihood loss.

In XABY, we create a `backprop` object that takes the input and targets, then returns the loss and gradients.

In [9]:
batch_size = train_loader.batch_size
print_every = 100

In [10]:
# Define a fresh model
model = xf.flatten >> xf.linear(784, 256) >> xf.relu >> xf.linear(256, 10) >> xf.log_softmax

# Backprogate the loss through the model
backprop = model << xaby.losses.nlloss

# Create the optimizer
optimize = xaby.optim.SGD(model, lr=0.003)

step = 0
start = time.time()
for images, labels in train_loader:
    step += 1
    
    inputs, targets = xaby.Tensor(images), xaby.Tensor(labels)
    
    # Backprop to get gradients!
    train_loss, grads = inputs >> backprop << targets
    # Update model parameters with gradients
    optimize(grads)
    
    if step % print_every == 0:
        stop = time.time()
        test_losses = []
        test_accuracy = []
        for images, labels in test_loader:
            inputs, targets = xaby.Tensor(images), xaby.Tensor(labels)
            log_p = inputs >> model
            loss = log_p >> xaby.losses.nlloss << targets
            accuracy = xaby.metrics.accuracy(log_p, targets)
            
            test_losses.append(loss.item())
            test_accuracy.append(accuracy.item())
            
        print(f"Train loss: {train_loss.item():.3f}  "
              f"Test loss: {sum(test_losses)/len(test_losses):.3f}  "
              f"Test accuracy: {sum(test_accuracy)/len(test_accuracy):.3f}  "
              f"Images/sec: {print_every*batch_size/(stop - start):.3f}")
        start = time.time()

Train loss: 1.208  Test loss: 1.276  Test accuracy: 0.590  Images/sec: 5919.404
Train loss: 0.903  Test loss: 0.852  Test accuracy: 0.727  Images/sec: 14378.808
Train loss: 0.760  Test loss: 0.695  Test accuracy: 0.781  Images/sec: 14424.387
Train loss: 0.599  Test loss: 0.605  Test accuracy: 0.810  Images/sec: 14358.304


In [11]:
# Start with a fresh model
torch_model = torch.nn.Sequential(
                    torch.nn.Flatten(),
                    torch.nn.Linear(784, 256),
                    torch.nn.ReLU(),
                    torch.nn.Linear(256, 10),
                    torch.nn.LogSoftmax(1))
torch_model.to("cuda")
optimizer = torch.optim.SGD(torch_model.parameters(), lr=0.003)
criterion = torch.nn.NLLLoss()

step = 0
start = time.time()
for images, labels in train_loader:
    step += 1
    
    inputs, targets = images.to("cuda"), labels.to("cuda")
    
    optimizer.zero_grad()
    log_p = torch_model(inputs)
    loss = criterion(log_p, targets)
    loss.backward()
    optimizer.step()
    
    train_loss = loss.item()
    
    if step % print_every == 0:
        stop = time.time()
        test_losses = []
        test_accuracy = []
        for images, labels in test_loader:
            with torch.no_grad():
                inputs, targets = images.to("cuda"), labels.to("cuda")
                log_p = torch_model(inputs)
                loss = criterion(log_p, targets)
                accuracy = (log_p.argmax(axis=1) == targets).sum()/float(len(images))
            
            test_losses.append(loss.item())
            test_accuracy.append(accuracy.item())
            
        print(f"Train loss: {train_loss:.3f}  "
              f"Test loss: {sum(test_losses)/len(test_losses):.3f}  "
              f"Test accuracy: {sum(test_accuracy)/len(test_accuracy):.3f}  "
              f"Images/sec: {print_every*batch_size/(stop - start):.3f}")
        start = time.time()

Train loss: 1.786  Test loss: 1.790  Test accuracy: 0.632  Images/sec: 13148.382
Train loss: 1.348  Test loss: 1.331  Test accuracy: 0.759  Images/sec: 14306.074
Train loss: 1.029  Test loss: 1.018  Test accuracy: 0.805  Images/sec: 14353.490
Train loss: 0.874  Test loss: 0.828  Test accuracy: 0.832  Images/sec: 14343.788
