In [1]:
import scsims 
import pandas as pd 
import numpy as np
import os
import anndata as an
from scsims import SIMS 
import scanpy as sc
import anndata as an
from torch.utils.data import DataLoader
from scsims.testing import TestAnndatasetMatrix
import torch
from tqdm import tqdm
from scsims import clean_sample 
import numpy as np
import pandas as pd 
from scsims import SIMSClassifier
from scsims.data import AnnDatasetMatrix
import plotly.express as px 
from scsims import DataModule
import os 
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl 

from bigcsv import to_h5ad

class UploadCallback(pl.callbacks.Callback):
    def __init__(
        self, 
        path: str, 
        desc: str, 
        upload_path='model_checkpoints',
        epochs: int=1,
    ) -> None:
        super().__init__()
        self.path = path 
        self.desc = desc
        self.upload_path = upload_path
        self.epochs = epochs 

    def on_train_epoch_end(self, trainer, pl_module):
        epoch = trainer.current_epoch

        if epoch % self.epochs == 0 and epoch > 0: # Save every ten epochs
            checkpoint = f'checkpoint-{epoch}-desc-{self.desc}.ckpt'
            trainer.save_checkpoint(os.path.join(self.path, checkpoint))
            print(f'Saving checkpoint at epoch {epoch}')

In [40]:
traindata = an.read_h5ad('../data/bhaduri/primary_T.h5ad')
testdata = an.read_h5ad('../data/bhaduri/organoid_T.h5ad')

In [3]:
trainlabels = pd.read_csv('../data/bhaduri/primary_labels_clean.csv')
testlabels = pd.read_csv('../data/bhaduri/organoid_labels_clean.csv')

In [25]:
trainlabels

Unnamed: 0,cell,Cell,Area,Individual,Age,Class,State,Type,Subtype,Cluster,categorical_Subtype
0,0,AAACCTGAGCTGCCCA_50646,motor,GW14,14,Non-neuronal,Non-dividing,Radial Glia,24,26,oRG
1,1,AAACCTGAGCTTATCG_50647,motor,GW14,14,Non-neuronal,Non-dividing,Radial Glia,24,26,oRG
2,2,AAACCTGAGTATGACA_50652,motor,GW14,14,Neuron,Postmitotic,Excitatory Neuron,8,15,Layer VI Occipital
3,3,AAACCTGAGTCGCCGT_50654,motor,GW14,14,Neuron,Postmitotic,Excitatory Neuron,13,35,Newborn
4,4,AAACCTGCACCAGCAC_50657,motor,GW14,14,Neuron,Postmitotic,Excitatory Neuron,8,15,Layer VI Occipital
...,...,...,...,...,...,...,...,...,...,...,...
168692,189404,CS22_CTTAACTCAGTAGAGC_6016,Occipital cortex,CS22,10,Non-neuronal,Postmitotic,Mural,12,18,Mural
168693,189405,CS22_GGATGTTTCGACCAGC_6380,Occipital cortex,CS22,10,Non-neuronal,Postmitotic,Mural,12,18,Mural
168694,189406,CS22_GCTGCTTAGCACCGTC_6308,Occipital cortex,CS22,10,Neuron,Postmitotic,Excitatory Neuron,1,27,Deep Layer
168695,189407,CS22_ATAGACCTCCTAGTGA_5271,Occipital cortex,CS22,10,Non-neuronal,Postmitotic,Microglia,11,39,Microglia


In [21]:
testlabels['Subtype'].unique()

array(['panRG', 'hindbrainRG', 'earlyRG', 'hindbrainAstrocyte',
       'MatureIPC', 'Newborn', 'panNeuron', 'glycolyticneurons',
       'glycolyticRG', 'lowquality', 'UpperLayer', 'Astrocyte',
       'DeepLayer', 'Unknown', 'Interneuron', 'Outlier'], dtype=object)

In [18]:
trainlabels['categorical_Subtype'].unique()

array(['oRG', 'Layer VI Occipital', 'Newborn', 'vRG', 'early', 'Mural',
       'IPC/newborn', 'OPC', 'Upper Layer Occipital', 'IPC-new',
       'Cajal Retzius', 'Upper Layer', 'Layer VI Pan-area', 'Outlier',
       'SST-MGE1', 'IPC-div1', 'PFC', 'Upper Layer PFC', 'Microglia',
       'Deep Layer', 'late', 'Parietal and Temporal', 'IPC-div2',
       'Layer IV', 'MGE2', 'Endothelial', 'oRG/Astrocyte', 'tRG'],
      dtype=object)

