![scrna5/6](https://img.shields.io/badge/scrna5/6-lightgrey)
[![Jupyter Notebook](https://img.shields.io/badge/Source%20on%20GitHub-orange)](https://github.com/laminlabs/lamin-usecases/blob/main/docs/scrna4.ipynb)
[![lamindata](https://img.shields.io/badge/Source%20%26%20report%20on%20LaminHub-mediumseagreen)](https://lamin.ai/laminlabs/lamindata/record/core/Transform?uid=Qr1kIHvK506rz8)

# Train an ML model on a dataset

In the [previous tutorial](scrna3), we loaded an entire dataset into memory to perform a simple analysis.

Here, we'll iterate over the files within the dataset to train an ML model.

In [None]:
import lamindb as ln
import anndata as ad
import numpy as np

In [None]:
ln.track()

## Preprocessing

Let us get our dataset:

In [None]:
dataset_v2 = ln.Dataset.filter(name="My versioned scRNA-seq dataset", version="2").one()
dataset_v2

We'll need to make a decision on the features that we want to use for training the model.

Because each file is validated, they're all indexed by `ensembl_gene_id` in the `var` slot of AnnData.

To make our live easy, we'll intersect features across files:

In [None]:
files = dataset_v2.files.all()
# the gene sets are stored in the "var" slot of features
shared_genes = files[0].features["var"]
for file in files[1:]:
    # QuerySet objects allow set operations
    shared_genes = shared_genes & file.features["var"]
shared_genes_ensembl = shared_genes.list("ensembl_gene_id")

We'll now store the raw representations and create a training dataset:

In [None]:
raw_files = []
for file in files:
    adata_raw = file.load().raw[:, shared_genes_ensembl].to_adata()
    raw_file = ln.File(adata_raw, description=f"Raw data of file {file.uid}")
    raw_files.append(raw_file)
ln.save(raw_files)

ds_train = ln.Dataset(raw_files, name="My training dataset", version="2")
ds_train.save()
ds_train.view_flow()

## PyTorch DataLoader

If you need to train your model on a list of files, you can use {meth}`~lamindb.Dataset.mapped` with the PyTorch `DataLoader`.

It only loads batches into memory and thus allows to work with very large datasets.

In [None]:
from torch.utils.data import DataLoader, WeightedRandomSampler

Files in the dataset should have the same variables, we have already taken care of this.

In [None]:
ds_mapped = ds_train.mapped(label_keys=["cell_type"])

This is compatible with pytorch `DataLoader` because it implements `__getitem__` over a list of `AnnData` files.

In [None]:
ds_mapped[5]

The `labels` are encoded into integers.

In [None]:
ds_mapped.encoders

Let us use a weighted sampler:

In [None]:
# label_key for weight doesn't have to be in labels on init
sampler = WeightedRandomSampler(
    weights=ds_mapped.get_label_weights("cell_type"), num_samples=len(ds_mapped)
)
dl = DataLoader(ds_mapped, batch_size=128, sampler=sampler)

We can now iterate through the data loader:

In [None]:
for batch in dl:
    pass

Close the connections in `MappedDataset`:

In [None]:
ds_mapped.close()

:::{dropdown} In practice, use a context manager

```
with ds_train.mapped(label_keys=["cell_type"]) as ds_mapped:
    sampler = WeightedRandomSampler(
        weights=ds_mapped.get_label_weights("cell_type"), num_samples=len(ds_mapped)
    )
    dl = DataLoader(ds_mapped, batch_size=128, sampler=sampler)
    for batch in dl:
        pass
```
:::

In [None]:
# clean up test instance
!lamin delete --force test-scrna
!rm -r ./test-scrna