## Prototype embed-classify Labelator EC_LBL8R 

### overview.
This notebook protypes a "labelator".  The purpose of a "labelator" is to easily classify _cell types_ for out-of-sample "Test" data. 

Currently we are prototyping with several `anndata` _dataloaders_.  `scvi-tools`, `scarches`, and `anndata` each have an implimenation of a `torch` _dataloader_.   The `scarches` flavor seems to be a good middle ground and then DO have an SCVI implimentation.    Probably will use the _native_ loader for each type, but an `scarches` variant for our simpler models. 

To state our confirmation bias, it impliments the SCVI models which we like.

We will validate potential models and calibrate them with simple expectations using a typical "Train"/"Validate" and "Test"/"Probe" approach.  


Definitions:
- "Train": data samples on which the model being tested is trained.  The `torch lightning` framework used by `scvi-tools` semi-automatically will "validate" to test out-of-sample prediction fidelity during training.
- "Test": held-out samples to test the fidelity of the model.  
- "Probe": data generated externally,which is _probing_ the fidelity of the model to general scRNAseq data.



#### embed and classify 

#### 2 step: ec_lbl8r encode + categorize
In two steps:
1) _encode_: embedding the scRNAseq counts into a latent sub-space
    - VAE (e.g. MMD-VAE, infoVAE etc)
    - PCA (_naive_ linear encoding)
    - scVI-latent (__naive__ VAE)
    - etc.
        - scVI (__transfer__ VAE)

2) _categorize_: predicting creating a probability of a each category 
    - Linear classifier (e.g. multinomial Logistic Regression)
    - NN non-linear classifier (MLP)
    - boosted trees (XGboost)



## Caveats
There are several gotchas to anticipate:
- features.  Currently we are locked into the 3k genes we are testing with.  Handling subsets and supersets is TBC.
- batch.  In principle each "embedding" or decode part of the model should be able to measure a "batch-correction" parameter explicitly.  in scVI this is explicitly _learned_.  However in _naive_ inference mode it should just be an inferred fudge factor.
- noise.  including or not including `doublet`, `mito`, or `ribo` metrics




### List of models

ec lbl8r xgb variants:
- raw counts PCA loadings n=50 features
- normalized counts (scVI) PCA loadings
- scVI latent
- etc.





In [1]:
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    !pip uninstall -y typing_extensions
    !pip install --quiet scvi-colab
    from scvi_colab import install
    install()

In [2]:

import numpy as np
import scanpy as sc

#### 
import sys
import warnings

import anndata as ad
import matplotlib.pyplot as plt
import numpy as np
# import pandas as pd
import scvi
from pathlib import Path
# import scarches as sca

from lbl8r.utils import make_pc_loading_adata
from lbl8r.xgb import get_xgb_data, train_xgboost, test_xgboost

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    

%load_ext autoreload
%autoreload 2


### Data Paths

In [3]:
if IN_COLAB:
    root_path = Path("/content/drive/MyDrive/")
    data_path = root_path / "SingleCellModel/data"
else:
    root_path = Path("../")
    if sys.platform == "darwin":
        data_path = root_path / "data/xylena_raw"
    else:
        data_path = root_path / "data/scdata/xylena_raw"
        
XYLENA_ANNDATA = "brain_atlas_anndata.h5ad"
XYLENA_METADATA = "final_metadata.csv"
XYLENA_ANNDATA2 = "brain_atlas_anndata_updated.h5ad"

XYLENA_TRAIN = XYLENA_ANNDATA.replace(".h5ad", "_train.h5ad")
XYLENA_TEST = XYLENA_ANNDATA.replace(".h5ad", "_test.h5ad")




In [4]:
model_path = root_path / "lbl8r_models"
if not model_path.exists():
  model_path.mkdir()



--------------

## xgb_LBL8R on raw count PCAs 

This is a zeroth order "baseline" for performance.  


