## Using the multiVI latent representations predicted the cell-types for the ATAC cells using the RNA cells

In [None]:
!date

#### import libraries

In [None]:
import scvi
import scanpy as sc
from autogluon.tabular import TabularDataset, TabularPredictor
import torch
from anndata import AnnData
from pandas import DataFrame, concat

import warnings
warnings.filterwarnings('ignore')

scvi.settings.seed = 42

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

#### set notebook variables

In [None]:
# variables and constants
project = 'aging_phase2'
DEBUG = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_QUAL_PRESET = 'good'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
quants_dir = f'{wrk_dir}/quants'
models_dir = f'{wrk_dir}/models'
figures_dir = f'{wrk_dir}/figures'
sc.settings.figdir = f'{figures_dir}/'

# in files
in_h5ad_file = f'{quants_dir}/{project}.dev.multivi.h5ad'

# out files
out_h5ad_file = f'{quants_dir}/{project}.dev.multivi.annotated.h5ad'
trained_model_path = f'{models_dir}/{project}_dev_trained_cellpred'

if DEBUG:
    print(f'{in_h5ad_file=}')
    print(f'{out_h5ad_file=}')
    print(f'{trained_model_path=}')
    print(f'{device=}')

#### functions

In [None]:
def peek_anndata(adata: AnnData, message: str=None, verbose: bool=False):
    if not message is None and len(message) > 0:
        print(message)
    print(adata)
    if verbose:
        display(adata.obs.head())
        display(adata.var.head())

def peek_dataframe(df: DataFrame, message: str=None, verbose: bool=False):
    if not message is None and len(message) > 0:
        print(message)
    print(f'{df.shape=}')
    if verbose:
        display(df.head())

## load the multiVI latent features

In [None]:
adata = sc.read_h5ad(in_h5ad_file)
peek_anndata(adata, 'loaded the multiVI anndata', DEBUG)

## split the anndata into training, test, and inference datasets

here will use the GEX for training, the ARC for test, and ATAC for inference

In [None]:
data_sets = {}
for modality in adata.obs.modality.unique():
    print(modality)
    adata_set = adata[adata.obs.modality == modality]    
    data = TabularDataset(adata_set.obsm['MultiVI_latent'])
    data['cell_label'] = adata_set.obs.cell_label.values
    data.index = adata_set.obs.index.values
    data_sets[modality] = data

if DEBUG:
    for name, data in data_sets.items():
        peek_dataframe(data, f'\n## dataset for {name}', DEBUG)
        display(data.cell_label.value_counts())

## use autoML to train a model

### initialize a AutoGluon Tabular Predictor

In [None]:
predictor = TabularPredictor(label='cell_label', path=trained_model_path, 
                             verbosity=2, log_to_file=True, eval_metric='mcc')

### train the predictor model

In [None]:
train_data = data_sets.get('expression')

In [None]:
%%time
predictor.fit(train_data, presets=MODEL_QUAL_PRESET, num_gpus=1)

In [None]:
%%time
display(train_data.cell_label.value_counts())
x_pred = predictor.predict(train_data)
display(x_pred.value_counts())
eval_results = predictor.evaluate_predictions(train_data.cell_label, x_pred, 
                                              detailed_report=True)
print(f'## {eval_results.get('accuracy')=}')
print(f'## {eval_results.get('balanced_accuracy')=}')
print(f'## Matthews Correlation Coefficient: {eval_results.get('mcc')}')
display(DataFrame(eval_results.get('classification_report')).transpose()
        .sort_values('f1-score', ascending=False))

### individual model scores for training

In [None]:
print(predictor.model_best)
display(predictor.leaderboard())

## check the test datasets predictions

In [None]:
%%time
test_data = data_sets.get('paired').copy()
test_data = test_data.loc[test_data.cell_label != 'Unknown']
y_pred = predictor.predict(test_data)
peek_dataframe(y_pred, 'model predictions for the ARC data', DEBUG)  # Predictions

In [None]:
display(test_data.cell_label.value_counts())
display(y_pred.value_counts())

In [None]:
eval_results = predictor.evaluate_predictions(test_data.cell_label, y_pred, 
                                              detailed_report=True)
print(f'## {eval_results.get('accuracy')=}')
print(f'## {eval_results.get('balanced_accuracy')=}')
print(f'## Matthews Correlation Coefficient: {eval_results.get('mcc')}')
display(DataFrame(eval_results.get('classification_report')).transpose()
        .sort_values('f1-score', ascending=False))

In [None]:
predictor.leaderboard(test_data)

## infer the cell labels for the unknown cells

In [None]:
atac_labels = predictor.predict(data_sets.get('accessibility'))

In [None]:
display(atac_labels.value_counts())

In [None]:
print(predictor.model_best)
predictor.leaderboard(data_sets.get('accessibility'))

### for the test dataset what are the important features

In [None]:
%%time
# feat_imp = predictor.feature_importance(train_data)