## Using the multiVI latent representations predicted the cell-types for the ATAC cells using the RNA cells

In [1]:
!date

Thu Aug 28 11:11:20 AM EDT 2025


#### import libraries

In [2]:
import scvi
import scanpy as sc
from autogluon.tabular import TabularDataset, TabularPredictor
import torch
from anndata import AnnData
from pandas import DataFrame, concat

import warnings
warnings.filterwarnings('ignore')

scvi.settings.seed = 42

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

Seed set to 42


#### set notebook variables

In [20]:
# variables and constants
project = 'aging_phase2'
DEBUG = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_QUAL_PRESET = 'good'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
quants_dir = f'{wrk_dir}/quants'
models_dir = f'{wrk_dir}/models'
figures_dir = f'{wrk_dir}/figures'
sc.settings.figdir = f'{figures_dir}/'

# in files
in_h5ad_file = f'{quants_dir}/{project}.dev.multivi.h5ad'

# out files
out_h5ad_file = f'{quants_dir}/{project}.dev.multivi.annotated.h5ad'
trained_model_path = f'{models_dir}/{project}_dev_trained_cellpred'

if DEBUG:
    print(f'{in_h5ad_file=}')
    print(f'{out_h5ad_file=}')
    print(f'{trained_model_path=}')
    print(f'{device=}')

in_h5ad_file='/labshare/raph/datasets/adrd_neuro/brain_aging/phase2/quants/aging_phase2.dev.multivi.h5ad'
out_h5ad_file='/labshare/raph/datasets/adrd_neuro/brain_aging/phase2/quants/aging_phase2.dev.multivi.annotated.h5ad'
trained_model_path='/labshare/raph/datasets/adrd_neuro/brain_aging/phase2/models/aging_phase2_dev_trained_cellpred'
device='cuda'


#### functions

In [4]:
def peek_anndata(adata: AnnData, message: str=None, verbose: bool=False):
    if not message is None and len(message) > 0:
        print(message)
    print(adata)
    if verbose:
        display(adata.obs.head())
        display(adata.var.head())

def peek_dataframe(df: DataFrame, message: str=None, verbose: bool=False):
    if not message is None and len(message) > 0:
        print(message)
    print(f'{df.shape=}')
    if verbose:
        display(df.head())

## load the multiVI latent features

In [5]:
adata = sc.read_h5ad(in_h5ad_file)
peek_anndata(adata, 'loaded the multiVI anndata', DEBUG)

loaded the multiVI anndata
AnnData object with n_obs × n_vars = 206210 × 3754
    obs: 'sample_id', 'donor_id', 'geno_IID', 'sex', 'ancestry', 'age', 'gex_pool', 'atac_pool', 'pmi', 'ph', 'smoker', 'bmi', 'rin', 'phase1_cluster', 'phase1_celltype', 'Study', 'Study_type', 'cell_label', 'modality', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'n_genes', '_indices', '_scvi_batch', '_scvi_labels', 'leiden_MultiVI', 'umap_density_modality'
    var: 'ID', 'modality', 'chr', 'start', 'end', 'mt', 'ribo', 'hb', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'highly

Unnamed: 0,sample_id,donor_id,geno_IID,sex,ancestry,age,gex_pool,atac_pool,pmi,ph,...,pct_counts_ribo,total_counts_hb,log1p_total_counts_hb,pct_counts_hb,n_genes,_indices,_scvi_batch,_scvi_labels,leiden_MultiVI,umap_density_modality
AAACAGCCAACCCTCC_paired,Aging120,NHBCC-1512,4256126241_A,Female,Caucasian,30.0,non,non,9.5,6.92,...,0.009991,0.0,0.0,0.0,4684,0,2,0,0,0.842875
AAACAGCCAAGTCGCT_paired,Aging121,NHBCC-1890,4572348740_R01C02,Male,Caucasian,42.1,non,non,24.5,6.34,...,0.008122,0.0,0.0,0.0,6169,1,2,0,8,0.579315
AAACAGCCAATCCCTT_paired,Aging120,NHBCC-1512,4256126241_A,Female,Caucasian,30.0,non,non,9.5,6.92,...,0.010535,0.0,0.0,0.0,4520,2,2,0,2,0.704063
AAACAGCCAATTAGGA_paired,Aging121,NHBCC-1890,4572348740_R01C02,Male,Caucasian,42.1,non,non,24.5,6.34,...,0.0,0.0,0.0,0.0,9065,3,2,0,8,0.754612
AAACAGCCAGGTTCAC_paired,Aging119,NHBCC-1390,4463344122_R01C02,Female,African American,42.7,non,non,8.17,6.34,...,0.102522,0.0,0.0,0.0,2542,4,2,0,15,0.046652


