# Working with AI models

In this tutorial, we will guide you through basic steps of how to download

1) Download external datasets
2) Prepare, process and train data
3) Predict your own dataset (query)
4) Extend the model with additional data

In [None]:
!pip install --quiet scvi-colab
from scvi_colab import install

install()

In [None]:
import scvi
import torch
import anndata

import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)

In [None]:
sc.set_figure_params(figsize=(6, 6), frameon=False)
sns.set_theme()
torch.set_float32_matmul_precision("high")
save_dir = tempfile.TemporaryDirectory()

%config InlineBackend.print_figure_kwargs={"facecolor": "w"}
%config InlineBackend.figure_format="retina"

## 1. Download public datasets

To download dataset we suggest using automated pipelines from [nf-core](https://nf-co.re). 

In order to to download publicly available datasets, you can execute:

```bash
nextflow run nf-core/fetchngs -profile test,docker --input ./ids.csv --output ./results
```

Next, alignment to reference genome is required. In our model we have used the following genomes:

- **Mouse (GRCm38 v102)**
    - `Mus_musculus.GRCm38.dna_sm.primary_assembly.fa`
    - `Mus_musculus.GRCm38.102.gtf`
- **Human (GRCh38 v110)**
    - `Homo_sapiens.GRCh38.dna_sm.primary_assembly.fa`
    - `Homo_sapiens.GRCh38.110.gtf`

### For SMART-seq experiments run:`

```bash
nextflow run brickmanlab/scrnaseq -r feature/smartseq -c smartseq.human.config --input ./results/samplesheet/samplesheet.csv
```

with `smartseq.human.config`:

```groovy
process {
    withName: STAR_ALIGN {
        ext.args = "--readFilesCommand zcat --soloUMIdedup NoDedup --soloStrand Unstranded"
    }
}
```

### For 10X Chromium datasets run:

```bash
nextflow run brickmanlab/scrnaseq -c 10X.human.config --input ./results/samplesheet/samplesheet.csv
```

with `10X.human.config`:

```groovy
process {
    aligner = "star"
    protocol = "10XV3" # or "10XV2"
}
```

If everything went correct, you should see your raw count matrix in `results/star/mtx_conversions/combined_matrix.h5ad`.

Repeat the same process for all the wanted datasets. If in doubt, please see our notebooks [01_fetchngs_mouse.ipynb](../notebooks/01_fetchngs_mouse.ipynb) and [01_fetchngs_human.ipynb](../notebooks/01_fetchngs_human.ipynb) for inspiration. 

## 2. Prepare, process and train data

In [None]:
import anndata


# concat all datasets into one
adata = anndata.concat([...])

# make sure cells and gene names are unique
adata.obs_names_make_unique()
adata.var_names_make_unique()

In [None]:
adata.var["mt"] = adata.var_names.str.startswith("mt-")
# or for human
# adata.var["mt"] = adata.var_names.str.startswith("MT-")

sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True)

In [None]:
sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts', 'pct_counts_mt'], jitter=0.4, multi_panel=True)

In [None]:
adata.layers["counts"] = adata.X.copy()

sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
adata.raw = adata

adata

In [None]:
sc.tl.pca(adata, svd_solver='arpack')
sc.pl.pca(adata, color=['batch'], frameon=False, wspace=0.4, ncols=2)

In [None]:
adata.write("adata_raw.h5ad")

### 2.1. Train model

