## Using the multiVI latent representations predicted the cell-types for the ATAC cells using the RNA cells

In [None]:
!date

#### import libraries

In [None]:
import scvi
import scanpy as sc
from autogluon.tabular import TabularDataset, TabularPredictor
import torch
from anndata import AnnData
from pandas import DataFrame, concat, Series
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context
from seaborn import barplot

import warnings
warnings.filterwarnings('ignore')

scvi.settings.seed = 42

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

#### set notebook variables

In [None]:
# variables and constants
project = 'aging_phase2'
DEBUG = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_QUAL_PRESET = 'good'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
quants_dir = f'{wrk_dir}/quants'
models_dir = f'{wrk_dir}/models'
figures_dir = f'{wrk_dir}/figures'
sc.settings.figdir = f'{figures_dir}/'

# in files
in_h5ad_file = f'{quants_dir}/{project}.dev.multivi.h5ad'

# out files
out_h5ad_file = f'{quants_dir}/{project}.dev.multivi.annotated.h5ad'
trained_model_path = f'{models_dir}/{project}_dev_trained_cellpred'

if DEBUG:
    print(f'{in_h5ad_file=}')
    print(f'{out_h5ad_file=}')
    print(f'{trained_model_path=}')
    print(f'{device=}')

#### functions

In [None]:
def peek_anndata(adata: AnnData, message: str=None, verbose: bool=False):
    if not message is None and len(message) > 0:
        print(message)
    print(adata)
    if verbose:
        display(adata.obs.head())
        display(adata.var.head())

def peek_dataframe(df: DataFrame, message: str=None, verbose: bool=False):
    if not message is None and len(message) > 0:
        print(message)
    print(f'{df.shape=}')
    if verbose:
        display(df.head())

def eval_classifier_model(model: TabularPredictor, this_data: TabularDataset, 
                          target: str, verbose: bool=False) -> (Series, DataFrame):

    x_pred = model.predict(this_data)
    eval_results = model.evaluate_predictions(this_data[target], x_pred, 
                                              detailed_report=True)
    print(f'## {target=} {eval_results.get('accuracy')=}')
    print(f'## {target=} {eval_results.get('balanced_accuracy')=}')
    print(f'## {target=} Matthews Correlation Coefficient: {eval_results.get('mcc')}')
    ret_df = (DataFrame(eval_results.get('classification_report')).transpose()
              .sort_values('f1-score', ascending=False))
    ret_df['cell_type'] = ret_df.index.values
    ret_df = ret_df.loc[~ret_df.cell_type.isin(['accuracy', 'macro avg', 
                                                'weighted avg'])].reset_index(drop=True)    
    if verbose:
        display(this_data[target].value_counts())
        display(x_pred.value_counts())
        display(ret_df)
    return x_pred, ret_df

## load the multiVI latent features

In [None]:
adata = sc.read_h5ad(in_h5ad_file)
peek_anndata(adata, 'loaded the multiVI anndata', DEBUG)

## split the anndata into training, test, and inference datasets

here will use the GEX and ARC for train/test, and ATAC for inference

In [None]:
adata_train_test = adata[adata.obs.modality.isin(['expression', 'paired'])]
adata_infer = adata[adata.obs.modality == 'accessibility']

train_test_data = TabularDataset(adata_train_test.obsm['MultiVI_latent'])
train_test_data['cell_label'] = adata_train_test.obs.cell_label.values
train_test_data.index = adata_train_test.obs.index.values

inference_data = TabularDataset(adata_infer.obsm['MultiVI_latent'])
inference_data['cell_label'] = adata_infer.obs.cell_label.values
inference_data.index = adata_infer.obs.index.values

### split the train and test dataset

In [None]:
X_train, X_test, y_train, y_test = train_test_split(train_test_data.drop(columns=['cell_label']), 
                                                    train_test_data['cell_label'], 
                                                    stratify=train_test_data['cell_label'],
                                                    test_size=0.3, random_state=42)
train_data = train_test_data.loc[X_train.index]
test_data = train_test_data.loc[X_test.index]

# make sure there are not any unlabeled cells in the train and test
train_data = train_data.loc[train_data.cell_label != 'Unknown']
test_data = test_data.loc[test_data.cell_label != 'Unknown']

peek_dataframe(train_data, 'training dataframe', DEBUG)
peek_dataframe(test_data, 'testing dataframe', DEBUG)

In [None]:
if DEBUG:
    display(train_data.cell_label.value_counts())
    display(test_data.cell_label.value_counts())
    display(inference_data.cell_label.value_counts())

#### how many of the test split are ARC samples

In [None]:
arc_temp = adata[adata.obs.modality == 'paired'].obs
display(arc_temp.modality.value_counts())
print(f'{arc_temp.shape=}')
print(f'{test_data.shape=}')
print(len(set(test_data.index) & set(arc_temp.index)))

## use autoML to train a model

### initialize a AutoGluon Tabular Predictor

In [None]:
predictor = TabularPredictor(label='cell_label', path=trained_model_path, 
                             verbosity=2, log_to_file=True, eval_metric='mcc')

### train the predictor model

In [None]:
%%time
predictor.fit(train_data, presets=MODEL_QUAL_PRESET, num_gpus=1)

### train data eval

In [None]:
target_var = 'cell_label'
train_predictions, train_scores = eval_classifier_model(predictor, train_data, target_var)

### test data eval

In [None]:
test_predictions, test_scores = eval_classifier_model(predictor, test_data, target_var)

In [None]:
if DEBUG:
    display(train_scores)
    display(test_scores)

### visualize the eval data for train and test

### individual model scores for training

In [None]:
train_scores['dataset'] = 'train'
test_scores['dataset'] = 'test'
scores_df = concat([test_scores, train_scores])
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    barplot(data=scores_df, x='cell_type', y='f1-score', hue='dataset', palette='colorblind')
    plt.legend(title='dataset', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xticks(rotation=90)
    plt.xlabel('Cell Types')
    plt.ylabel('F1 Score')

In [None]:
print(f'{predictor.model_best=}')
if DEBUG:
    display(predictor.leaderboard())

In [None]:
if DEBUG:
    display(train_data.cell_label.value_counts())
    display(test_data.cell_label.value_counts())
    display(test_predictions.value_counts())

## infer the cell labels for the unknown cells

In [None]:
%%time
atac_labels = predictor.predict(inference_data)

In [None]:
if DEBUG:
    display(train_data.cell_label.value_counts())
    display(test_data.cell_label.value_counts())
    display(atac_labels.value_counts())

### visualize the number of cells by cell-types for each of the datasets

In [None]:
combined_counts = concat([train_data.cell_label.value_counts(), 
                           test_data.cell_label.value_counts(), 
                           atac_labels.value_counts()], axis='columns')
combined_counts.columns = ['Train', 'Test', 'Inference']
combined_counts = combined_counts.reset_index()
combined_counts = combined_counts.melt(id_vars=['cell_label'], 
                                       value_vars=['Train', 'Test', 'Inference'], 
                                       var_name='Dataset', value_name='Cell_counts')
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    barplot(data=combined_counts, x='cell_label', y='Cell_counts', hue='Dataset', palette='colorblind')
    plt.legend(title='Dataset', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xticks(rotation=90)
    plt.xlabel('Cell Types')
    plt.ylabel('Cell Counts')

### for the test dataset what are the important features

In [None]:
%%time
# feat_imp = predictor.feature_importance(train_data)