# Train an autoencoder to get a low-dimensional representation

In [None]:
import lamindb as ln
import anndata as ad
import numpy as np
import scgen

In [None]:
ln.track()

In [None]:
dataset = ln.Dataset.filter(name="My versioned scRNA-seq dataset", version="2").one()

## Train scgen model on the concatenated dataset

In [None]:
data_train = dataset.load(join="inner")

In [None]:
data_train

In [None]:
data_train.obs.file_id.value_counts()

We use `SCGEN` here instead of `SCVI` or `SCANVI` because we have access only to normalized exression data.

In [None]:
scgen.SCGEN.setup_anndata(data_train)

In [None]:
vae = scgen.SCGEN(data_train)

In [None]:
vae.train(max_epochs=1)  # we use max_epochs=1 to be able to run it on CI

## Train on the files iteratively

For a large number of huge files it might be better to train the model iteratively.

In [None]:
file1, file2 = dataset.files.list()

In [None]:
shared_genes = file1.features["var"] & file2.features["var"]
shred_genes_ensembl = shared_genes.list("ensembl_gene_id")

In [None]:
data_train1 = file1.load()[:, shred_genes_ensembl].copy()

In [None]:
data_train1

In [None]:
scgen.SCGEN.setup_anndata(data_train1)

In [None]:
vae = scgen.SCGEN(data_train1)

In [None]:
vae.train(max_epochs=1)  # we use max_epochs=1 to be able to run it on CI

In [None]:
vae.save("saved_models/scgen")

In [None]:
data_train2 = file2.load()[:, shred_genes_ensembl].copy()

In [None]:
data_train2

In [None]:
vae = scgen.SCGEN.load("saved_models/scgen", data_train2)

In [None]:
vae.train(max_epochs=1)  # we use max_epochs=1 to be able to run it on CI

In [None]:
vae.save("saved_models/scgen", overwrite=True)

## Save the model weights

In [None]:
weights = ln.File("saved_models/scgen/model.pt", key="models/scgen/model.pt")

In [None]:
weights.save()

## Get and store the low-dimensional representation

In [None]:
latent1 = vae.get_latent_representation(data_train1)
latent2 = vae.get_latent_representation(data_train2)

latent = np.vstack((latent1, latent2))

In [None]:
adata_latent = ad.AnnData(X=latent)

Set file id.

In [None]:
adata_latent.obs["file_id"] = np.concatenate(
    (np.full(len(data_train1), file1.id), np.full(len(data_train2), file2.id))
)

In [None]:
file_latent = ln.File(adata_latent, key="adata_latent.h5ad")

In [None]:
file_latent.save()

In [None]:
file_latent.genes.set(shared_genes)

In [None]:
file_latent.describe()

## Append the low-dimensional representation to the dataset

In [None]:
dataset_v3 = ln.Dataset(
    dataset.files.list() + [file_latent],
    is_new_version_of=dataset,
)

In [None]:
dataset_v3

In [None]:
dataset_v3.save()

In [None]:
# clean up test instance
!lamin delete --force test-scrna
!rm -r ./test-scrna