# Track images for PyTorch

In [None]:
!lamin init --storage "mnist-100"

In [None]:
import lamindb as ln
import pandas as pd

ln.track()

In [None]:
# prepare local data
import boto3
from pathlib import Path

s3 = boto3.resource("s3")
bucket = s3.Bucket("bernardo-test-bucket-1")
for obj in bucket.objects.filter(Prefix="mnist-100/"):
    if not obj.key.endswith("/"):
        Path(obj.key).parent.mkdir(exist_ok=True)
        bucket.download_file(obj.key, obj.key)

Assume we have a local directory of files that we'd like to ingest:

In [None]:
!ls mnist-100/images*

And a `.csv` file containing the labels for each of the images.

In [None]:
labels_df = pd.read_csv("mnist-100/labels.csv")
labels_df.head()

## Ingest images and labels

Let's ingest each image in the folder as a `File` record by leveraging the `Folder` entity.

In [None]:
img_folder = ln.Folder("mnist-100/images")
ln.add(img_folder);

Let's also ingest the labels file as a single data object.

In [None]:
labels = ln.File("mnist-100/labels.csv")
ln.add(labels);

```{important}

We can equally well pass cloud locations!

```

## Create the PyTorch Dataset

Let's query the relevant data objects to instantiate a [canonical PyTorch custom image dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files).

In [None]:
# define the custom dataset class, as seen in the PyTorch guide

import os
import torch
from torchvision.io import read_image
from torch.utils.data import Dataset


class CustomImageDataset(Dataset):
    def __init__(
        self, annotations_file, img_dir, transform=None, target_transform=None
    ):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path).to(torch.float32)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

The canonical PyTorch dataset takes as input the path to a folder with images.

In [None]:
img_folder = ln.select(ln.Folder).one()
img_folderpath = img_folder.path()

As well as the path to a csv file with labels.

In [None]:
labels = ln.select(ln.File, suffix=".csv").one()
labels_filepath = labels.path()

Let's now instantiate the canonical PyTorch custom image dataset.

In [None]:
dataset = CustomImageDataset(labels_filepath, img_folderpath)

## Create the PyTorch DataLoaders

In [None]:
from torch.utils.data import random_split, DataLoader

# define train and test splits
train_subset, test_subset = random_split(dataset, [80, 20])

# create train an test Dataloaders based on splits
train_loader = DataLoader(train_subset.dataset)
test_loader = DataLoader(test_subset.dataset)

## Train a simple autoencoder

We can now train a [canonical PyTorch Lightning autoencoder](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files).

In [None]:
import torch
from torch import optim, nn
from torchmetrics import Accuracy
import pytorch_lightning as pl

# torch.set_default_dtype(torch.uint8)

encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))


class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.test_accuracy = Accuracy(task="multiclass", num_classes=9)

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


autoencoder = LitAutoEncoder(encoder, decoder)

In [None]:
trainer = pl.Trainer(limit_train_batches=100, max_epochs=5)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)