Unnamed: 0,ID,modality,chr,start,end,mt,ribo,hb,n_cells_by_counts,mean_counts,...,pct_dropout_by_counts,total_counts,log1p_total_counts,n_cells,highly_variable,highly_variable_rank,means,variances,variances_norm,highly_variable_nbatches
SYNE2,ENSG00000054654,Gene Expression,chr14,63852982,64216322,False,False,False,19161,0.124657,...,91.637024,28561.0,10.259832,18403,True,1390.0,0.132622,0.464146,1.625825,3
ZFP36L1,ENSG00000185650,Gene Expression,chr14,68791242,68794736,False,False,False,11831,0.085459,...,94.836263,19580.0,9.882315,11595,True,1379.0,0.093507,0.339088,1.379622,3
PLEKHH1,ENSG00000054690,Gene Expression,chr14,67533289,67533290,False,False,False,44853,0.395291,...,80.423539,90568.0,11.413867,43528,True,3830.0,0.42851,1.434227,1.177728,2
LINC01500,ENSG00000258583,Gene Expression,chr14,58646630,58828473,False,False,False,13916,0.119005,...,93.926247,27266.0,10.213432,13329,True,1201.0,0.126565,0.501839,1.80924,6
AL161757.4,ENSG00000258776,Gene Expression,chr14,56893248,56893710,False,False,False,383,0.001833,...,99.832836,420.0,6.042633,366,True,,0.001954,0.002552,0.810734,0


## split the anndata into training, test, and inference datasets

here will use the GEX for training, the ARC for test, and ATAC for inference

In [6]:
data_sets = {}
for modality in adata.obs.modality.unique():
    print(modality)
    adata_set = adata[adata.obs.modality == modality]    
    data = TabularDataset(adata_set.obsm['MultiVI_latent'])
    data['cell_label'] = adata_set.obs.cell_label.values
    data.index = adata_set.obs.index.values
    data_sets[modality] = data

if DEBUG:
    for name, data in data_sets.items():
        peek_dataframe(data, f'\n## dataset for {name}', DEBUG)
        display(data.cell_label.value_counts())

paired
expression
accessibility

## dataset for paired
df.shape=(15193, 6)


Unnamed: 0,0,1,2,3,4,cell_label
AAACAGCCAACCCTCC_paired,0.001939,0.243171,-0.065698,-1.148102,-0.206896,Oligodendrocytes
AAACAGCCAAGTCGCT_paired,0.013593,0.04831,-0.45998,-0.116494,-0.504673,Unknown
AAACAGCCAATCCCTT_paired,0.002519,0.056038,-0.208126,-0.822854,-0.287302,ExN RORB
AAACAGCCAATTAGGA_paired,0.004357,-0.153279,-0.045215,-0.419864,-0.314169,Unknown
AAACAGCCAGGTTCAC_paired,0.017464,0.226176,1.247877,-0.355685,0.399117,OPCs


cell_label
Oligodendrocytes    3809
Unknown             3117
Astrocytes          2185
Microglia           1384
OPCs                1232
ExN RORB             741
ExN CUX2             675
ExN SEMA3E           519
InN SST              319
ExN THEMIS           255
InN LAMP5            219
InN VIP              204
InN PVALB            156
ExN LAMP5            130
Fibroblasts          107
Endothelial           92
InN PAX6              38
ExN BCL11B             7
Pericytes/VSMCs        4
Name: count, dtype: int64


