[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/merantix/mxlabs-datasets/blob/main/examples/Squirrel_Tutorial_Pytorch_Model_Training.ipynb)

# Install Squirrel and Squirrel Datasets

In [None]:
!pip install keyring keyrings.google-artifactregistry-auth
from google.colab import auth

auth.authenticate_user()

In [None]:
!pip install squirrel-core squirrel-datasets --extra-index=https://europe-west1-python.pkg.dev/mx-labs-devops/labs-pypi-registry/simple/ --ignore-requires-python --upgrade

# MNIST dataset construction using Squirrel

In [None]:
import typing as t

import PIL
from tqdm.auto import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as tud
import torchvision.transforms as tr

from squirrel.catalog import Catalog
from squirrel.iterstream.torch_composables import TorchIterable

This tutorial gives a brief overview of how to load a standard dataset in Squirrel to train a simple Neural Network.
We start the introduction using the seminal MNIST dataset.

The basic entrypoint for loading datasets in Squirrel is the `Catalog` API.
We can instantiate a catalog straightforwardly and obtain a list of all registered datasets.

In [None]:
ca = Catalog.from_plugins()
print(sorted(ca.keys()))

To load the MNIST data we simply use a standard dictionary API and call `.get_driver()` to obtain the `Driver`.
In this case the resulting driver is a lightweight Huggingface wrapper.

In [None]:
mnist = ca["mnist"].get_driver()
type(mnist)

Now let's get going to actually load some data and look at a few examples

In [None]:
mnist.get_iter("train").take(5).collect()

The dataloader is returning a list of dictionaries.
Each of them contains a `PIL.Image.Image` accessible via the _image_-key and a integer via the _label_-key.
So let's look at the data.
For this we first create a grid plotting function for PIL images and then use the MNIST driver to return images.
Note that we use the `map` functionality to select only the _image_ key from the data samples.

In [None]:
def grid(imgs: t.List[PIL.Image.Image], nrows: int, ncols: int) -> PIL.Image.Image:

    w, h = imgs[0].size
    grid = PIL.Image.new("RGB", size=(ncols * w, nrows * h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % ncols * w, i // ncols * h))
    return grid


imgs = mnist.get_iter("train").take(25).map(lambda x: x["image"]).collect()
grid(imgs, 5, 5)

Looks indeed like MNIST!

Let's move on to create a training dataloader.
Using the `map` function of the driver, we can easily incorporate the necessary data augmentations.
We rely on `torchvision` transforms to construct the augmentation pipeline.
We first define the train and test transforms, which in this case only differ by adding some regularizing noise to the training data.
Otherwise it consists of casting the PIL image to a `torch.Tensor` object and centering the data.
The `augmentation` method will in the next step be used as the `lambda` function object for the driver's `map` call.

In [None]:
train_augment = tr.Compose([tr.ToTensor(), tr.Lambda(lambda x: (255 * x + torch.rand_like(x)) / 256 - 0.5)])


test_augment = tr.Compose([tr.ToTensor(), tr.Lambda(lambda x: 255 * x / 256 - 0.5)])


def augmentation(image: PIL.Image, augmentation: tr.Compose) -> torch.Tensor:
    return augmentation(image)

Now that we constructed the augmentation let's put the pieces together.
We create the train and test data drivers as already demonstrated previously, then map the data-augmentation over the individual data samples.
For the sake of simplicity and speed of training we select a subset of the training and the test data.
Finally we compose the driver with a `TorchIterable` in order to make it a torch dataset.
This is necessary in order to use the standard `torch.utils.data.DataLoader` API.
With this the dataloading with squirrel becomes a drop-in replacement for any previous training loops relying on the torch dataloading mechanism.

One detail to be aware of: When applying the data augmentation, note that we turn the dictionary into a tuple.
This is to be consistent with the standard torch training examples for MNIST, which returns the data in a tuple of `(sample, label)`.

In [None]:
mnist_train = (
    ca["mnist"]
    .get_driver()
    .get_iter("train")
    .map(lambda r: (augmentation(r["image"], train_augment), r["label"]))
    .take(4000)
    .compose(TorchIterable)
)

mnist_test = (
    ca["mnist"]
    .get_driver()
    .get_iter("train")
    .map(lambda r: (augmentation(r["image"], test_augment), r["label"]))
    .take(200)
    .compose(TorchIterable)
)

train_loader = tud.DataLoader(mnist_train, batch_size=20)
test_loader = tud.DataLoader(mnist_test, batch_size=50)

The remaining parts are fairly standard and for more details we refer the reader to PyTorch tutorials explaining how to set up the training.
In summary: we first define an evaluation function that measures the accuracy on a given test dataset, then define a standard MLP network with GeLU activation and BatchNorm layers.
Finally we define the loss-function for multi-class classification and define our optimizer (in this case SGD with momentum).

In [None]:
def evaluate(net: nn.Module, loader: tud.DataLoader) -> t.Dict[str, float]:
    net.eval()
    with torch.no_grad():
        accs = []
        for b, lbl in tqdm(loader, desc="eval", leave=False):
            pred = net(b.reshape(-1, 28 ** 2))
            accs += (pred.argmax(-1) == lbl.flatten()).numpy().tolist()

    return float(np.mean(accs))


net = nn.Sequential(
    nn.Linear(28 ** 2, 1024),
    nn.GELU(),
    nn.BatchNorm1d(1024),
    nn.Linear(1024, 1024),
    nn.GELU(),
    nn.BatchNorm1d(1024),
    nn.Linear(1024, 1024),
    nn.GELU(),
    nn.Linear(1024, 10),
)

xent = nn.CrossEntropyLoss()
opter = optim.SGD(params=net.parameters(), lr=0.01, momentum=0.5)

Now let's train the model.
As you can see from the code it is a standard layout of a torch training loop and nothing refers to specifics of squirrel.

In [None]:
for idx, (b, lbl) in tqdm(enumerate(train_loader)):
    if idx % 20 == 0:
        print(f"step: {idx:03d}, accuracy: {evaluate(net, test_loader)}")
    net.train()
    opter.zero_grad()
    pred = net(b.reshape(-1, 28 ** 2))
    loss = xent(pred, lbl)
    loss.backward()
    opter.step()

print(f"step: {idx:03d}, accuracy: {evaluate(net, test_loader)}")

This concludes the _Hello World!_ example of training a deep neural network with squirrel dataloaders. Play and have fun! 