## Prototype end to end Labelator E2E_LBL8R 

### overview.
This notebook protypes a "labelator".  The purpose of a "labelator" is to easily classify _cell types_ for out-of-sample "Test" data. 

Currently we are prototyping with several `anndata` _dataloaders_.  `scvi-tools`, `scarches`, and `anndata` each have an implimenation of a `torch` _dataloader_.   The `scarches` flavor seems to be a good middle ground and then DO have an SCVI implimentation.    Probably will use the _native_ loader for each type, but an `scarches` variant for our simpler models. 

To state our confirmation bias, it impliments the SCVI models which we like.

We will validate potential models and calibrate them with simple expectations using a typical "Train"/"Validate" and "Test"/"Probe" approach.  


Definitions:
- "Train": data samples on which the model being tested is trained.  The `torch lightning` framework used by `scvi-tools` semi-automatically will "validate" to test out-of-sample prediction fidelity during training.
- "Test": held-out samples to test the fidelity of the model.  
- "Probe": data generated externally,which is _probing_ the fidelity of the model to general scRNAseq data.



#### end-to-end
We can also try some _end-to-end_ approaches where a single model takes us from raw counts to category probabilities.
- __naive__
    - boosted trees (e.g. xgboost)
    - cVAE
    - trVAE
- __transfer__
    - scANVI




## Caveats
There are several gotchas to anticipate:
- features.  Currently we are locked into the 3k genes we are testing with.  Handling subsets and supersets is TBC.
- batch.  In principle each "embedding" or decode part of the model should be able to measure a "batch-correction" parameter explicitly.  in scVI this is explicitly _learned_.  However in _naive_ inference mode it should just be an inferred fudge factor.
- noise.  including or not including `doublet`, `mito`, or `ribo` metrics




### List of models

e2e xgb variants:
- raw counts: n=3000 features
- normalized counts (scVI)


In [1]:
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    !pip uninstall -y typing_extensions
    !pip install --quiet scvi-colab
    from scvi_colab import install
    install()

In [2]:

import numpy as np
import scanpy as sc

#### 
import sys
import warnings

import anndata as ad
import matplotlib.pyplot as plt
import numpy as np
# import pandas as pd
import scvi
from pathlib import Path
# import scarches as sca

from lbl8r.utils import make_pc_loading_adata
from lbl8r.xgb import get_xgb_data, train_xgboost, test_xgboost

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    

%load_ext autoreload
%autoreload 2


### Load Train, Validate Data 

In [3]:
if IN_COLAB:
    root_path = Path("/content/drive/MyDrive/")
    data_path = root_path / "SingleCellModel/data"
else:
    root_path = Path("../")
    if sys.platform == "darwin":
        data_path = root_path / "data/xylena_raw"
    else:
        data_path = root_path / "data/scdata/xylena_raw"
        
XYLENA_ANNDATA = "brain_atlas_anndata.h5ad"
XYLENA_METADATA = "final_metadata.csv"
XYLENA_ANNDATA2 = "brain_atlas_anndata_updated.h5ad"

XYLENA_TRAIN = XYLENA_ANNDATA.replace(".h5ad", "_train.h5ad")
XYLENA_TEST = XYLENA_ANNDATA.replace(".h5ad", "_test.h5ad")




## model path

In [4]:
model_path = root_path / "e2e_models"
if not model_path.exists():
  model_path.mkdir()




## Raw Counts

### load data

In [5]:
outfilen = data_path / XYLENA_TRAIN
train_ad = ad.read_h5ad(outfilen)

### train model

In [6]:


X_train, y_train, label_encoder = get_xgb_data(train_ad)

bst = train_xgboost(X_train, y_train)



[0]	valid-mlogloss:1.06628
[10]	valid-mlogloss:0.10792
[20]	valid-mlogloss:0.07040
[30]	valid-mlogloss:0.06699
[40]	valid-mlogloss:0.06674
[45]	valid-mlogloss:0.06691