## dataset for expression
df.shape=(110676, 6)


Unnamed: 0,0,1,2,3,4,cell_label
AAACCCAAGCCAGACA-1_expression,-0.005614,0.378016,0.704691,1.704608,0.158816,ExN CUX2
AAACCCAAGCTTCGTA-1_expression,0.113897,0.120229,-0.423404,1.841352,-1.241813,ExN RORB
AAACCCAAGGAACGTC-1_expression,0.025832,-0.132238,-0.351114,1.841085,0.01294,ExN CUX2
AAACGCTAGCTAAACA-1_expression,0.049334,-1.188595,0.590004,-0.329548,0.003459,InN VIP
AAACGCTGTGAAGCGT-1_expression,0.028623,1.176501,1.285304,-1.34697,0.769792,OPCs


cell_label
ExN CUX2            18418
ExN SEMA3E          14073
Oligodendrocytes    12687
ExN RORB            10376
ExN LAMP5            8881
InN VIP              6629
InN LAMP5            5807
Astrocytes           5753
OPCs                 5182
InN SST              5107
InN PVALB            4481
ExN BCL11B           4159
ExN THEMIS           2407
Unknown              1624
Microglia            1566
Endothelial          1434
InN PAX6             1223
Fibroblasts           854
Pericytes/VSMCs        15
Name: count, dtype: int64


## dataset for accessibility
df.shape=(80341, 6)


Unnamed: 0,0,1,2,3,4,cell_label
AAACGAAAGAAACGCC-29_accessibility,-0.000411,0.138493,-2.05446,-0.132706,-0.231953,Unknown
AAACGAAAGAAAGGGT-1_accessibility,0.007484,-0.189278,0.049153,-0.036451,-0.652306,Unknown
AAACGAAAGAACGTTA-6_accessibility,-0.008942,-0.092481,-0.273481,0.026667,1.639003,Unknown
AAACGAAAGAACTAAC-2_accessibility,-0.004381,0.029047,0.019562,-0.844669,-0.523771,Unknown
AAACGAAAGAAGAGTG-32_accessibility,-0.003423,-0.067709,-0.2948,-0.104835,1.755333,Unknown


cell_label
Unknown    80341
Name: count, dtype: int64

## use autoML to train a model

### initialize a AutoGluon Tabular Predictor

In [21]:
predictor = TabularPredictor(label='cell_label', path=trained_model_path, 
                             verbosity=2, log_to_file=True, eval_metric='mcc')



### train the predictor model

In [22]:
train_data = data_sets.get('expression')

In [23]:
%%time
predictor.fit(train_data, presets=MODEL_QUAL_PRESET, num_gpus=1)

Preset alias specified: 'good' maps to 'good_quality'.
Verbosity: 2 (Standard Logging)
AutoGluon Version:  1.4.0
Python Version:     3.12.2
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #132-Ubuntu SMP Thu Aug 29 13:45:52 UTC 2024
CPU Count:          64
Memory Avail:       236.68 GB / 1007.74 GB (23.5%)
Disk Space Avail:   2041.85 GB / 205168.86 GB (1.0%)
Presets specified: ['good']
Using hyperparameters preset: hyperparameters='light'
Setting dynamic_stacking from 'auto' to True. Reason: Enable dynamic_stacking when use_bag_holdout is disabled. (use_bag_holdout=False)
Stack configuration (auto_stack=True): num_stack_levels=1, num_bag_folds=8, num_bag_sets=1
Note: `save_bag_folds=False`! This will greatly reduce peak disk usage during fit (by ~8x), but runs the risk of an out-of-memory error during model refit if memory is small relative to the data size.
	You can avoid this risk by setting `save_bag_folds=True`.
DyStack is enabled (dynamic_stacking=True). Au

CPU times: user 4d 14h 51min 28s, sys: 6min 31s, total: 4d 14h 57min 59s
Wall time: 2h 53min 5s


<autogluon.tabular.predictor.predictor.TabularPredictor at 0x7ef2ecba4e00>

