# scArches Annotation 

This tutorial is based off of: https://docs.scarches.org/en/latest/hlca_map_classify.html

### Setting up environment: 

In [4]:
import os
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=DeprecationWarning)
warnings.simplefilter(action="ignore", category=UserWarning)

import sys
import scanpy as sc
import numpy as np
import pandas as pd
import scarches as sca
import anndata as ad
from scipy import sparse
import gdown
import gzip
import shutil
import urllib.request

 captum (see https://github.com/pytorch/captum).


### Set directory paths

In [6]:
ref_model_dir_prefix = "."  # directory in which to store the reference model directory
surgery_model_dir_prefix = "."  # directory in which to store the surgery model directory
path_reference_emb = "../data/hlca_core.h5ad" # path to reference embedding to be created

path_query_data = "../../data/cellref/cellref_set1.h5ad" ## CHANGE THIS to input test query data

ref_model_dir = os.path.join(ref_model_dir_prefix, "HLCA_reference_model") # don't change this
surgery_model_dir = os.path.join(surgery_model_dir_prefix, "surgery_model") # don't change this

### Download reference data and reference model from Zenodo

In [None]:
# download reference model 
url = "https://zenodo.org/record/7599104/files/HLCA_reference_model.zip"
output = "HLCA_reference_model.zip"
gdown.download(url, output, quiet=False)
shutil.unpack_archive("HLCA_reference_model.zip", extract_dir=ref_model_dir_prefix)
os.remove(output)

# download embedding of HLCA reference (2.3 gb)
url = "https://zenodo.org/record/7599104/files/HLCA_full_v1.1_emb.h5ad"
output = path_reference_emb
gdown.download(url, output, quiet=False)

In [None]:
# load reference embedding 
adata_ref = sc.read_h5ad(path_reference_emb)

# subset to only HLCA core
adata_ref = adata_ref[adata_ref.obs.core_or_extension == "core", :].copy()

# remove all obs variables that were only relevant to HLCA extension data (healthy + diseased dataset)
cols_to_drop = [col for col in adata_ref.obs.columns if adata_ref.obs[col].isnull().all()]
adata_ref.obs.drop(columns=cols_to_drop, inplace=True)
adata_ref

### Load and prepare query data

In [7]:
# load your query dataset
adata_query_unprep = sc.read_h5ad(path_query_data)

# verify query data is sparse matrix 
adata_query_unprep.X = sparse.csr_matrix(adata_query_unprep.X)

# remove obsm and varm matrices to avoid issues downstream
del adata_query_unprep.obsm
del adata_query_unprep.varm

**Check whether query dataset has raw counts or normalized (we want raw counts):**

In [10]:
adata_query_unprep.X

<115990x32284 sparse matrix of type '<class 'numpy.float64'>'
	with 330795181 stored elements in Compressed Sparse Row format>

In [8]:
adata_query_unprep.X[:10, :30].toarray()

array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.12, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.12, 0.  , 0.  , 0.  , 0.  , 0.12, 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.62, 0.45, 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.96, 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.36, 0.  ,
        0.  , 0.  , 0.  , 0.  , 

In [None]:
# if adata_query_unprep.X is not raw counts, change to raw counts
adata_query_unprep.X = adata_query_unprep.raw.X

**Check whether your reference model uses gene names or gene ids as input features.**  
The HLCA reference model that was just downloaded uses ensembl ID, so we need to verify:  
1.) whether the query data also uses ensembl ID   
2.) whether it is stored in `adata_query_unprep.var.index`

View the features used by the reference model:

In [None]:
# load reference model features
ref_model_features = pd.read_csv(os.path.join(ref_model_dir, "var_names.csv"), header=None)
print(ref_model_features.head(5))

In [None]:
# view query data input features
adata_query_unprep.var

## If your query dataset does not use ensembl ID, you have to map the features it uses to ensembl ID
**If the input features map your reference, you can move on to the next section!**

This example maps gene names to ensembl ID

In [None]:
# set path 
path_gene_mapping_df = os.path.join(ref_model_dir, "HLCA_reference_model_gene_order_ids_and_symbols.csv")

# download gene information from HLCA github: 
url = "https://zenodo.org/record/7599104/files/HLCA_reference_model_gene_order_ids_and_symbols.csv"
gdown.download(url, path_gene_mapping_df, quiet=False)

In [None]:
# load mapping table
gene_id_to_gene_name_df = pd.read_csv(path_gene_mapping_df, index_col=0)
gene_id_to_gene_name_df

In [None]:
# store the current input feature in a new column IF it currently resides in the index: 
#adata_query_unprep.var["gene_names"] = adata_query_unprep.var.index

# find the column in the query data that contains gene names:
gene_name_column_name = "features" ## CHANGE to match your data

# map genes and see how much overlap there is btwn genes in query and genes in ref model
n_overlap = (
    adata_query_unprep.var[gene_name_column_name]
    .isin(gene_id_to_gene_name_df.gene_symbol)
    .sum()
)
n_genes_model = gene_id_to_gene_name_df.shape[0]
print(
    f"Number of model input genes detected: {n_overlap} out of {n_genes_model} ({round(n_overlap/n_genes_model*100)}%)"
)

