## scANVI analysis for healthy PBMC pilot study (Cai 2020 and Cai 2022)

**Objective**: Run scANVI analysis for label transfer for healthy PBMCs [Cai 2020 and Cai 2022]


- **Developed by**: Mairi McClean

- **Institute of Computational Biology - Computational Health Centre - Helmholtz Munich**

- v230317

- Following this tutorial: https://docs.scvi-tools.org/en/stable/tutorials/notebooks/scarches_scvi_tools.html
> "This particular workflow is useful in the case where a model is trained on some data (called reference here) and new samples are received (called query)."



In [1]:
# Sanity check
import os 
os.write(1, b"text\n")

text


5

### Import modules 

In [2]:
import sys
import warnings

import anndata
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scrublet as scr
import scvi



Global seed set to 0
  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


In [3]:
warnings.simplefilter(action="ignore", category=FutureWarning)


sc.set_figure_params(figsize=(4, 4))
scvi.settings.seed = 94705

%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'


Global seed set to 94705


### Read in data


In [4]:
yoshida = sc.read_h5ad('/Volumes/Lacie/data_lake/Mairi_example/INBOX/sc_downloads/yoshida_2021/meyer_nikolic_covid_pbmc.cellxgene.20210813.h5ad')

yoshida

AnnData object with n_obs × n_vars = 422220 × 33751
    obs: 'patient_id', 'Ethnicity', 'BMI', 'annotation_broad', 'annotation_detailed', 'annotation_detailed_fullNames', 'Age_group', 'COVID_severity', 'COVID_status', 'Group', 'Sex', 'Smoker', 'sample_id', 'sequencing_library', 'Protein_modality_weight'
    var: 'name'
    obsm: 'X_ umap (wnn derived)', 'X_umap (after harmony ADT)', 'X_umap (after harmony RNA)', 'X_umap (before harmony ADT)', 'X_umap (before harmony RNA)'

In [5]:
yoshida.obs

Unnamed: 0,patient_id,Ethnicity,BMI,annotation_broad,annotation_detailed,annotation_detailed_fullNames,Age_group,COVID_severity,COVID_status,Group,Sex,Smoker,sample_id,sequencing_library,Protein_modality_weight
CV001_KM10202384-CV001_KM10202394_AAACCTGAGGCAGGTT-1,AN5,EUR,Unknown,Monocyte,Monocyte CD14,Classical monocyte,Adult,Healthy,Healthy,Adult,Female,Non-smoker,AN5,CV001_KM10202384-CV001_KM10202394,0.359517
CV001_KM10202384-CV001_KM10202394_AAACCTGAGTGTCCCG-1,AN5,EUR,Unknown,T CD4+,T CD4 helper,T CD4 helper,Adult,Healthy,Healthy,Adult,Female,Non-smoker,AN5,CV001_KM10202384-CV001_KM10202394,0.577522
CV001_KM10202384-CV001_KM10202394_AAACCTGCAGATGGGT-1,AN3,EUR,Unknown,T CD4+,T CD4 helper,T CD4 helper,Adult,Healthy,Healthy,Adult,Male,Non-smoker,AN3,CV001_KM10202384-CV001_KM10202394,0.369143
CV001_KM10202384-CV001_KM10202394_AAACCTGGTATAGTAG-1,AN5,EUR,Unknown,T CD8+,T CD8 naive,T CD8 naive,Adult,Healthy,Healthy,Adult,Female,Non-smoker,AN5,CV001_KM10202384-CV001_KM10202394,0.785563
CV001_KM10202384-CV001_KM10202394_AAACCTGGTGTGCGTC-1,AN5,EUR,Unknown,T CD4+,T CD4 naive,T CD4 naive,Adult,Healthy,Healthy,Adult,Female,Non-smoker,AN5,CV001_KM10202384-CV001_KM10202394,0.564174
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
S28_TTTGTCAGTTCTGTTT-1,PC9,EUR,27.17,NK,NK,NK,Adult,Severe,Post-COVID-19,Adult,Male,Non-smoker,PC9,CV001_KM9294396-CV001_KM9294404,0.429398
S28_TTTGTCATCAACCAAC-1,PC9,EUR,27.17,Monocyte,Monocyte CD14,Classical monocyte,Adult,Severe,Post-COVID-19,Adult,Male,Non-smoker,PC9,CV001_KM9294396-CV001_KM9294404,0.677910
S28_TTTGTCATCATTATCC-1,PC9,EUR,27.17,Monocyte,Monocyte CD14,Classical monocyte,Adult,Severe,Post-COVID-19,Adult,Male,Non-smoker,PC9,CV001_KM9294396-CV001_KM9294404,0.422796
S28_TTTGTCATCCTATGTT-1,PC9,EUR,27.17,DC,pDC,pDC,Adult,Severe,Post-COVID-19,Adult,Male,Non-smoker,PC9,CV001_KM9294396-CV001_KM9294404,0.471905