In [24]:
%%time
display(train_data.cell_label.value_counts())
x_pred = predictor.predict(train_data)
display(x_pred.value_counts())
eval_results = predictor.evaluate_predictions(train_data.cell_label, x_pred, 
                                              detailed_report=True)
print(f'## {eval_results.get('accuracy')=}')
print(f'## {eval_results.get('balanced_accuracy')=}')
print(f'## Matthews Correlation Coefficient: {eval_results.get('mcc')}')
display(DataFrame(eval_results.get('classification_report')).transpose()
        .sort_values('f1-score', ascending=False))

cell_label
ExN CUX2            18418
ExN SEMA3E          14073
Oligodendrocytes    12687
ExN RORB            10376
ExN LAMP5            8881
InN VIP              6629
InN LAMP5            5807
Astrocytes           5753
OPCs                 5182
InN SST              5107
InN PVALB            4481
ExN BCL11B           4159
ExN THEMIS           2407
Unknown              1624
Microglia            1566
Endothelial          1434
InN PAX6             1223
Fibroblasts           854
Pericytes/VSMCs        15
Name: count, dtype: int64

cell_label
ExN CUX2            18559
ExN SEMA3E          14155
Oligodendrocytes    12699
ExN RORB            10096
ExN LAMP5            8995
InN VIP              6730
Astrocytes           5824
InN LAMP5            5814
OPCs                 5193
InN SST              4785
InN PVALB            4778
ExN BCL11B           4195
ExN THEMIS           2439
Unknown              1512
Microglia            1491
Endothelial          1478
InN PAX6             1126
Fibroblasts           806
Pericytes/VSMCs         1
Name: count, dtype: int64

## eval_results.get('accuracy')=0.9681593118652644
## eval_results.get('balanced_accuracy')=0.9012652344668188
## Matthews Correlation Coefficient: 0.9650218200259223


Unnamed: 0,precision,recall,f1-score,support
Oligodendrocytes,0.99559,0.996532,0.996061,12687.0
ExN SEMA3E,0.992158,0.997939,0.99504,14073.0
OPCs,0.99326,0.995369,0.994313,5182.0
Astrocytes,0.982143,0.994264,0.988166,5753.0
InN LAMP5,0.984004,0.98519,0.984597,5807.0
ExN CUX2,0.976777,0.984255,0.980501,18418.0
InN VIP,0.969985,0.984764,0.977319,6629.0
ExN LAMP5,0.965759,0.978156,0.971918,8881.0
Endothelial,0.957375,0.98675,0.971841,1434.0
accuracy,0.968159,0.968159,0.968159,0.968159


CPU times: user 9min 15s, sys: 1.68 s, total: 9min 17s
Wall time: 20.8 s


### individual model scores for training

In [25]:
print(predictor.model_best)
display(predictor.leaderboard())

WeightedEnsemble_L3_FULL


Unnamed: 0,model,score_val,eval_metric,pred_time_val,fit_time,pred_time_val_marginal,fit_time_marginal,stack_level,can_infer,fit_order
0,WeightedEnsemble_L3,0.946263,mcc,89.295409,1869.133172,0.05603,27.443832,3,False,22
1,WeightedEnsemble_L2,0.94585,mcc,45.964689,458.413161,0.055006,13.748901,2,False,11
2,RandomForestEntr_BAG_L2,0.945654,mcc,62.140407,1326.942596,8.21862,12.545937,2,False,16
3,ExtraTreesEntr_BAG_L2,0.945491,mcc,61.78581,1317.994837,7.864023,3.598177,2,False,19
4,ExtraTreesGini_BAG_L2,0.945473,mcc,61.782391,1318.044592,7.860604,3.647932,2,False,18
5,RandomForestGini_BAG_L2,0.945418,mcc,62.580063,1325.41044,8.658276,11.01378,2,False,15
6,NeuralNetFastAI_BAG_L2,0.944874,mcc,55.194215,1700.12778,1.272428,385.73112,2,False,12
7,XGBoost_BAG_L2,0.944838,mcc,54.976649,1399.673866,1.054862,85.277206,2,False,20
8,CatBoost_BAG_L2,0.944734,mcc,54.070978,1881.278171,0.149191,566.881511,2,False,17
9,LightGBMXT_BAG_L1,0.944336,mcc,35.585213,68.251815,35.585213,68.251815,1,False,2