In [None]:
# subset your data to genes used in the reference model
adata_query_unprep = adata_query_unprep[:,adata_query_unprep.var[gene_name_column_name].isin(
    gene_id_to_gene_name_df.gene_symbol),].copy()  

# add gene ids for the gene names, and store in .var.index
adata_query_unprep.var.index = adata_query_unprep.var[gene_name_column_name].map(
    dict(zip(gene_id_to_gene_name_df.gene_symbol, gene_id_to_gene_name_df.index))
)  

# remove index name to prevent bugs later on
adata_query_unprep.var.index.name = None
adata_query_unprep.var["gene_ids"] = adata_query_unprep.var.index

# check that the mapping was successful -- the index should be ensembl ID now
adata_query_unprep.var.head(3)

### If your query data does have the correct input features, continue here to prepare your data for scArches:

In [None]:
# pad missing genes in query data with zeros to have the same gene set as the reference data
adata_query = sca.models.SCANVI.prepare_query_anndata(
    adata=adata_query_unprep, reference_model=ref_model_dir, inplace=False
)

# set dataset to be a single batch
adata_query.obs["dataset"] = "lung_batch"

# set scanvi_label to unlabeled
adata_query.obs["scanvi_label"] = "unlabeled"

adata_query

## Perform surgery

In [None]:
# load reference model
surgery_model = sca.models.SCANVI.load_query_data(
    adata_query,
    ref_model_dir,
    freeze_dropout=True,
)

print(surgery_model)

# view setup args
print(surgery_model.registry_["setup_args"])

In [None]:
# Set training arguments
surgery_epochs = 500
early_stopping_kwargs_surgery = {"early_stopping_monitor": "elbo_train",
                                 "early_stopping_patience": 10,
                                 "early_stopping_min_delta": 0.001,
                                 "plan_kwargs": {"weight_decay": 0.0}}

# Performing surgery by training
surgery_model.train(accelerator='cpu', devices=1, max_epochs=surgery_epochs, **early_stopping_kwargs_surgery)

In [None]:
# either save model:
surgery_model.save(surgery_model_dir, overwrite=True)

# or load if trained already: 
surgery_model = sca.models.SCANVI.load(
    surgery_model_dir, adata_query
)

## Latent Embedding

In [None]:
# Getting query latent embedding
adata_query_latent = sc.AnnData(surgery_model.get_latent_representation(adata_query))
adata_query_latent.obs = adata_query.obs.loc[adata_query.obs.index, :]

# Creating joint embedding
adata_query_latent.obs["ref_or_query"] = "query"
adata_ref.obs["ref_or_query"] = "ref"
combined_emb = sc.concat(
    (adata_ref, adata_query_latent), index_unique=None, join="outer"
)  

# Saving embedding
for cat in adata_query_latent.obs.columns:
    if isinstance(adata_query_latent.obs[cat].values, pd.Categorical):
        pass
    elif pd.api.types.is_float_dtype(adata_query_latent.obs[cat]):
        pass
    else:
        print(f"Setting obs column {cat} (not categorical neither float) to strings to prevent writing error.")
        adata_query_latent.obs[cat] = adata_query_latent.obs[cat].astype(str)

In [None]:
# either save embedding: 
combined_emb.write_h5ad("combined_embedding.h5ad")

# or load if already saved:
combined_emb = sc.read_h5ad("combined_embedding.h5ad")

# Label Transfer

In [None]:
# load cell type df
path_celltypes = os.path.join(ref_model_dir, "HLCA_celltypes_ordered.csv")
url = "https://github.com/LungCellAtlas/HLCA_reproducibility/raw/main/supporting_files/celltype_structure_and_colors/manual_anns_and_leveled_anns_ordered.csv" # "https://github.com/LungCellAtlas/mapping_data_to_the_HLCA/raw/main/supporting_files/HLCA_celltypes_ordered.csv"
gdown.download(url, path_celltypes, quiet=False)

cts_ordered = pd.read_csv(path_celltypes, index_col=0).rename(
    columns={f"Level_{lev}": f"labtransf_ann_level_{lev}" for lev in range(1, 6)}
)

cts_ordered.head(5)

In [None]:
# add annotations for all five levels to reference
adata_ref.obs = adata_ref.obs.join(cts_ordered, on="ann_finest_level")

In [None]:
# knn transformer
knn_transformer = sca.utils.knn.weighted_knn_trainer(
    train_adata=adata_ref,
    train_adata_emb="X",  # location of our joint embedding
    n_neighbors=50,
)

# label transfer
labels, uncert = sca.utils.knn.weighted_knn_transfer(
    query_adata=adata_query_latent,
    query_adata_emb="X",  # location of our embedding, query_adata.X in this case
    label_keys="labtransf_ann_level_",  # (start of) obs column name(s) for which to transfer labels
    knn_model=knn_transformer,
    ref_adata_obs=adata_ref.obs,
)

In [None]:
# saving results
true_label_name = 'ann_finest_level' ## CHANGE
df = pd.merge(labels, uncert, left_index=True, right_index=True, suffixes = ("", "_uncert"))
df = pd.merge(df, adata_query_latent.obs[true_label_name], left_index=True, right_index=True)
df.head()