# Train on images with queryable metadata

In the previous notebooks, we have ingested a ML dataset, prepared a PyTorch dataset, and trained a simple autoencoder as part of our use case demonstration of LaminDB applied to ML workflows.

When preparing a PyTorch dataset, we manually fetched labels on a `labels.csv` file based on the filename of each data object. It would be much easier, however, if we could to query the label for a given file with the LaminDB API or, similarly, query a collection of data objects based on any given label.

LaminDB allows users to create custom entities and link them to any existing schema entities through simple extensions of its core schema.

In this notebook, we extend the LaminDB schema by creating a `HandwrittenNumber` entity, which represents the labels of our MNIST images. We will see how this feature dramatically improves the experience of handling features and their respectives labels in a ML workflow.

In [None]:
!lamin load mnist-100

## Creating a `HandwrittenNumber` entity linked to a `File`

We can extend the core LaminDB schema by writing our own SQLModel ORMs.

We'd like to link `HandwrittenNumber` records to `File` records (the images).

As `HandwrittenNumber` is not present in any of the default schema modules, we need to go through 3 steps:
- Creation of a LinkORM between the new entity (`HandwrittenNumber`) and the target entity (`File`).
- Creation of the `HandwrittenNumber` ORM with an attribute referencing the target entity (relationship).
- Creation of the ORM relationship attribute in the target entity.

Let's see how we can do this in practice.

```{note}        
For more details about `SQLModel`, see the [SQLModel docs](https://sqlmodel.tiangolo.com/).

For more details about defining relationships with `SQLModel`, please refer to [Relationship Attributes - Intro](https://sqlmodel.tiangolo.com/tutorial/relationship-attributes/).
```

In [None]:
from sqlmodel import SQLModel, Field, Relationship
from lnschema_core import File
from lnschema_core.dev.sqlmodel import add_relationships
from lamindb.setup import settings
from typing import Optional, List


class FileHandwrittenNumber(SQLModel, table=True):
    """Link table between File and HandwrittenNumber."""

    __tablename__ = "file_handwritten_number"

    file_id: str = Field(foreign_key=File.id, primary_key=True)
    handwrittennumber_id: int = Field(
        foreign_key="handwritten_number.id", primary_key=True
    )


class HandwrittenNumber(SQLModel, table=True):
    """Handwritten number entity."""

    __tablename__ = "handwritten_number"

    id: int = Field(primary_key=True)
    files: List[File] = Relationship(
        back_populates="handwritten_numbers",
        link_model=FileHandwrittenNumber,
    )


add_relationships(HandwrittenNumber)

# Create tables in the database
SQLModel.metadata.create_all(settings.instance.engine)

## Linking images to their labels

Now that the schema has been properly extended, lets loop through our data objects once again and link each of them to their respective `HandwrittenNumber` label.

In [None]:
import lamindb as ln

# Select MNIST folder
mnist_folder = ln.select(ln.Folder).one()

# Query and load the data object containing the labels dataframe
labels = ln.select(ln.File, suffix=".csv").one()
labels_df = labels.load()

# Ingest label entities (HandwrittenNumber)
for label in labels_df["label"].unique():
    ln.add(HandwrittenNumber, name=str(label))

# Query all feature data objects (MNIST images)
# ORM relationship attributes are lazy loaded, so we need to bound data objects
# to a session in order to access and assign values to File.handwritten_number
with ln.Session() as session:
    feature_files = (
        session.select(ln.File)
        .join(ln.File.folders)
        .where(ln.Folder.id == mnist_folder.id)
    ).all()

    # Loop through the feature data objects and link them to their respective labels
    updated_files = []
    for file in feature_files:
        label = labels_df.loc[labels_df["filename"] == file.name, "label"].item()
        handwritten_number = session.select(HandwrittenNumber, name=label).one()
        file.handwritten_numbers = [handwritten_number]
        updated_files += [file]

    # Ingest data objects and their linked labels
    session.add(updated_files)

## Querying images by labels

Now that all data objects have been properly indexed with their `HandwrittenNumber` labels, we can proceed to query them with the LaminDB API.

Let's query all data objects associated with the label 2.

In [None]:
files = (
    ln.select(ln.File)
    .join(ln.File.handwritten_numbers)
    .where(HandwrittenNumber.id == 2)
).all()
files

Let's query a random data object and get its label. Also, use the relationship attribute:

In [None]:
import random

with ln.Session() as session:
    files = session.select(ln.File).join(ln.File.handwritten_numbers).all()
    file = random.choice(files)
    file.handwritten_numbers  # lazy load handwritten_number relationships inside of session
file.handwritten_numbers

## Extending our custom PyTorch `Dataset` with the LaminDB API

Instead of having to write custom logic to fetch labels from a pandas `Dataframe`, we can now use the LaminDB API to get the label directly from a data object.

In [None]:
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from PIL import Image


class LNDataset(Dataset):
    def __init__(self, folder: ln.Folder):
        # query files in the data folder
        self.files = (
            ln.select(ln.File).join(ln.File.folders).where(ln.Folder.id == folder.id)
        ).all()

        # define features and labels
        self.feature_files = []
        for file in self.files:
            if file.name == "labels":  # load and define dataframe with labels
                self.img_labels = file.load()
            else:
                self.feature_files += [file]

        # set key torch.utils.data.Dataset attributes
        self.transform = ToTensor()
        self.target_transform = None

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        # bound data object to session in order to lazy load the DOBject.handwritten_number relationship
        with ln.Session() as session:
            # get feature file
            file = session.select(ln.File, id=self.feature_files[idx].id).one()
            # get label from feature file
            label = file.handwritten_numbers
        # get feature (image) from feature file
        path = file.load()
        image = Image.open(path)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label