[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/merantix/mxlabs-datasets/blob/main/examples/Squirrel_Tutorial_PytorchLightning_Model_Training.ipynb)


# Install Squirrel and Squirrel Datasets

In [None]:
!pip install keyring keyrings.google-artifactregistry-auth
from google.colab import auth

auth.authenticate_user()

In [None]:
!pip install squirrel-core squirrel-datasets --extra-index=https://europe-west1-python.pkg.dev/mx-labs-devops/labs-pypi-registry/simple/ --ignore-requires-python --upgrade

In [None]:
!pip install pytorch_lightning

# Train Pytorch-Lighting model on MNIST

In [None]:
from squirrel.catalog import Catalog
from squirrel.iterstream.torch_composables import TorchIterable

from torch.nn import functional as F
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from torchvision import transforms
import pytorch_lightning as pl


def preprocess_sample(sample):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    return transform(sample["image"]), sample["label"]


# define model and training
class LitMNIST(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(28 * 28, 128)
        self.layer_2 = nn.Linear(128, 256)
        self.layer_3 = nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)
        x = F.log_softmax(x, dim=1)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

    def train_dataloader(self):
        it = (
            Catalog.from_plugins()["mnist"].get_driver().get_iter("train").map(preprocess_sample).compose(TorchIterable)
        )
        return DataLoader(it, batch_size=64)

In [None]:
# train
model = LitMNIST()
trainer = pl.Trainer(max_steps=5)
trainer.fit(model)