## Prototype embed-classify Labelator EC_LBL8R 

### overview.
This notebook protypes a "labelator".  The purpose of a "labelator" is to easily classify _cell types_ for out-of-sample "Test" data. 

Currently we are prototyping with several `anndata` _dataloaders_.  `scvi-tools`, `scarches`, and `anndata` each have an implimenation of a `torch` _dataloader_.   The `scarches` flavor seems to be a good middle ground and then DO have an SCVI implimentation.    Probably will use the _native_ loader for each type, but an `scarches` variant for our simpler models. 

To state our confirmation bias, it impliments the SCVI models which we like.

We will validate potential models and calibrate them with simple expectations using a typical "Train"/"Validate" and "Test"/"Probe" approach.  


Definitions:
- "Train": data samples on which the model being tested is trained.  The `torch lightning` framework used by `scvi-tools` semi-automatically will "validate" to test out-of-sample prediction fidelity during training.
- "Test": held-out samples to test the fidelity of the model.  
- "Probe": data generated externally,which is _probing_ the fidelity of the model to general scRNAseq data.



#### embed and classify 

#### 2 step: ec_lbl8r encode + categorize
In two steps:
1) _encode_: embedding the scRNAseq counts into a latent sub-space
    - VAE (e.g. MMD-VAE, infoVAE etc)
    - PCA (_naive_ linear encoding)
    - scVI-latent (__naive__ VAE)
    - etc.
        - scVI (__transfer__ VAE)

2) _categorize_: predicting creating a probability of a each category 
    - Linear classifier (e.g. multinomial Logistic Regression)
    - NN non-linear classifier (MLP)
    - boosted trees (XGboost)



## Caveats
There are several gotchas to anticipate:
- features.  Currently we are locked into the 3k genes we are testing with.  Handling subsets and supersets is TBC.
- batch.  In principle each "embedding" or decode part of the model should be able to measure a "batch-correction" parameter explicitly.  in scVI this is explicitly _learned_.  However in _naive_ inference mode it should just be an inferred fudge factor.
- noise.  including or not including `doublet`, `mito`, or `ribo` metrics




### List of models

ec lbl8r xgb variants:
- raw counts PCA loadings n=50 features
- normalized counts (scVI) PCA loadings
- scVI latent
- etc.





In [1]:
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    !pip uninstall -y typing_extensions
    !pip install --quiet scvi-colab
    from scvi_colab import install
    install()

else:
    import os
    # os.chdir('../')

    ### import local python functions in ../lbl8r
    sys.path.append(os.path.abspath((os.path.join(os.getcwd(), '..'))))

In [2]:

import numpy as np
import scanpy as sc

#### 
import sys
import warnings

import anndata as ad
import matplotlib.pyplot as plt
import numpy as np
# import pandas as pd
import scvi
from pathlib import Path
# import scarches as sca

from lbl8r.utils import make_pc_loading_adata
from lbl8r.xgb import get_xgb_data, train_xgboost, test_xgboost

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    

%load_ext autoreload
%autoreload 2


### Data Paths

In [3]:
if IN_COLAB:
    root_path = Path("/content/drive/MyDrive/")
    data_path = root_path / "SingleCellModel/data"
else:
    root_path = Path("../")
    if sys.platform == "darwin":
        data_path = root_path / "data/xylena_raw"
    else:
        data_path = root_path / "data/scdata/xylena"
        raw_data_path = root_path / "data/scdata/xylena_raw"

XYLENA_ANNDATA = "brain_atlas_anndata.h5ad"
XYLENA_METADATA = "final_metadata.csv"
XYLENA_ANNDATA2 = "brain_atlas_anndata_updated.h5ad"

XYLENA_TRAIN = XYLENA_ANNDATA.replace(".h5ad", "_train_cnt.h5ad")
XYLENA_TEST = XYLENA_ANNDATA.replace(".h5ad", "_test_cnt.h5ad")

XYLENA_TRAIN_SPARSE = XYLENA_TRAIN.replace(".h5ad", "_sparse.h5ad")
XYLENA_TEST_SPARSE = XYLENA_TEST.replace(".h5ad", "_sparse.h5ad")

In [4]:
model_path = root_path / "lbl8r_models"
if not model_path.exists():
  model_path.mkdir()



--------------

## xgb_LBL8R on raw count PCAs 

This is a zeroth order "baseline" for performance.  


### load data

In [5]:
OUT_PATH = data_path / "LBL8R"


filen = OUT_PATH / XYLENA_TRAIN.replace("_cnt.h5ad", "_pca_out.h5ad")
train_ad = ad.read_h5ad(filen)




In [6]:
train_ad.obs['cell_type'].value_counts

<bound method IndexOpsMixin.value_counts of cells
GGCCTAATCGATTTAG-1_1        ExN
TAGTAACGTAGTCAAT-1_1        ExN
GAAAGCCAGCAGCTCA-1_1        ExN
ACTCACCTCCTCCCTC-1_1        ExN
CTTCATCCAATCGCAC-1_1        ExN
                          ...  
GTTGTGAGTCGCAATA-1_138    Astro
GAAGTCAAGCCACAAT-1_138      ExN
CTGGACCAGGCTGTGC-1_138    Astro
TCCTCACAGGAGTAAT-1_138      ExN
GCAGCCAGTTGTGATG-1_138    Oligo
Name: cell_type, Length: 502085, dtype: category
Categories (7, object): ['Astro', 'ExN', 'InN', 'MG', 'OPC', 'Oligo', 'VC']>

### train model

In [7]:

X_train, y_train, label_encoder = get_xgb_data(train_ad)

bst = train_xgboost(X_train, y_train)