### test and save

In [7]:
outfilen = data_path / XYLENA_TEST
test_ad = ad.read_h5ad(outfilen)

test_xgboost(bst, test_ad, label_encoder)

# Save the model for later use
bst.save_model(model_path / 'xgb_raw_cnt.model')



              precision    recall  f1-score   support

       Astro       1.00      0.98      0.99     42519
         ExN       0.99      0.99      0.99    110484
         InN       0.99      1.00      0.99     53325
          MG       1.00      1.00      1.00     26529
         OPC       1.00      0.98      0.99     28882
       Oligo       0.99      1.00      0.99    235180
     Unknown       1.00      0.92      0.96        12
          VC       0.99      0.97      0.98      5154

    accuracy                           0.99    502085
   macro avg       0.99      0.98      0.99    502085
weighted avg       0.99      0.99      0.99    502085





--------------

## scVI normalized counts


### load data

In [8]:
outfile = data_path / XYLENA_ANNDATA.replace(".h5ad", "_train_scvi_normalized.h5ad")
train_ad = ad.read_h5ad(outfile)

train_ad

AnnData object with n_obs × n_vars = 502085 × 3000
    obs: 'seurat_clusters', 'cell_type', 'sample', 'doublet_score', 'nCount_RNA', 'nFeature_RNA', 'percent.mt', 'percent.rb', 'batch', 'S.Score', 'G2M.Score', 'Phase', 'RNA_snn_res.0.3', 'ExN1', 'InN2', 'MG3', 'Astro4', 'Oligo5', 'OPC6', 'VC7', 'type', 'UMAP_1', 'UMAP_2', 'clean', 'test', 'train', 'tmp', '_scvi_batch', '_scvi_labels'

### train model

In [26]:
def preprocess_norm_cnts(X):
    # no idea what works best... i suppose something 
    # X = 1e-2 * X
    return np.log1p(X)

X_train, y_train, label_encoder = get_xgb_data(train_ad)



X_train = preprocess_norm_cnts(X_train)

bst = train_xgboost(X_train, y_train)



[0]	valid-mlogloss:1.05149
[10]	valid-mlogloss:0.09910
[20]	valid-mlogloss:0.06446
[30]	valid-mlogloss:0.06246
[40]	valid-mlogloss:0.06262
[42]	valid-mlogloss:0.06278


### test and save

In [10]:
outfile = data_path / XYLENA_ANNDATA.replace(".h5ad", "_test_scvi_normalized.h5ad")
test_ad = ad.read_h5ad(outfilen)

test_xgboost(bst, test_ad, label_encoder)

# Save the model for later use
bst.save_model(model_path / 'xgb_scVInorm_cnt.model')



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

       Astro       0.78      0.03      0.06     42519
         ExN       0.22      0.82      0.35    110484
         InN       0.37      0.16      0.22     53325
          MG       0.77      0.01      0.02     26529
         OPC       0.11      0.18      0.14     28882
       Oligo       0.69      0.06      0.11    235180
     Unknown       0.00      0.00      0.00        12
          VC       0.29      0.05      0.09      5154

    accuracy                           0.24    502085
   macro avg       0.40      0.16      0.12    502085
weighted avg       0.53      0.24      0.16    502085



  _warn_prf(average, modifier, msg_start, len(result))


### TODO:  probe with external data

In [21]:
test_ad.X.mean(axis=1),train_ad.X.mean(axis=1)

(array([2.901     , 3.0906668 , 2.3573334 , ..., 0.15533334, 0.051     ,
        0.12466667], dtype=float32),
 array([3.3333337, 3.333333 , 3.333334 , ..., 3.3333333, 3.3333328,
        3.3333328], dtype=float32))

------------------
TODO:  evaluation for entropy of predictions


TODO:  strategy for "Unknown" low-quality predictions