### load data

In [6]:
filen = data_path / XYLENA_ANNDATA.replace(".h5ad", "_train_pca.h5ad")
train_ad = ad.read_h5ad(filen)


### train model

In [7]:

X_train, y_train, label_encoder = get_xgb_data(train_ad)

bst = train_xgboost(X_train, y_train)



[0]	valid-mlogloss:1.06754
[10]	valid-mlogloss:0.11488
[20]	valid-mlogloss:0.07826
[30]	valid-mlogloss:0.07471
[40]	valid-mlogloss:0.07442
[49]	valid-mlogloss:0.07469


### test and save

In [13]:

filen = data_path / XYLENA_ANNDATA.replace(".h5ad", "_test_pca.h5ad")
test_ad = ad.read_h5ad(filen)


array([[ 3.68088684e+02,  1.33296595e+01,  8.02294540e+00,
         5.23752823e+01, -2.03400230e+01,  4.36526947e+01,
        -4.19003716e+01, -3.73202019e+01,  1.57291861e+01,
         3.05967464e+01,  3.66521606e+01,  6.03300667e+01,
        -3.64095688e+01,  2.11280966e+00, -2.48350811e+01],
       [ 4.32298737e+02,  1.81242142e+01, -7.17947922e+01,
        -1.00064980e+02,  5.68049431e+00, -2.24803200e+01,
        -8.52943611e+00, -2.49820480e+01, -4.20321846e+00,
        -2.50607896e+00, -1.34945745e+01, -1.39766378e+01,
        -3.33317261e+01, -1.29724140e+01,  1.88233509e+01],
       [ 2.97618011e+02, -1.38771582e+00, -6.77000732e+01,
        -2.80466099e+01, -5.72715473e+00,  2.45978336e+01,
        -1.16432734e+01, -7.45540333e+00,  9.18723679e+00,
        -2.31691837e+01,  5.42982721e+00,  2.03360653e+01,
        -3.04312782e+01,  1.84899998e+01,  1.32612133e+00],
       [ 3.92290924e+02,  1.75723610e+01, -8.01843872e+01,
        -6.13158760e+01, -2.19617596e+01, -1.17024280

In [9]:

test_xgboost(bst, test_ad, label_encoder)

# Save the model for later use
bst.save_model(model_path / 'xgb_raw_pca.model')


              precision    recall  f1-score   support

       Astro       1.00      0.97      0.99     42519
         ExN       0.99      0.99      0.99    110484
         InN       0.99      1.00      0.99     53325
          MG       1.00      0.99      1.00     26529
         OPC       0.99      0.98      0.99     28882
       Oligo       0.99      1.00      0.99    235180
     Unknown       1.00      0.75      0.86        12
          VC       0.99      0.95      0.97      5154

    accuracy                           0.99    502085
   macro avg       0.99      0.95      0.97    502085
weighted avg       0.99      0.99      0.99    502085





--------------

## xgb_LBL8R on scVI normalized PCAs 

To give the pca "baseline" a fair shake its important to use normalized counts.  Using the `scVI` normalization is our best shot... (Although the current models are NOT batch correcting since we don't have a good strategy to do this with probe data)

In [10]:
# Load & prep Training data
filen = data_path / XYLENA_ANNDATA.replace(".h5ad", "_train_scvi_normalized_pca.h5ad")
norm_train_ad = ad.read_h5ad(filen)

X_train, y_train, label_encoder = get_xgb_data(train_ad)

# train 
bst = train_xgboost(X_train, y_train)

# test
filen = data_path / XYLENA_ANNDATA.replace(".h5ad", "_test_scvi_normalized_pca.h5ad")
test_ad = ad.read_h5ad(filen)
test_xgboost(bst, test_ad, label_encoder)

# save
bst.save_model(model_path / 'xgb_scVInorm_pca.model')


[0]	valid-mlogloss:1.06719
[10]	valid-mlogloss:0.11092
[20]	valid-mlogloss:0.07386
[30]	valid-mlogloss:0.07029
[40]	valid-mlogloss:0.06978
[49]	valid-mlogloss:0.06979


  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

       Astro       0.01      0.01      0.01     42519
         ExN       0.14      0.17      0.15    110484
         InN       0.01      0.03      0.02     53325
          MG       0.00      0.00      0.00     26529
         OPC       0.00      0.00      0.00     28882
       Oligo       0.53      0.24      0.33    235180
     Unknown       0.00      0.00      0.00        12
          VC       0.00      0.01      0.00      5154

    accuracy                           0.16    502085
   macro avg       0.09      0.06      0.06    502085
weighted avg       0.28      0.16      0.19    502085



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


--------------

## xgb_LBL8R on scVI latents  


In [11]:
# Load & prep Training data
filen = data_path / XYLENA_ANNDATA.replace(".h5ad", "_train_scVI_lat.h5ad")
norm_train_ad = ad.read_h5ad(filen)

X_train, y_train, label_encoder = get_xgb_data(train_ad)

# train 
bst = train_xgboost(X_train, y_train)

# test
filen = data_path / XYLENA_ANNDATA.replace(".h5ad", "_test_scVI_lat.h5ad")
test_ad = ad.read_h5ad(filen)
test_xgboost(bst, test_ad, label_encoder)

# save
bst.save_model(model_path / 'xgb_scVInorm_pca.model')


[0]	valid-mlogloss:1.06881
[10]	valid-mlogloss:0.11701
[20]	valid-mlogloss:0.08017
[30]	valid-mlogloss:0.07680
[40]	valid-mlogloss:0.07651
[47]	valid-mlogloss:0.07674


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

       Astro       0.00      0.00      0.00     42519
         ExN       0.22      1.00      0.36    110484
         InN       0.00      0.00      0.00     53325
          MG       0.00      0.00      0.00     26529
         OPC       0.00      0.00      0.00     28882
       Oligo       0.00      0.00      0.00    235180
     Unknown       0.00      0.00      0.00        12
          VC       1.00      0.01      0.01      5154

    accuracy                           0.22    502085
   macro avg       0.15      0.13      0.05    502085
weighted avg       0.06      0.22      0.08    502085



  _warn_prf(average, modifier, msg_start, len(result))


In [12]:
# Load & prep Training data
filen = data_path / XYLENA_ANNDATA.replace(".h5ad", "_train_scVImu_lat.h5ad")
norm_train_ad = ad.read_h5ad(filen)

X_train, y_train, label_encoder = get_xgb_data(train_ad)

# train 
bst = train_xgboost(X_train, y_train)

# test
filen = data_path / XYLENA_ANNDATA.replace(".h5ad", "_test_scVImu_lat.h5ad")
test_ad = ad.read_h5ad(filen)
test_xgboost(bst, test_ad, label_encoder)

# save
bst.save_model(model_path / 'xgb_scVImu_lat.model')


[0]	valid-mlogloss:1.06775
[10]	valid-mlogloss:0.11552
[20]	valid-mlogloss:0.07863
[30]	valid-mlogloss:0.07508
[40]	valid-mlogloss:0.07479
[49]	valid-mlogloss:0.07508


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

       Astro       0.00      0.00      0.00     42519
         ExN       0.22      1.00      0.36    110484
         InN       0.00      0.00      0.00     53325
          MG       0.00      0.00      0.00     26529
         OPC       0.00      0.00      0.00     28882
       Oligo       0.00      0.00      0.00    235180
     Unknown       0.00      0.00      0.00        12
          VC       0.00      0.00      0.00      5154

    accuracy                           0.22    502085
   macro avg       0.03      0.12      0.05    502085
weighted avg       0.05      0.22      0.08    502085



  _warn_prf(average, modifier, msg_start, len(result))