In [6]:
yoshida.obs['COVID_status'].value_counts()

COVID_status
Healthy          173684
COVID-19         151312
Post-COVID-19     97224
Name: count, dtype: int64

In [7]:
reference = yoshida[yoshida.obs['COVID_status'] == 'Healthy']

reference

View of AnnData object with n_obs × n_vars = 173684 × 33751
    obs: 'patient_id', 'Ethnicity', 'BMI', 'annotation_broad', 'annotation_detailed', 'annotation_detailed_fullNames', 'Age_group', 'COVID_severity', 'COVID_status', 'Group', 'Sex', 'Smoker', 'sample_id', 'sequencing_library', 'Protein_modality_weight'
    var: 'name'
    obsm: 'X_ umap (wnn derived)', 'X_umap (after harmony ADT)', 'X_umap (after harmony RNA)', 'X_umap (before harmony ADT)', 'X_umap (before harmony RNA)'

In [8]:
reference.obs['COVID_status'].value_counts()

COVID_status
Healthy    173684
Name: count, dtype: int64

In [9]:
# Note: Query data is that of only Cai, not the Cai+Yoshida data, but clustered following scvi. This is not the correct data object?

# adata_query = sc.read_h5ad('/Volumes/Lacie/data_lake/Mairi_example/processed_files/scvi/post_sccaf/CaiY_healthy_scRNA_PBMC_mm230316_scVI-clustered.raw.h5ad')

cai = sc.read_h5ad('/Volumes/LaCie/data_lake/Mairi_example/processed_files/abridged_qc/human/Cai2020_scRNA_PBMC_mm230315_qcd.h5ad')
cai

AnnData object with n_obs × n_vars = 73146 × 61533
    obs: 'study', 'individual', 'sample', 'tissue', 'donor', 'age', 'gender', 'status', 'data_type', 'centre', 'version', 'batch', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'percent_mt2', 'n_counts', 'percent_chrY', 'XIST-counts', 'S_score', 'G2M_score', 'predicted_doublets'
    var: 'gene_id', 'mt', 'ribo', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts'
    uns: 'donor_colors'
    layers: 'counts', 'sqrt_norm'

In [10]:
cai.obs