In [125]:
# Nothing should map to microglia or endothelial
mapping = {
    'panRG': ['oRG', 'tRG', 'oRG/Astrocyte', 'vRG'],
    'hindbrainRG': ['oRG', 'tRG', 'oRG/Astrocyte', 'vRG'],
    'earlyRG': ['oRG', 'tRG', 'oRG/Astrocyte', 'vRG'],
    'hindbrainAstrocyte': ['oRG', 'tRG', 'oRG/Astrocyte', 'vRG'],
    'MatureIPC': ['IPC/newborn', 'IPC-new', 'IPC-div1', 'IPC-div2'],
    'Newborn': ['Newborn', 'IPC/newborn', 'early', 'Upper Layer', 'Upper Layer Occipital'],
    'panNeuron': [
        'Layer VI Occipital', 'early', 
        'Upper Layer Occipital', 'Layer VI Pan-area', 
        'Upper Layer', 'Upper Layer PFC', 
        'PFC', 'Deep Layer', 'Parietal and Temporal', 'Newborn', 'IPC/newborn'
    ],
    'glycolyticneurons': [ # look at these 
        'Layer VI Occipital', 'early', 
        'Upper Layer Occipital', 'Layer VI Pan-area', 
        'Upper Layer', 'Upper Layer PFC', 
        'PFC', 'Deep Layer', 'Parietal and Temporal', 'Newborn', 'IPC/newborn'
    ],
    'glycolyticRG': ['oRG', 'tRG', 'oRG/Astrocyte', 'vRG'],
    'lowquality': ['oRG', 'Layer VI Occipital', 'Newborn', 'vRG', 'early', 'Mural',
       'IPC/newborn', 'OPC', 'Upper Layer Occipital', 'IPC-new',
       'Cajal Retzius', 'Upper Layer', 'Layer VI Pan-area', 'Outlier',
       'SST-MGE1', 'IPC-div1', 'PFC', 'Upper Layer PFC', 'Microglia',
       'Deep Layer', 'late', 'Parietal and Temporal', 'IPC-div2',
       'Layer IV', 'MGE2', 'Endothelial', 'oRG/Astrocyte', 'tRG'], # could be anywhere
    'Interneuron': ['SST-MGE1', 'MGE2'],
    'UpperLayer': ['Upper Layer Occipital', 'Upper Layer', 'Upper Layer PFC', 'Layer IV'],
    'Astrocyte': ['oRG', 'tRG', 'oRG/Astrocyte', 'vRG'],
    'DeepLayer': ['Layer VI Occipital', 'Layer VI Pan-area', 'Deep Layer', 'late', 'Parietal and Temporal'],
    'Unknown': ['oRG', 'Layer VI Occipital', 'Newborn', 'vRG', 'early', 'Mural',
       'IPC/newborn', 'OPC', 'Upper Layer Occipital', 'IPC-new',
       'Cajal Retzius', 'Upper Layer', 'Layer VI Pan-area', 'Outlier',
       'SST-MGE1', 'IPC-div1', 'PFC', 'Upper Layer PFC', 'Microglia',
       'Deep Layer', 'late', 'Parietal and Temporal', 'IPC-div2',
       'Layer IV', 'MGE2', 'Endothelial', 'oRG/Astrocyte', 'tRG'], # could be anywhere
    'Interneuron': ['SST-MGE1', 'MGE2'],
    'Outlier': ['oRG', 'Layer VI Occipital', 'Newborn', 'vRG', 'early', 'Mural',
       'IPC/newborn', 'OPC', 'Upper Layer Occipital', 'IPC-new',
       'Cajal Retzius', 'Upper Layer', 'Layer VI Pan-area', 'Outlier',
       'SST-MGE1', 'IPC-div1', 'PFC', 'Upper Layer PFC', 'Microglia',
       'Deep Layer', 'late', 'Parietal and Temporal', 'IPC-div2',
       'Layer IV', 'MGE2', 'Endothelial', 'oRG/Astrocyte', 'tRG']
}

In [126]:
for f in mapping.values():
    for k in f:
        if k not in trainlabels['categorical_Subtype'].unique():
            print(k)

Now let's train the model and run inference

In [26]:
testgenes = [x.split('|')[0].upper() for x in testdata.var['index']]
traingenes = [x.upper() for x in traindata.var['index']]

combined_genes = list(set(testgenes).intersection(traingenes))
    
module = DataModule(
    datafiles=['../data/bhaduri/primary_T.h5ad'],
    labelfiles=['../data/bhaduri/primary_labels_clean.csv'],
    class_label='categorical_Subtype',
    index_col='cell',
    batch_size=16,
    num_workers=0,
    deterministic=True,
    normalize=True,
    currgenes=traingenesews,
    refgenes=combined_genes,
)