## check the test datasets predictions

In [26]:
%%time
test_data = data_sets.get('paired').copy()
test_data = test_data.loc[test_data.cell_label != 'Unknown']
y_pred = predictor.predict(test_data)
peek_dataframe(y_pred, 'model predictions for the ARC data', DEBUG)  # Predictions

model predictions for the ARC data
df.shape=(12076,)


AAACAGCCAACCCTCC_paired    Oligodendrocytes
AAACAGCCAATCCCTT_paired    Oligodendrocytes
AAACAGCCAGGTTCAC_paired             InN SST
AAACAGCCATGCAACC_paired    Oligodendrocytes
AAACATGCAAACTGTT_paired            ExN RORB
Name: cell_label, dtype: object

CPU times: user 1min 24s, sys: 961 ms, total: 1min 25s
Wall time: 3.9 s


In [27]:
display(test_data.cell_label.value_counts())
display(y_pred.value_counts())

cell_label
Oligodendrocytes    3809
Astrocytes          2185
Microglia           1384
OPCs                1232
ExN RORB             741
ExN CUX2             675
ExN SEMA3E           519
InN SST              319
ExN THEMIS           255
InN LAMP5            219
InN VIP              204
InN PVALB            156
ExN LAMP5            130
Fibroblasts          107
Endothelial           92
InN PAX6              38
ExN BCL11B             7
Pericytes/VSMCs        4
Unknown                0
Name: count, dtype: int64

cell_label
Oligodendrocytes    3678
Astrocytes          1966
Microglia           1324
OPCs                1200
ExN RORB             846
ExN CUX2             720
ExN SEMA3E           501
InN SST              339
ExN LAMP5            283
ExN BCL11B           278
InN LAMP5            244
ExN THEMIS           198
InN VIP              186
InN PVALB            114
Endothelial           71
Fibroblasts           55
Unknown               54
InN PAX6              19
Name: count, dtype: int64

In [28]:
eval_results = predictor.evaluate_predictions(test_data.cell_label, y_pred, 
                                              detailed_report=True)
print(f'## {eval_results.get('accuracy')=}')
print(f'## {eval_results.get('balanced_accuracy')=}')
print(f'## Matthews Correlation Coefficient: {eval_results.get('mcc')}')
display(DataFrame(eval_results.get('classification_report')).transpose()
        .sort_values('f1-score', ascending=False))

## eval_results.get('accuracy')=0.8928453130175554
## eval_results.get('balanced_accuracy')=0.7071577232427553
## Matthews Correlation Coefficient: 0.87317074962425


Unnamed: 0,precision,recall,f1-score,support
OPCs,0.994167,0.968344,0.981086,1232.0
Oligodendrocytes,0.992931,0.958782,0.975558,3809.0
Microglia,0.979607,0.937139,0.957903,1384.0
ExN SEMA3E,0.974052,0.94027,0.956863,519.0
Astrocytes,0.996948,0.897025,0.944351,2185.0
weighted avg,0.926634,0.892845,0.906502,12076.0
accuracy,0.892845,0.892845,0.892845,0.892845
InN LAMP5,0.844262,0.940639,0.889849,219.0
ExN CUX2,0.848611,0.905185,0.875986,675.0
InN VIP,0.892473,0.813725,0.851282,204.0


In [29]:
predictor.leaderboard(test_data)

