# Training on MNIST with XABY

This notebook demonstrates how to train a fully connected network (not convolutional!) on MNIST with the XABY framework. I'll also compare it to PyTorch so you can see the different APIs and performances.

I'm going to use torchvision to load in the MNIST data, because it's super great.

In [None]:
import time
import xaby
import xaby.nn as xn

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets, transforms

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

mnist_train = datasets.MNIST("~/.pytorch", train=True, transform=transform, download=True)
mnist_test = datasets.MNIST("~/.pytorch", train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=128, num_workers=2)

## Defining Models

First up, I'll define two models with the same architecture. One with XABY, the other with PyTorch. 

XABY models are defined as a sequence of operations. When a model is defined, it is compiled behind the scenes into a single function. You call the function with some input like `inputs >> model`. I had a lot of fun messing with Python operators. My intention of doing it this way is if you chain a lot of functions, the last function called is the first function you read. I'm using the `>>` operator so you can write the chain of functions in the order they are called.

You can define the PyTorch model with `torch.nn.Sequential`, but sublassing from `torch.nn.Module` is the preferred method, so I'll do that.

In [None]:
## XABY model ##
xaby_model = xaby.flatten >> xn.linear(784, 256) >> xn.relu \
                          >> xn.linear(256, 10) >> xn.log_softmax

## PyTorch Model ##
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)
        
    def forward(self, x):
        x = x.flatten(1)
        x = F.relu(self.fc1(x))
        x = F.log_softmax(self.fc2(x), dim=1)
        return x
    
torch_model = Model()

## Let's run data through the models!

Just a small example of using XABY models for inference.

In [None]:
# Get data from the image loader
images, labels = next(iter(train_loader))

# Convert PyTorch Tensor to a XABY Tensor
inputs = xaby.tensor(images)

# Thanks to JAX, XABY tensors are automatically on the GPU
print(f"XABY device: {inputs.device}")

# Call the model in a fun manner
log_p = inputs >> xaby_model

# Normal function call... boring....
log_p = xaby_model(inputs)

I should also note you can run XABY tensors through operations without creating models. This returns another tensor. If you start the sequence with an operation, it'll create a model. If you start with a tensor, it'll run through the operations and return a tensor.

In [None]:
inputs >> xaby.flatten >> xn.linear(784, 10) >> xn.log_softmax

## Timing XABY and PyTorch

Below I'll test how long it takes for inference with these models.

In [None]:
# First on CPU
torch_model = torch_model.requires_grad_(False)
torch_model.to("cpu")
images = images.to("cpu")

%timeit -n 1000 torch_model(images)

In [None]:
# Now on GPU
torch_model.to("cuda")
images = images.to("cuda")

%timeit -n 1000 torch_model(images)

In [None]:
# Now the XABY model, runs on GPU!
inputs = xaby.tensor(images.to("cpu"))

%timeit -n 1000 inputs >> xaby_model

XABY is slightly slower than PyTorch on the GPU. This might be due to JAX being slower or it's possible I can do some more optimization in XABY.

Either way, time to train the models. First up, XABY. I'll use simple stochastic gradient descent for both. The output of the models is log-softmax, so I'll use the negative log-likelihood loss.

In XABY, we create a `backprop` function that takes the input and targets, then returns the loss and gradients. When only evaluating, such as in validation, you can get the loss directly:
```python
loss = inputs >> model >> nlloss << targets
```

We also create an `update` function that updates a model given gradients.

In [9]:
batch_size = train_loader.batch_size
print_every = 100

In [10]:
### Define a fresh model, in two lines for readability
model = xaby.flatten >> xn.linear(784, 256) >> xn.relu \
                   >> xn.linear(256, 10) >> xn.log_softmax

# Model with backpropagation
backprop = model << xn.nlloss

# Update function
update = xaby.optim.sgd(model, lr=0.003)

step = 0
start = time.time()
for images, labels in train_loader:
    step += 1
    
    inputs, targets = xaby.tensor(images), xaby.tensor(labels)
    
    # Backprop to get gradients!
    train_loss, grads = backprop(inputs, targets)
    # Update model parameters with gradients
    grads >> update
    
    if step % print_every == 0:
        stop = time.time()
        test_losses = []
        test_accuracy = []
        for images, labels in test_loader:
            inputs, targets = xaby.tensor(images), xaby.tensor(labels)
            
            log_p = inputs >> model
            loss = log_p >> xn.nlloss << targets
            accuracy = xaby.metrics.accuracy(log_p, targets)
            
            test_losses.append(loss.numpy())
            test_accuracy.append(accuracy)
            
        print(f"Train loss: {train_loss:.3f}  "
              f"Test loss: {sum(test_losses)/len(test_losses):.3f}  "
              f"Test acc.: {sum(test_accuracy)/len(test_accuracy):.3f}  "
              f"Images/sec: {print_every*batch_size/(stop - start):.3f}")
        start = time.time()

Train loss: 1.818  Test loss: 1.809  Test acc.: 0.662  Images/sec: 3054.412
Train loss: 1.380  Test loss: 1.351  Test acc.: 0.750  Images/sec: 6638.921
Train loss: 1.015  Test loss: 1.034  Test acc.: 0.804  Images/sec: 5866.025
Train loss: 0.920  Test loss: 0.837  Test acc.: 0.833  Images/sec: 6021.412


In [None]:
# Start with a fresh model
torch_model = torch.nn.Sequential(
                    torch.nn.Flatten(),
                    torch.nn.Linear(784, 256),
                    torch.nn.ReLU(),
                    torch.nn.Linear(256, 10),
                    torch.nn.LogSoftmax(1))
torch_model.to("cuda")
optimizer = torch.optim.SGD(torch_model.parameters(), lr=0.003)
criterion = torch.nn.NLLLoss()

step = 0
start = time.time()
for images, labels in train_loader:
    step += 1
    
    inputs, targets = images.to("cuda"), labels.to("cuda")
    
    optimizer.zero_grad()
    log_p = torch_model(inputs)
    loss = criterion(log_p, targets)
    loss.backward()
    optimizer.step()
    
    train_loss = loss.item()
    
    if step % print_every == 0:
        stop = time.time()
        test_losses = []
        test_accuracy = []
        for images, labels in test_loader:
            with torch.no_grad():
                inputs, targets = images.to("cuda"), labels.to("cuda")
                log_p = torch_model(inputs)
                loss = criterion(log_p, targets)
                accuracy = (log_p.argmax(axis=1) == targets).sum()/float(len(images))
            
            test_losses.append(loss.item())
            test_accuracy.append(accuracy.item())
            
        print(f"Train loss: {train_loss:.3f}  "
              f"Test loss: {sum(test_losses)/len(test_losses):.3f}  "
              f"Test accuracy: {sum(test_accuracy)/len(test_accuracy):.3f}  "
              f"Images/sec: {print_every*batch_size/(stop - start):.3f}")
        start = time.time()