[0]	valid-mlogloss:1.05707
[10]	valid-mlogloss:0.11159
[20]	valid-mlogloss:0.07661
[30]	valid-mlogloss:0.07379
[40]	valid-mlogloss:0.07353
[47]	valid-mlogloss:0.07392


### test and save

In [8]:


filen = OUT_PATH / XYLENA_TEST.replace("_cnt.h5ad", "_pca_out.h5ad")

test_ad = ad.read_h5ad(filen)


test_ad.obs["cell_type"].value_counts()

Oligo    86666
ExN      50541
InN      25488
Astro    18646
OPC      12809
MG       11052
VC        2524
Name: cell_type, dtype: int64

In [9]:

test_xgboost(bst, test_ad, label_encoder)

# Save the model for later use
bst.save_model(model_path / 'xgb_raw_pca.model')


              precision    recall  f1-score   support

       Astro       0.96      0.89      0.93     18646
         ExN       0.95      0.98      0.96     50541
         InN       0.98      0.92      0.95     25488
          MG       0.96      0.99      0.98     11052
         OPC       1.00      0.90      0.95     12809
       Oligo       0.94      0.97      0.96     86666
          VC       0.87      0.72      0.79      2524

    accuracy                           0.95    207726
   macro avg       0.95      0.91      0.93    207726
weighted avg       0.95      0.95      0.95    207726





--------------

## xgb_LBL8R on scVI normalized PCAs 

To give the pca "baseline" a fair shake its important to use normalized counts.  Using the `scVI` normalization is our best shot... (Although the current models are NOT batch correcting since we don't have a good strategy to do this with probe data)

In [18]:
filen = OUT_PATH / XYLENA_TRAIN.replace("_cnt.h5ad", "_exp_nb_out.h5ad")
filen

PosixPath('../data/scdata/xylena/LBL8R/brain_atlas_anndata_train_exp_nb_out.h5ad')

In [16]:
# Load & prep Training data
filen = OUT_PATH / XYLENA_TRAIN.replace("_cnt.h5ad", "_exp_nb_pca_out.h5ad")
norm_train_ad = ad.read_h5ad(filen)

X_train, y_train, label_encoder = get_xgb_data(train_ad)

# train 
bst = train_xgboost(X_train, y_train)

# test
filen = OUT_PATH / XYLENA_TEST.replace("_cnt.h5ad", "_exp_nb_pca_out.h5ad")
test_ad = ad.read_h5ad(filen)
test_xgboost(bst, test_ad, label_encoder)

# save
bst.save_model(model_path / 'xgb_scVInorm_pca.model')


[0]	valid-mlogloss:1.05750
[10]	valid-mlogloss:0.11306
[20]	valid-mlogloss:0.07784
[30]	valid-mlogloss:0.07458
[40]	valid-mlogloss:0.07432
[45]	valid-mlogloss:0.07445
              precision    recall  f1-score   support

       Astro       0.00      0.00      0.00     18646
         ExN       0.36      1.00      0.53     50541
         InN       0.92      0.01      0.01     25488
          MG       0.00      0.00      0.00     11052
         OPC       0.00      0.00      0.00     12809
       Oligo       0.47      0.17      0.25     86666
          VC       0.00      0.00      0.00      2524

    accuracy                           0.31    207726
   macro avg       0.25      0.17      0.11    207726
weighted avg       0.40      0.31      0.23    207726





In [14]:
test_ad

AnnData object with n_obs × n_vars = 207726 × 3000
    obs: 'seurat_clusters', 'cell_type', 'sample', 'doublet_score', 'nCount_RNA', 'nFeature_RNA', 'percent.mt', 'percent.rb', 'batch', 'sample_other', 'S.Score', 'G2M.Score', 'Phase', 'RNA_snn_res.0.3', 'seurat_clusters_other', 'ExN1', 'InN2', 'MG3', 'Astro4', 'Oligo5', 'OPC6', 'VC7', 'type', 'UMAP_1', 'UMAP_2', 'clean', 'test', 'train', 'tmp', '_scvi_batch', '_scvi_labels'
    var: 'feat'
    uns: 'pca'
    obsm: 'X_pca', 'X_scVI', '_X_pca'
    varm: 'PCs', '_PCs'

--------------

## xgb_LBL8R on scVI latents  


In [19]:
# Load & prep Training data
filen = OUT_PATH / XYLENA_ANNDATA.replace(".h5ad", "_train_scvi_nb_out.h5ad")
norm_train_ad = ad.read_h5ad(filen)

X_train, y_train, label_encoder = get_xgb_data(train_ad)

# train 
bst = train_xgboost(X_train, y_train)

# test
filen = OUT_PATH / XYLENA_ANNDATA.replace(".h5ad", "_test_scvi_nb_out.h5ad")
test_ad = ad.read_h5ad(filen)
test_xgboost(bst, test_ad, label_encoder)

# save
bst.save_model(model_path / 'xgb_scvi_nb.model')


[0]	valid-mlogloss:1.05822
[10]	valid-mlogloss:0.11433
[20]	valid-mlogloss:0.07914
[30]	valid-mlogloss:0.07584
[40]	valid-mlogloss:0.07530
[49]	valid-mlogloss:0.07542
              precision    recall  f1-score   support

       Astro       0.00      0.00      0.00     18646
         ExN       0.14      0.15      0.15     50541
         InN       0.08      0.17      0.11     25488
          MG       0.00      0.00      0.00     11052
         OPC       0.00      0.00      0.00     12809
       Oligo       0.47      0.01      0.01     86666
          VC       0.01      0.43      0.02      2524

    accuracy                           0.07    207726
   macro avg       0.10      0.11      0.04    207726
weighted avg       0.24      0.07      0.05    207726



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
