# Integrating LaminDB with PyTorch and training an autoencoder

As part of our demonstration of a simple ML workflow with LaminDB and PyTorch, we now proceed to query the ingested data, prepare a PyTorch dataset, and train our simple autoencoder.

This notebook focuses on the key interface between LaminDB and MLOps tools: the dataset abstraction.

LaminDB provides fundamental building blocks (metadata-enriched object storage, expressive querying and streaming, and built-in data lineage) that simplify and enhance the process of building modular, special-purpose datasets for any ML workflow.

```{note}
- For an introduction to this four-part demonstration, please see [LaminDB use case: integrating with PyTorch to train a model on the MNIST dataset](./mnist-intro.ipynb).
- For ingesting the MNIST dataset stored locally, please see [Ingesting a remote ML dataset](./mnist-ingest-local.ipynb).
- For ingesting the MNIST dataset stored in the cloud, please see [Ingesting a remote ML dataset](./mnist-ingest-remote.ipynb).
- For extending the LaminDB schema, please see [Extending the LaminDB schema](./mnist-extend-schema.ipynb).
```

## Accessing data

Let's load the instance we created when ingesting data objects from a local folder.

In [None]:
import lndb

lndb.load("mnist-remote")

In [None]:
import lamindb as ln

ln.nb.header()

## Building a custom PyTorch dataset with LaminDB

Most MLOps tools provide a modular, customizable abstraction over datasets. 

This is exactly where LaminDB comes into play: it not only simplifies the process of building these abstractions, but also integrates crucial features of decentralized data management (shared metadata layer, built-in data lineage, universal biological/biomedical ontologies) into the ML lifecyle.

In the case of PyTorch, [that abstraction is the `torch.utils.data.Dataset` class](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files).

In order to build a custom PyTorch `Dataset` class, one must inherit from `torch.utils.data.Dataset` and implement three key methods:
- `__init__`: initialize and load relevant data objects.
- `__len__`: return the size of the dataset.
- `__getitem__`: return a sample (feature-label pair) for model training or testing.

Let's see how we can build a custom PyTorch `Dataset` by leveraging the LaminDB API.

In [None]:
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from PIL import Image

# To maximize performance, let's disable tracking of each data object as inputs to this notebook run
ln.settings.track_run_inputs_upon_load = False


class LNDataset(Dataset):
    def __init__(self, dfolder: ln.DFolder):
        # query dobjects in the data folder
        self.dobjects = (
            ln.select(ln.DObject)
            .join(ln.DObject.dfolders)
            .where(ln.DFolder.id == dfolder.id)
        ).all()

        # define features and labels
        self.feature_dobjects = []
        for dobject in self.dobjects:
            if dobject.name == "labels":  # load and define dataframe with labels
                self.img_labels = dobject.load()
            else:
                self.feature_dobjects += [dobject]

        # set key torch.utils.data.Dataset attributes
        self.transform = ToTensor()
        self.target_transform = None

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        # get feature dobject
        dobject = self.feature_dobjects[idx]
        # get image from dobject
        path = dobject.load()
        image = Image.open(path)
        # get label from dobject
        filename = dobject.name + dobject.suffix
        label = self.img_labels.loc[
            self.img_labels["filename"] == filename, "label"
        ].item()
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

We can now use our custom dataset class to instantiate a dataset and pass it to train/test dataloaders for training.

In [None]:
from torch.utils.data import random_split, DataLoader

# create our custom Dataset
mnist_folder = ln.select(ln.DFolder, name="mnist").one()
mnist_dataset = LNDataset(mnist_folder)

# define train and test splits
train_subset, test_subset = random_split(mnist_dataset, [80, 20])

# create train an test Dataloaders based on splits
train_loader = DataLoader(train_subset.dataset)
test_loader = DataLoader(test_subset.dataset)

## Train the model (PyTorch Lightning)

We can now finally train a simple model, just like we would train any other model with PyTorch.

Here we train a simple autoencoder for illustration purposes.

In [None]:
from torch import optim, nn
from torchmetrics import Accuracy
import pytorch_lightning as pl

encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))


class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.test_accuracy = Accuracy(task="multiclass", num_classes=9)

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


autoencoder = LitAutoEncoder(encoder, decoder)

In [None]:
trainer = pl.Trainer(limit_train_batches=100, max_epochs=5)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)