module.prepare_data()
module.setup()

wandb_logger = WandbLogger(
    project=f"Bhaduri human organoid model",
)

lr_callback = pl.callbacks.LearningRateMonitor(logging_interval='epoch')

upload_callback = UploadCallback(
    path='checkpoints',
    desc=f'local_bhaduri_human_organoid'
)

early_stopping_callback = pl.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=6,
)

trainer = pl.Trainer(
    gpus=(1 if torch.cuda.is_available() else 0),
    auto_lr_find=False,
    logger=wandb_logger,
    max_epochs=500,
    gradient_clip_val=0.5,
    callbacks=[
        lr_callback, 
        upload_callback,
        early_stopping_callback,
    ],
    stochastic_weight_avg=True,
)


model = SIMSClassifier(
    input_dim=module.num_features,
    output_dim=module.num_labels,
    weights=module.weights,
)

print(f'Input dim and output dim are {module.num_features} / {module.num_labels}')

trainer.fit(model, datamodule=module)

Labels are non-numeric, using sklearn.preprocessing.LabelEncoder to encode.
Transforming labelfile 1/1
Creating train/val/test DataLoaders...
Done, continuing to training.
Calculating weights


[34m[1mwandb[0m: Currently logged in as: [33mjlehrer1[0m. Use [1m`wandb login --relogin`[0m to force relogin


  rank_zero_deprecation(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Initializing network
Initializing explain matrix
Input dim and output dim are 16507 / 28
Creating train/val/test DataLoaders...
Done, continuing to training.
Calculating weights



  | Name    | Type   | Params
-----------------------------------
0 | network | TabNet | 1.1 M 
-----------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.381     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  precision = tp / (tp + fp)
  recall = tp / (tp + fn)
  f1s = 2*(precision * recall) / (precision + recall)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Saving checkpoint at epoch 1


Validation: 0it [00:00, ?it/s]

Saving checkpoint at epoch 2


Validation: 0it [00:00, ?it/s]

Saving checkpoint at epoch 3


Validation: 0it [00:00, ?it/s]

Saving checkpoint at epoch 4


Validation: 0it [00:00, ?it/s]

Saving checkpoint at epoch 5


Validation: 0it [00:00, ?it/s]

Saving checkpoint at epoch 6


Validation: 0it [00:00, ?it/s]

Saving checkpoint at epoch 7


Validation: 0it [00:00, ?it/s]

Saving checkpoint at epoch 8


In [113]:
from scsims.data import CollateLoader 
from scsims.testing import TestAnndatasetMatrix


testdataset = TestAnndatasetMatrix(
    testdata.X,
)

testloader = CollateLoader(
    dataset=testdataset, 
    batch_size=64, 
    num_workers=0, 
    refgenes=combined_genes, 
    currgenes=testgenes,
    normalize=True,
)

In [116]:
from functools import partial
tqdm = partial(tqdm, position=0, leave=True)

preds = []

model.eval()
with torch.no_grad():
    for X in tqdm(testloader):
        res, _ = model(X)
        
        _, top_preds = res.topk(3, axis=1) # to get indices
        preds.extend(top_preds.numpy())

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3674/3674 [08:08<00:00,  7.52it/s]


In [117]:
preds = pd.DataFrame(preds)

In [41]:
traindata = traindata[trainlabels['cell'].values, :]

In [43]:
traindata = traindata.copy()

In [50]:
testlabels = testlabels.set_index('V1')
testdata.obs = testdata.obs.join(testlabels)

In [58]:
trainlabels = trainlabels.set_index('Cell')

traindata.obs = traindata.obs.join(trainlabels)

In [118]:
from sklearn.preprocessing import LabelEncoder

le = LabelEncoder().fit(traindata.obs['categorical_Subtype'])

testdata.obs['predicted'] = preds.loc[:, 0].values
testdata.obs['predicted_str'] = module.label_encoder.inverse_transform(testdata.obs['predicted'])

In [119]:
acc = 0
for truth, pred in zip(testdata.obs["Subtype"], testdata.obs["predicted_str"]):
    options = mapping[truth]
    if pred in options:
        acc += 1

acc

76160

In [120]:
testdata.obs['Subtype'].value_counts()

panRG                 69272
panNeuron             46665
lowquality            25224
glycolyticRG          19798
glycolyticneurons     16344
Newborn               14757
UpperLayer            12406
hindbrainRG            7670
earlyRG                5800
DeepLayer              4585
hindbrainAstrocyte     3565
MatureIPC              2210
Outlier                2160
Astrocyte              1873
Interneuron            1662
Unknown                1130
Name: Subtype, dtype: int64

In [109]:
module.label_encoder

LabelEncoder()

In [98]:
testdata.obs = testdata.obs.reset_index(drop=False)

In [127]:
for label in testdata.obs['Subtype'].unique():
    t = testdata[
        testdata.obs[testdata.obs['Subtype'] == label].index
    ]
    
    acc = 0
    for truth, pred in zip(t.obs["Subtype"], t.obs["predicted_str"]):
        options = mapping[truth]
        if pred in options:
            acc += 1

    print(f'Accuracy for {label} is {acc / len(t) * 100}')

Accuracy for panRG is 36.77098972167687
Accuracy for hindbrainRG is 34.485006518904825
Accuracy for earlyRG is 31.06896551724138
Accuracy for hindbrainAstrocyte is 29.957924263674613
Accuracy for MatureIPC is 3.619909502262444
Accuracy for Newborn is 6.173341465067426
Accuracy for panNeuron is 59.71070395371263
Accuracy for glycolyticneurons is 32.74596182085169
Accuracy for glycolyticRG is 17.875542984139813
Accuracy for lowquality is 100.0
Accuracy for UpperLayer is 13.872319845236175
Accuracy for Astrocyte is 29.578216764548852
Accuracy for DeepLayer is 36.205016357688116
Accuracy for Unknown is 100.0
Accuracy for Interneuron is 16.365824308062578
Accuracy for Outlier is 100.0


Exception in thread ChkStopThr:
Traceback (most recent call last):
  File "/Users/julian/miniconda3/envs/sims/lib/python3.9/threading.py", line 973, in _bootstrap_inner
    self.run()
  File "/Users/julian/miniconda3/envs/sims/lib/python3.9/threading.py", line 910, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/julian/miniconda3/envs/sims/lib/python3.9/site-packages/wandb/sdk/wandb_run.py", line 170, in check_status
    status_response = self._interface.communicate_stop_status()
  File "/Users/julian/miniconda3/envs/sims/lib/python3.9/site-packages/wandb/sdk/interface/interface.py", line 127, in communicate_stop_status
    resp = self._communicate_stop_status(status)
  File "/Users/julian/miniconda3/envs/sims/lib/python3.9/site-packages/wandb/sdk/interface/interface_sock.py", line 69, in _communicate_stop_status
    data = super()._communicate_stop_status(status)
  File "/Users/julian/miniconda3/envs/sims/lib/python3.9/site-packages/wandb/sdk/interface/interface_sh

In [101]:
testdata.obs[testdata.obs['Subtype'] == 'panRG']

Unnamed: 0,index,Cluster,Sample,Line,Protocol,Age,iPSCorhESC,Class,State,Type,Subtype,predicted,predicted_str
0,H1SWeek3_AAACCTGAGACAAAGG,29,H1SWeek3,H1,Less Directed,3,hESC,Nonneuronal,Dividing,RadialGlia,panRG,23,late
1,H1SWeek3_AAACCTGAGCACACAG,5,H1SWeek3,H1,Less Directed,3,hESC,Nonneuronal,Dividing,RadialGlia,panRG,11,Microglia
2,H1SWeek3_AAACCTGAGGATGGAA,35,H1SWeek3,H1,Less Directed,3,hESC,Nonneuronal,Dividing,RadialGlia,panRG,21,Upper Layer PFC
4,H1SWeek3_AAACCTGCAGCGTAAG,5,H1SWeek3,H1,Less Directed,3,hESC,Nonneuronal,Dividing,RadialGlia,panRG,16,PFC
5,H1SWeek3_AAACCTGCATTACGAC,5,H1SWeek3,H1,Less Directed,3,hESC,Nonneuronal,Dividing,RadialGlia,panRG,11,Microglia
...,...,...,...,...,...,...,...,...,...,...,...,...,...
235101,WTC10SWeek10_TTTGCGCGTAGCCTCG,14,YH10SWeek10,YH10,Less Directed,10,iPSC,Nonneuronal,Nondividing,RadialGlia,panRG,16,PFC
235106,WTC10SWeek10_TTTGGTTAGCGTCAAG,17,YH10SWeek10,YH10,Less Directed,10,iPSC,Nonneuronal,Dividing,RadialGlia,panRG,16,PFC
235112,WTC10SWeek10_TTTGGTTTCACGCGGT,4,YH10SWeek10,YH10,Less Directed,10,iPSC,Nonneuronal,Dividing,RadialGlia,panRG,11,Microglia
235113,WTC10SWeek10_TTTGGTTTCTCTTATG,29,YH10SWeek10,YH10,Less Directed,10,iPSC,Nonneuronal,Dividing,RadialGlia,panRG,23,late


In [70]:
acc / len(preds) * 100

21.788781095691153