Unnamed: 0,model,score_test,score_val,eval_metric,pred_time_test,pred_time_val,fit_time,pred_time_test_marginal,pred_time_val_marginal,fit_time_marginal,stack_level,can_infer,fit_order
0,NeuralNetFastAI_BAG_L2_FULL,0.878545,,mcc,2.580466,,6667.739098,0.27097,,70.063047,2,True,34
1,WeightedEnsemble_L2_FULL,0.874461,,mcc,1.699555,,6561.427152,0.008053,,13.748901,2,True,33
2,NeuralNetFastAI_BAG_L1_FULL,0.873526,,mcc,0.233998,,90.086152,0.233998,,90.086152,1,True,23
3,WeightedEnsemble_L3_FULL,0.873171,,mcc,3.823174,,6744.234845,0.017806,,27.443832,3,True,44
4,LightGBMXT_BAG_L1_FULL,0.872531,,mcc,0.603286,,6399.138931,0.603286,,6399.138931,1,True,24
5,ExtraTreesGini_BAG_L1_FULL,0.87075,,mcc,0.256614,3.77755,1.955435,0.256614,3.77755,1.955435,1,True,29
6,ExtraTreesGini_BAG_L1,0.87075,0.936507,mcc,0.301518,3.77755,1.955435,0.301518,3.77755,1.955435,1,True,7
7,XGBoost_BAG_L2_FULL,0.870435,,mcc,2.439696,,6601.632049,0.1302,,3.955999,2,True,42
8,RandomForestEntr_BAG_L2_FULL,0.869822,,mcc,2.617105,,6610.221987,0.30761,8.21862,12.545937,2,True,38
9,ExtraTreesEntr_BAG_L1_FULL,0.869792,,mcc,0.297758,3.834191,1.844519,0.297758,3.834191,1.844519,1,True,30


## infer the cell labels for the unknown cells

In [30]:
atac_labels = predictor.predict(data_sets.get('accessibility'))

In [31]:
display(atac_labels.value_counts())

cell_label
ExN RORB            19152
Oligodendrocytes    16963
Astrocytes          12510
Microglia            8615
OPCs                 6250
ExN CUX2             5317
ExN LAMP5            4353
ExN BCL11B           2749
InN SST              1798
ExN THEMIS            689
InN LAMP5             688
ExN SEMA3E            670
Unknown               376
InN PVALB             129
InN VIP                48
InN PAX6               31
Endothelial             3
Name: count, dtype: int64

In [32]:
predictor.leaderboard(data_sets.get('accessibility'))

Unnamed: 0,model,score_test,score_val,eval_metric,pred_time_test,pred_time_val,fit_time,pred_time_test_marginal,pred_time_val_marginal,fit_time_marginal,stack_level,can_infer,fit_order
0,LightGBM_BAG_L1_FULL,0.0,,mcc,0.083682,,9.93224,0.083682,,9.93224,1,True,25
1,CatBoost_BAG_L1_FULL,0.0,,mcc,0.127132,,34.930202,0.127132,,34.930202,1,True,28
2,LightGBMLarge_BAG_L1_FULL,0.0,,mcc,0.131118,,45.287262,0.131118,,45.287262,1,True,32
3,XGBoost_BAG_L1_FULL,0.0,,mcc,0.421594,,6.44994,0.421594,,6.44994,1,True,31
4,RandomForestGini_BAG_L1_FULL,0.0,,mcc,0.764538,3.463543,3.179922,0.764538,3.463543,3.179922,1,True,26
5,ExtraTreesGini_BAG_L1_FULL,0.0,,mcc,0.768505,3.77755,1.955435,0.768505,3.77755,1.955435,1,True,29
6,RandomForestGini_BAG_L1,0.0,0.942754,mcc,0.798359,3.463543,3.179922,0.798359,3.463543,3.179922,1,True,4
7,ExtraTreesGini_BAG_L1,0.0,0.936507,mcc,0.798825,3.77755,1.955435,0.798825,3.77755,1.955435,1,True,7
8,ExtraTreesEntr_BAG_L1_FULL,0.0,,mcc,0.810242,3.834191,1.844519,0.810242,3.834191,1.844519,1,True,30
9,RandomForestEntr_BAG_L1_FULL,0.0,,mcc,0.830709,3.8032,4.871447,0.830709,3.8032,4.871447,1,True,27


### for the test dataset what are the important features

In [33]:
%%time
# feat_imp = predictor.feature_importance(train_data)

CPU times: user 10 μs, sys: 0 ns, total: 10 μs
Wall time: 21 μs