Unnamed: 0_level_0,study,individual,sample,tissue,donor,age,gender,status,data_type,centre,...,pct_counts_mt,total_counts_ribo,pct_counts_ribo,percent_mt2,n_counts,percent_chrY,XIST-counts,S_score,G2M_score,predicted_doublets
barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCTGAGAACAATC-acTB3,CaiY_2021,SAMN14048025,PBMC_TB_3,PBMC,SAMN14048025,34,male,active_TB,scRNAseq,Shenzhen University,...,6.017039,1676.0,29.747959,0.060170,5634.0,0.053248,0.0,-0.352188,-0.193287,0.0
AAACCTGAGAAGGTGA-acTB3,CaiY_2021,SAMN14048025,PBMC_TB_3,PBMC,SAMN14048025,34,male,active_TB,scRNAseq,Shenzhen University,...,1.668552,1901.0,53.761311,0.016686,3536.0,0.113122,0.0,-0.064944,-0.071169,0.0
AAACCTGAGATCTGCT-acTB3,CaiY_2021,SAMN14048025,PBMC_TB_3,PBMC,SAMN14048025,34,male,active_TB,scRNAseq,Shenzhen University,...,9.967497,367.0,19.880823,0.099675,1846.0,0.054171,0.0,-0.231399,-0.080643,0.0
AAACCTGAGCACAGGT-acTB3,CaiY_2021,SAMN14048025,PBMC_TB_3,PBMC,SAMN14048025,34,male,active_TB,scRNAseq,Shenzhen University,...,4.557977,772.0,19.333836,0.045580,3993.0,0.075131,0.0,-0.227884,-0.258770,0.0
AAACCTGAGCGTGAAC-acTB3,CaiY_2021,SAMN14048025,PBMC_TB_3,PBMC,SAMN14048025,34,male,active_TB,scRNAseq,Shenzhen University,...,2.457002,1003.0,49.287468,0.024570,2035.0,0.147420,0.0,0.020959,0.025030,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTCAGTCGCGGTT-H1,CaiY_2021,SAMN14048019,PBMC_HC_1,PBMC,SAMN14048019,26,male,Healthy,scRNAseq,Shenzhen University,...,3.250689,1904.0,52.451790,0.032507,3630.0,0.110193,0.0,-0.165353,-0.250572,0.0
TTTGTCAGTCGCTTCT-H1,CaiY_2021,SAMN14048019,PBMC_HC_1,PBMC,SAMN14048019,26,male,Healthy,scRNAseq,Shenzhen University,...,4.287570,868.0,37.592033,0.042876,2309.0,0.043309,0.0,-0.177318,-0.173327,0.0
TTTGTCAGTTCGAATC-H1,CaiY_2021,SAMN14048019,PBMC_HC_1,PBMC,SAMN14048019,26,male,Healthy,scRNAseq,Shenzhen University,...,6.650446,712.0,28.872669,0.066504,2466.0,0.040551,0.0,-0.235875,-0.013744,0.0
TTTGTCAGTTTGGGCC-H1,CaiY_2021,SAMN14048019,PBMC_HC_1,PBMC,SAMN14048019,26,male,Healthy,scRNAseq,Shenzhen University,...,3.483043,1555.0,47.509930,0.034830,3273.0,0.091659,0.0,-0.050351,0.022269,0.0


In [11]:
cai.obs['status'].value_counts()

status
active_TB    33104
Healthy      22049
latent_TB    17993
Name: count, dtype: int64

In [12]:
query = cai[cai.obs['status'] == 'Healthy']
query

View of AnnData object with n_obs × n_vars = 22049 × 61533
    obs: 'study', 'individual', 'sample', 'tissue', 'donor', 'age', 'gender', 'status', 'data_type', 'centre', 'version', 'batch', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'percent_mt2', 'n_counts', 'percent_chrY', 'XIST-counts', 'S_score', 'G2M_score', 'predicted_doublets'
    var: 'gene_id', 'mt', 'ribo', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts'
    uns: 'donor_colors'
    layers: 'counts', 'sqrt_norm'

In [13]:
query.obs['status'].value_counts()

status
Healthy    22049
Name: count, dtype: int64

In [14]:
def X_is_raw(adata):
    return np.array_equal(adata.X.sum(axis=0).astype(int), adata.X.sum(axis=0))

In [15]:
X_is_raw(query)

False

In [16]:
X_is_raw(query)

False

In [17]:
reference.obs['seed_labels'] = reference.obs['annotation_broad'].copy()


  reference.obs['seed_labels'] = reference.obs['annotation_broad'].copy()


In [18]:
query.obs['seed_labels'] = 'Unknown'

  query.obs['seed_labels'] = 'Unknown'


In [19]:
reference.var

Unnamed: 0,name
MIR1302-2HG,MIR1302-2HG
FAM138A,FAM138A
OR4F5,OR4F5
AL627309.1,AL627309.1
AL627309.3,AL627309.3
...,...
AB-TRAC,AB-TRAC
AB-TRG,AB-TRG
AB-TIGIT,AB-TIGIT
AB-CRLF2,AB-CRLF2


In [20]:
reference.var.rename(columns = {'name': 'gene_name'}, inplace = True)

In [22]:
# Why is this necessary?

