![scrna5/6](https://img.shields.io/badge/scrna5/6-lightgrey)
[![Jupyter Notebook](https://img.shields.io/badge/Jupyter%20Notebook-orange)](https://github.com/laminlabs/lamin-usecases/blob/main/docs/scrna4.ipynb)
[![lamindata](https://img.shields.io/badge/laminlabs/lamindata-mediumseagreen)](https://lamin.ai/laminlabs/lamindata/record/core/Transform?uid=Qr1kIHvK506rz8)

# Iteratively train an ML model on a dataset

In the [previous tutorial](scrna3), we loaded an entire dataset into memory to perform a simple analysis.

Here, we'll iterate over the files within the dataset to train an ML model.

In [None]:
import lamindb as ln
import anndata as ad
import numpy as np

In [None]:
ln.track()

## Setup

In [None]:
dataset_v2 = ln.Dataset.filter(name="My versioned scRNA-seq dataset", version="2").one()

dataset_v2

We import [scvi-tools](https://github.com/scverse/scvi-tools).

In [None]:
import scvi

Similar to what we did in the [previous tutorial](scrna3), we could load the entire dataset into memory and train a model in 4 lines of code.

:::{dropdown} How would this look like?

```
data_train = dataset_v2.load(join="inner")
scvi.model.SCVI.setup_anndata(data_train)
vae = scvi.model.SCVI(data_train)
vae.train(max_epochs=1)  # we use max_epochs=1 to be able to run it on CI
```

:::

Let us instead load all file records:

In [None]:
file1, file2 = dataset_v2.files.list()

We'd like some context on what the first file contains and where it's from:

In [None]:
file1.describe()
file1.view_flow()

We'll need to make a decision on the features that we want to use for training the model.

Because each file is validated, they're all indexed by `ensembl_gene_id` in the `var` slot of AnnData.

In [None]:
shared_genes = file1.features["var"] & file2.features["var"]
shared_genes_ensembl = shared_genes.list("ensembl_gene_id")

## Train the model

Let us load the first file into memory:

In [None]:
data_train1 = file1.load().raw[:, shared_genes_ensembl].to_adata()
data_train1

Train the model on this first file:

In [None]:
scvi.model.SCVI.setup_anndata(data_train1)
vae = scvi.model.SCVI(data_train1)
vae.train(max_epochs=1)  # we use max_epochs=1 to run it on CI
vae.save("saved_models/scvi1")

Load the second file and resume training the model:

In [None]:
data_train2 = file2.load().raw[:, shared_genes_ensembl].to_adata()
vae = scvi.model.SCVI.load("saved_models/scvi1", data_train2)
vae.train(max_epochs=1)
vae.save("saved_models/scvi1", overwrite=True)

## Save the model

In [None]:
weights = ln.File("saved_models/scvi1/model.pt", description="My trained model")
weights.save()

## Save latent representation as a new dataset

In [None]:
latent1 = vae.get_latent_representation(data_train1)
latent2 = vae.get_latent_representation(data_train2)

adata_latent1 = ad.AnnData(X=latent1, obs=data_train1.obs)
adata_latent2 = ad.AnnData(X=latent2, obs=data_train2.obs)

Because the latent representation is low-dimensional, we can typically fit very high number of observations into memory.

Hence, let's store it as a concatenated adata.

In [None]:
adata_latent = ad.concat([adata_latent1, adata_latent2])

In [None]:
dataset_v2_latent = ln.Dataset(
    adata_latent,
    name="Latent representation of scRNA-seq dataset v2",
    description="For the original data, see dataset T5x0SkRJNviE0jYGbJKt",
)
dataset_v2_latent.save()

Let us look at the data flow:

In [None]:
dataset_v2_latent.view_flow()

Compare this with the model:

In [None]:
weights.view_flow()

Annotate with labels:

In [None]:
dataset_v2_latent.labels.add_from(dataset_v2)

dataset_v2_latent.describe()

In [None]:
# clean up test instance
!lamin delete --force test-scrna
!rm -r ./test-scrna