We use scvi-tool to build our models which are build using variational encoder. For detailed information refer to [scvi-tools](http://scvi-tools.org). 

In [None]:
adata.uns['log1p']["base"] = None
sc.pp.highly_variable_genes(
    adata,
    flavor="cell_ranger",
    n_top_genes=3_000,
    batch_key="batch",
    subset=True,
)
adata.shape

Figuring out correct parameters for the training can be quite tricky. There are few options, either testing few parameters and see if the training improved, or using the scvi's [autone](https://docs.scvi-tools.org/en/stable/tutorials/notebooks/tuning/autotune_new_model.html) tutorial.

From our experience, the biggest difference can be immediatelly seen by adjusting:

- `n_layers`: 2 .. 5
- `gene_likelihood`: `nb` or `zinb`

In [None]:
vae = scvi.model.SCVI(adata, n_layers=2, gene_likelihood='nb')

We recommend using `early_stopping` flag

In [None]:
vae.train(early_stopping=True)

Make sure to next inspect the train/test ELBO validation plot. Both metrics should converge into a specific value. If that's not the case, try training for longer period or adjust the parameters.

In [None]:
pd.concat([vae.history['elbo_train'], vae.history['elbo_validation']], axis=1).plot.line(marker='o')

In [None]:
adata.obsm["X_scVI"] = vae.get_latent_representation(adata)
adata.obsm["X_mde_scVI"] = mde(adata.obsm["X_scVI"])

adata.layers['scVI_normalized'] = vae.get_normalized_expression(return_numpy=True)

# Save the model
vae.save("../results/scvi", overwrite=True, save_anndata=True)

### 2.2. Classifier

To train classifier, your cells have to be labeled by cell types. This has to be determined in advance. **Remember**, the classifier is only as good as the actual data it is trained on.

In case you training is not performing well or you have small number of cells try setting `linear_classifier=True` in the code below.

In [None]:
lvae = scvi.model.SCANVI.from_scvi_model(vae, adata=adata, labels_key="annotations", unlabeled_category="Unknown")
# lvae = scvi.model.SCANVI.from_scvi_model(vae, adata=adata, labels_key="annotations", unlabeled_category="Unknown", linear_classifier=True)
lvae

In [None]:
max_epochs_scanvi = int(np.min([10, np.max([2, round(200 / 3.0)])]))
print(max_epochs_scanvi)

In case of smaller number of cells increasing the sampling size (`n_samples_per_label`) helps with increasing the overall performance.

In [None]:
lvae.train(max_epochs=15)
# lvae.train(max_epochs=15, n_samples_per_label=15)

In [None]:
# Inspect the training

fig, ax = plt.subplots(3, 3, figsize=[20, 14])
for idx, key in enumerate(lvae.history.keys()):
    lvae.history[key].plot(title=key, ax=ax[idx // 3 , idx % 3])

In [None]:
adata.obsm["X_scANVI"] = lvae.get_latent_representation(adata)
adata.obsm["X_mde_scANVI"] = mde(adata.obsm["X_scANVI"])

adata.layers['scANVI_normalized'] = lvae.get_normalized_expression(return_numpy=True)

# Save the classifier model
lvae.save("../results/scanvi", overwrite=True, save_anndata=True)

## 3. Predict your dataset

Either use your own trained classifier for prediction

In [None]:
lvae = scvi.model.SCANVI.load("../results/scanvi/")

or use our pre-made one from the publication

In [None]:
from scvi.hub import HubModel


lvae = HubModel.pull_from_huggingface_hub(
    repo_name="brickmanlab/human-scanvi",
    cache_dir="./human_scanvi_model_downloaded/",
    revision="main",
)

In [None]:
query = anndata.read_h5ad('./query.h5ad')

In [None]:
scvi.model.SCANVI.prepare_query_anndata(nned, lvae)

lvae_q = scvi.model.SCANVI.load_query_data(query, lvae)
lvae_q.train(max_epochs=100, plan_kwargs=dict(weight_decay=0.0), check_val_every_n_epoch=10)

In [None]:
query.obs['prediction'] = lvae_q.predict()
query.obs['entropy'] = 1 - lvae_q.predict(soft=True).max(axis=1)

In [None]:
sc.pl.pca(query, color=['batch', 'prediction', 'entropy'])

In [None]:
query.save('../results/query.predicted.h5ad')

## 4. Extend your model with new data

In [None]:
adata_newds = sc.read_h5ad('./adata_newds.h5ad')

In [None]:
extended_adata = ad.concat([lvae.adata, adata_newds])

In [None]:
scvi.model.SCVI.setup_anndata(extended_adata, layer="counts", batch_key="batch",)
extended_vae = scvi.model.SCVI(extended_adata, n_layers=2, gene_likelihood='nb')

In [None]:
extended_vae.train()

In [None]:
extended_adata.obsm["X_scVI"] = extended_vae.get_latent_representation(extended_adata)
extended_adata.layers['scVI_normalized'] = extended_vae.get_normalized_expression(return_numpy=True)

### 4.1. Analysis

In [None]:
USE_REP = 'X_scVI'

sc.pp.neighbors(extended_adata, use_rep=USE_REP)
sc.tl.leiden(extended_adata)
sc.tl.pca(extended_adata)
sc.tl.draw_graph(extended_adata, n_jobs=8, random_state=3)

sc.pp.neighbors(extended_adata, use_rep=USE_REP)
sc.tl.diffmap(extended_adata)

sc.tl.paga(extended_adata, groups='prediction')
sc.pl.paga(extended_adata, color=['prediction'], frameon=False, fontoutline=True)
sc.tl.draw_graph(extended_adata, init_pos='paga', n_jobs=10)

In [None]:
# Save the dataset
extended_adata.write('./results/extended_adata.h5ad')