reference.obs_names_make_unique()
query.obs_names_make_unique()
reference.var_names_make_unique()
query.var_names_make_unique()

In [23]:
# Check if genes in reference and query are the same
genes_reference = list(reference.var.index)
genes_query = list(query.var.index)
genes_in_reference = [gene for gene in genes_query if gene in genes_reference]

In [24]:
# Check if genes in reference and query are the same
cells_reference = list(reference.obs.index)
cells_query = list(query.obs.index)
cells_in_reference = [cell for cell in cells_query if cell in cells_reference]

- Concatenate both datasets


In [25]:
adata = reference.concatenate(query, batch_categories=['reference', 'query'], batch_key='dataset')

  warn(


In [26]:
adata.X

<195733x22792 sparse matrix of type '<class 'numpy.float32'>'
	with 242201046 stored elements in Compressed Sparse Row format>

In [27]:
adata.obs

Unnamed: 0,patient_id,Ethnicity,BMI,annotation_broad,annotation_detailed,annotation_detailed_fullNames,Age_group,COVID_severity,COVID_status,Group,...,total_counts_ribo,pct_counts_ribo,percent_mt2,n_counts,percent_chrY,XIST-counts,S_score,G2M_score,predicted_doublets,dataset
CV001_KM10202384-CV001_KM10202394_AAACCTGAGGCAGGTT-1-reference,AN5,EUR,Unknown,Monocyte,Monocyte CD14,Classical monocyte,Adult,Healthy,Healthy,Adult,...,,,,,,,,,,reference
CV001_KM10202384-CV001_KM10202394_AAACCTGAGTGTCCCG-1-reference,AN5,EUR,Unknown,T CD4+,T CD4 helper,T CD4 helper,Adult,Healthy,Healthy,Adult,...,,,,,,,,,,reference
CV001_KM10202384-CV001_KM10202394_AAACCTGCAGATGGGT-1-reference,AN3,EUR,Unknown,T CD4+,T CD4 helper,T CD4 helper,Adult,Healthy,Healthy,Adult,...,,,,,,,,,,reference
CV001_KM10202384-CV001_KM10202394_AAACCTGGTATAGTAG-1-reference,AN5,EUR,Unknown,T CD8+,T CD8 naive,T CD8 naive,Adult,Healthy,Healthy,Adult,...,,,,,,,,,,reference
CV001_KM10202384-CV001_KM10202394_AAACCTGGTGTGCGTC-1-reference,AN5,EUR,Unknown,T CD4+,T CD4 naive,T CD4 naive,Adult,Healthy,Healthy,Adult,...,,,,,,,,,,reference
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTCAGTCGCGGTT-H1-query,,,,,,,,,,,...,1904.0,52.451790,0.032507,3630.0,0.110193,0.0,-0.165353,-0.250572,0.0,query
TTTGTCAGTCGCTTCT-H1-query,,,,,,,,,,,...,868.0,37.592033,0.042876,2309.0,0.043309,0.0,-0.177318,-0.173327,0.0,query
TTTGTCAGTTCGAATC-H1-query,,,,,,,,,,,...,712.0,28.872669,0.066504,2466.0,0.040551,0.0,-0.235875,-0.013744,0.0,query
TTTGTCAGTTTGGGCC-H1-query,,,,,,,,,,,...,1555.0,47.509930,0.034830,3273.0,0.091659,0.0,-0.050351,0.022269,0.0,query


In [29]:
adata

AnnData object with n_obs × n_vars = 195733 × 22792
    obs: 'patient_id', 'Ethnicity', 'BMI', 'annotation_broad', 'annotation_detailed', 'annotation_detailed_fullNames', 'Age_group', 'COVID_severity', 'COVID_status', 'Group', 'Sex', 'Smoker', 'sample_id', 'sequencing_library', 'Protein_modality_weight', 'seed_labels', 'study', 'individual', 'sample', 'tissue', 'donor', 'age', 'gender', 'status', 'data_type', 'centre', 'version', 'batch', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'percent_mt2', 'n_counts', 'percent_chrY', 'XIST-counts', 'S_score', 'G2M_score', 'predicted_doublets', 'dataset'
    var: 'gene_id-query', 'mt-query', 'ribo-query', 'n_cells_by_counts-query', 'mean_counts-query', 'pct_dropout_by_counts-query', 'total_counts-query', 'gene_name-reference'

In [33]:
# Replace NaN values in patient ID column

adata.obs['patient_id'] = adata.obs['patient_id'].cat.add_categories("unknown").fillna("unknown")


- Make reference and query objects from original adata object

In [35]:
query = np.array([s in ["query"] for s in adata.obs.dataset])

adata_ref = adata[~query].copy()
adata_query = adata[query].copy()

In [36]:
adata_ref.obs

Unnamed: 0,patient_id,Ethnicity,BMI,annotation_broad,annotation_detailed,annotation_detailed_fullNames,Age_group,COVID_severity,COVID_status,Group,...,total_counts_ribo,pct_counts_ribo,percent_mt2,n_counts,percent_chrY,XIST-counts,S_score,G2M_score,predicted_doublets,dataset
CV001_KM10202384-CV001_KM10202394_AAACCTGAGGCAGGTT-1-reference,AN5,EUR,Unknown,Monocyte,Monocyte CD14,Classical monocyte,Adult,Healthy,Healthy,Adult,...,,,,,,,,,,reference
CV001_KM10202384-CV001_KM10202394_AAACCTGAGTGTCCCG-1-reference,AN5,EUR,Unknown,T CD4+,T CD4 helper,T CD4 helper,Adult,Healthy,Healthy,Adult,...,,,,,,,,,,reference
CV001_KM10202384-CV001_KM10202394_AAACCTGCAGATGGGT-1-reference,AN3,EUR,Unknown,T CD4+,T CD4 helper,T CD4 helper,Adult,Healthy,Healthy,Adult,...,,,,,,,,,,reference
CV001_KM10202384-CV001_KM10202394_AAACCTGGTATAGTAG-1-reference,AN5,EUR,Unknown,T CD8+,T CD8 naive,T CD8 naive,Adult,Healthy,Healthy,Adult,...,,,,,,,,,,reference
CV001_KM10202384-CV001_KM10202394_AAACCTGGTGTGCGTC-1-reference,AN5,EUR,Unknown,T CD4+,T CD4 naive,T CD4 naive,Adult,Healthy,Healthy,Adult,...,,,,,,,,,,reference
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
S22_TTTGTCAGTCGCATCG-1-reference,NP32,EUR,24.7,T CD8+,T CD8 naive,T CD8 naive,Adolescent,Healthy,Healthy,Paediatric,...,,,,,,,,,,reference
S22_TTTGTCAGTGTAAGTA-1-reference,NP32,EUR,24.7,NK,NK IFN stim,NK IFN stim,Adolescent,Healthy,Healthy,Paediatric,...,,,,,,,,,,reference
S22_TTTGTCATCATGTCCC-1-reference,NP31,AFR,18.2,Monocyte,Monocyte CD14,Classical monocyte,Child,Healthy,Healthy,Paediatric,...,,,,,,,,,,reference
S22_TTTGTCATCGAGGTAG-1-reference,NP31,AFR,18.2,T CD8+,T CD8 naive,T CD8 naive,Child,Healthy,Healthy,Paediatric,...,,,,,,,,,,reference


### HVG selection

In [38]:
adata_raw = adata_ref.copy()

In [43]:
# We run highly variable gene selection on the reference data and use these same genes for the query data.

sc.pp.highly_variable_genes(adata_ref, n_top_genes=2000, batch_key="patient_id", subset=True)

adata_query = adata_query[:, adata_ref.var_names].copy()

  result = op(self._deduped_data())


ValueError: cannot specify integer `bins` when input data contains infinity

### Train reference

> From tutorial: SCANVI tends to perform better in situations where it has been initialized using a pre-trained SCVI model. 

- scVI model

In [None]:
scvi.model.SCVI.setup_anndata(adata_ref, batch_key="dataset")

In [None]:
# Added for sarches params following tutorial; check if we require it?
arches_params = dict(
    use_layer_norm="both",
    use_batch_norm="none",
    encode_covariates=True,
    dropout_rate=0.2,
    n_layers=2,
)

vae_ref = scvi.model.SCVI(adata_ref, **arches_params)
vae_ref.train()

In [None]:
adata_ref.obsm["X_scVI"] = vae_ref.get_latent_representation()
sc.pp.neighbors(adata_ref, use_rep="X_scVI")
sc.tl.leiden(adata_ref)
sc.tl.umap(adata_ref)

In [None]:
sc.pl.umap(
    adata_ref,
    color=["dataset", "annotation_broad"],
    frameon=False,
    ncols=1,
)

### Update with query

- scANVI model

In [None]:
adata_ref.obs["labels_scanvi"] = adata_ref.obs["annotation_broad"].values

In [None]:
# unlabeled category does not exist in adata.obs[labels_key]
# so all cells are treated as labeled
vae_ref_scanvi = scvi.model.SCANVI.from_scvi_model(
    vae_ref,
    unlabeled_category="unknown",
    labels_key="labels_scanvi",
)

In [None]:
vae_ref_scanvi.train(max_epochs=20, n_samples_per_label=100)

In [None]:
adata_ref.obsm["X_scANVI"] = vae_ref_scanvi.get_latent_representation()
sc.pp.neighbors(adata_ref, use_rep="X_scANVI")
sc.tl.leiden(adata_ref)
sc.tl.umap(adata_ref)


In [None]:
sc.pl.umap(
    adata_ref,
    color=["leiden", "annotation_broad"],
    frameon=False,
    ncols=1,
)

### Update with query

In [None]:
dir_path_scan = "/Volumes/LaCie/data_lake/Mairi_example/processed_files/scanvi/models/"
vae_ref_scanvi.save(dir_path_scan, overwrite=True)


In [None]:
# again a no-op in this tutorial, but good practice to use
scvi.model.SCANVI.prepare_query_anndata(adata_query, dir_path_scan)

In [None]:
vae_ref_scanvi.registry_

In [None]:
vae_query = scvi.model.SCANVI.load_query_data(
    adata_query,
    dir_path_scan,
)

In [None]:
vae_query.train(
    max_epochs=100,
    plan_kwargs=dict(weight_decay=0.0),
    check_val_every_n_epoch=10,
)

In [None]:
adata_query.obsm["X_scANVI"] = vae_query.get_latent_representation()
adata_query.obs["predictions"] = vae_query.predict()

In [None]:
df = adata_query.obs.groupby(["annotation_broad", "predictions"]).size().unstack(fill_value=0)
norm_df = df / df.sum(axis=0)

plt.figure(figsize=(8, 8))
_ = plt.pcolor(norm_df)
_ = plt.xticks(np.arange(0.5, len(df.columns), 1), df.columns, rotation=90)
_ = plt.yticks(np.arange(0.5, len(df.index), 1), df.index)
plt.xlabel("Predicted")
plt.ylabel("Observed")

### Analyze reference and query

In [None]:
adata_full = adata_query.concatenate(adata_ref)

In [None]:
adata_full.obs.batch.cat.rename_categories(["Query", "Reference"])

In [None]:
full_predictions = vae_query.predict(adata_full)
print(f"Acc: {np.mean(full_predictions == adata_full.obs.annotation_broad)}")

adata_full.obs["predictions"] = full_predictions

In [None]:
sc.pp.neighbors(adata_full, use_rep="X_scANVI")
sc.tl.leiden(adata_full)
sc.tl.umap(adata_full)

In [None]:
sc.pl.umap(
    adata_full,
    color=["dataset", "annotation_broad"],
    frameon=False,
    ncols=1,
)

In [None]:
ax = sc.pl.umap(
    adata_full,
    frameon=False,
    show=False,
)
sc.pl.umap(
    adata_full[: adata_query.n_obs],
    color=["predictions"],
    frameon=False,
    title="Query predictions",
    ax=ax,
    alpha=0.7,
)

ax = sc.pl.umap(
    adata_full,
    frameon=False,
    show=False,
)
sc.pl.umap(
    adata_full[: adata_query.n_obs],
    color=["annotation_broad"],
    frameon=False,
    title="Query observed cell types",
    ax=ax,
    alpha=0.7,
)