In [21]:
import pandas as pd 
from sklearn.preprocessing import LabelEncoder
import sys
import anndata as an
import scanpy as sp
import h5py

sys.path.append('../src')
sys.path.append('../tests')

from models.lib.lightning_train import DataModule, generate_trainer
from models.lib.data import *
from models.lib.neural import *
from models.lib.testing import *

from pytorch_lightning.loggers import WandbLogger
from torchmetrics.functional import *

import pandas as pd
from scipy.sparse import csr_matrix
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from bigcsv.bigcsv import experimental_to_h5ad
from functools import partial
import torchmetrics.functional as f
from torchmetrics import Metric
import torchmetrics 

In [36]:
class Median_F1(Metric):
    def __init__(self, dist_sync_on_step=False, *args, **kwargs):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.f1 = torchmetrics.F1Score(*args, **kwargs, average='none') # calculate f1 for each class 
        
    def update(self, preds: torch.Tensor, target: torch.Tensor):
        preds, target = self._input_format(preds, target)
        assert preds.shape == target.shape
        
        self.scores = self.f1(preds, target)
        
    def compute(self):
        return np.median(self.scores)
    
def median_f1(preds, target, num_classes):
    f1s = torchmetrics.functional.f1_score(preds, target, num_classes=num_classes, average='none')
    print(f1s)
    
    print(np.median(f1s))
    return np.median(f1s)

In [37]:
module = DataModule(
    datafiles=['../data/retina/retina_T.h5ad'],
    labelfiles=['../data/retina/retina_labels_numeric.csv'],
    class_label='class_label',
    index_col='cell',
    batch_size=16,
    num_workers=0,
    shuffle=True,
    drop_last=True,
    normalize=True,
    deterministic=True,
)

module.setup()

model = TabNetLightning.load_from_checkpoint(
    '../checkpoints/checkpoint-80-desc-retina.ckpt',
    input_dim=37475,
    output_dim=13,
    n_d=32,
    n_a=32,
    n_steps=10,
    metrics={
        'f1_median': median_f1
    }
)

Creating train/val/test DataLoaders...


Variable names are not unique. To make them unique, call `.var_names_make_unique`.


Done, continuing to training.
Calculating weights
Initializing network
Initializing explain matrix


In [38]:
module.num_labels

13

In [39]:
trainer = pl.Trainer()

trainer.test(model, datamodule=module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_deprecation(
  rank_zero_warn(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

tensor([1., 1., nan, 1., nan, nan, nan, nan, nan, 1., 1., nan, nan])
nan
tensor([1.0000, 1.0000,    nan, 1.0000,    nan, 1.0000,    nan, 0.6667,    nan,
        0.9231, 0.6667,    nan, 1.0000])
nan
tensor([nan, nan, 0., 1., nan, nan, 0., nan, 0., 1., 1., 0., 1.])
nan
tensor([1., nan, nan, nan, nan, 1., nan, 1., nan, 1., 1., nan, nan])
nan
tensor([0.6667,    nan,    nan,    nan,    nan, 1.0000,    nan,    nan,    nan,
        1.0000, 1.0000,    nan, 0.0000])
nan
tensor([nan, 1., nan, 1., 1., 1., nan, 1., nan, 1., 1., nan, 1.])
nan
tensor([   nan,    nan,    nan,    nan, 1.0000,    nan,    nan, 0.6667,    nan,
        0.9091, 0.9091, 0.0000, 1.0000])
nan
tensor([1.0000, 1.0000,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
        0.9231, 1.0000, 0.0000,    nan])
nan
tensor([nan, nan, nan, 1., nan, nan, nan, 1., nan, 1., 1., nan, 1.])
nan
tensor([1.0000,    nan,    nan, 1.0000, 1.0000,    nan, 1.0000,    nan,    nan,
        0.9091, 1.0000, 0.0000,    nan])
nan
tensor([1.0000, 

tensor([1.0000,    nan,    nan,    nan,    nan,    nan,    nan, 1.0000,    nan,
        0.9091, 0.9412,    nan,    nan])
nan
tensor([   nan,    nan,    nan, 1.0000,    nan,    nan, 1.0000, 0.6667,    nan,
        0.7500, 0.9333,    nan, 1.0000])
nan
tensor([1.0000,    nan,    nan, 1.0000,    nan,    nan, 1.0000, 1.0000,    nan,
        0.9091, 0.9231,    nan,    nan])
nan
tensor([1.0000,    nan, 0.0000, 1.0000,    nan, 1.0000,    nan, 1.0000,    nan,
        1.0000, 0.9091,    nan, 0.0000])
nan
tensor([1., nan, nan, nan, 1., nan, 1., nan, nan, 1., 1., nan, nan])
nan
tensor([1., nan, nan, nan, nan, nan, 1., nan, nan, 1., 1., nan, nan])
nan
tensor([nan, nan, nan, 1., nan, 0., 1., nan, nan, 1., 1., nan, 0.])
nan
tensor([1.0000, 1.0000, 1.0000,    nan,    nan,    nan,    nan, 0.0000,    nan,
        0.8889, 1.0000,    nan, 1.0000])
nan
tensor([   nan,    nan,    nan,    nan,    nan,    nan,    nan, 1.0000,    nan,
        0.9231, 0.9412,    nan,    nan])
nan
tensor([nan, nan, nan, nan, nan

tensor([nan, nan, nan, 1., nan, nan, nan, nan, nan, 1., 1., nan, nan])
nan
tensor([0.0000,    nan, 0.0000, 1.0000,    nan,    nan,    nan, 1.0000,    nan,
        0.8750, 0.8889,    nan, 0.0000])
nan
tensor([0.6667,    nan,    nan,    nan, 1.0000,    nan, 1.0000, 1.0000,    nan,
        1.0000, 1.0000,    nan, 0.0000])
nan
tensor([   nan,    nan,    nan,    nan,    nan, 0.6667, 1.0000, 1.0000,    nan,
        0.7273, 0.8000, 0.0000,    nan])
nan
tensor([   nan, 0.0000,    nan,    nan,    nan, 1.0000,    nan,    nan, 0.0000,
        0.9000, 0.7500,    nan,    nan])
nan
tensor([1.0000,    nan,    nan,    nan, 1.0000,    nan,    nan,    nan,    nan,
        0.9091, 0.9333,    nan, 1.0000])
nan
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, 1., 1., nan, 1.])
nan
tensor([1., nan, nan, 1., 1., nan, 1., nan, nan, 1., 1., nan, nan])
nan
tensor([0.6667,    nan,    nan, 1.0000,    nan,    nan, 1.0000,    nan,    nan,
        1.0000, 0.6667,    nan,    nan])
nan
tensor([1.0000,    nan,    n

[{'test_loss': 0.2873498201370239,
  'test_loss_epoch': 0.2873498201370239,
  'test_f1_median': nan}]

Let's do some simple preprocessing for training, and then test our library

In [2]:
# le = LabelEncoder()

# labels = pd.read_csv('../data/retina/retina_labels.tsv', sep='\t')
# labels = labels[labels['CellType'].isin(['retinal rod cell type B', 'retinal rod cell type A'])]
# labels = labels[labels['CellType'] != 'unannotated']
# labels['class_label'] = le.fit_transform(labels['CellType'])

# labels.index.name = 'cell'
# # labels = labels.iloc[0:5000, :]
# labels.to_csv('../data/retina/retina_labels_numeric.csv')

# # label_df = pd.read_csv('../data/retina/retina_labels_numeric.csv', index_col='cell')
# # label_df

# # labels.to_csv('../data/retina/retina_labels_numeric.csv')

# labels = pd.read_csv('../data/retina/retina_labels_numeric.csv')
# labels.loc[:, 'class_label']

In [3]:
# labels = pd.read_csv('../data/retina/raw_labels.tsv', sep='\t')
# corrected = pd.read_csv('../data/retina/retina_labels_numeric.csv')

In [4]:
# from sklearn.model_selection import train_test_split
# trainsplit, valsplit = train_test_split(current_labels, stratify=current_labels, random_state=42)
# trainsplit

In [5]:
corrected = pd.read_csv('../data/retina/retina_labels_numeric.csv', index_col='cell')
print(corrected.shape)
# current_labels = corrected['CellType']

# trainsplit, valsplit = train_test_split(current_labels, stratify=current_labels, random_state=42)
# trainsplit

(16446, 8)


In [12]:
module = DataModule(
    datafiles=['../data/retina/retina_T.h5ad'],
    labelfiles=['../data/retina/retina_labels_numeric.csv'],
    class_label='class_label',
    index_col='cell',
    batch_size=16,
    num_workers=32,
    shuffle=True,
    drop_last=True,
    normalize=True,
    deterministic=True,
)

module.setup()

Creating train/val/test DataLoaders...


Variable names are not unique. To make them unique, call `.var_names_make_unique`.


Done, continuing to training.
Calculating weights


In [13]:
module.valloader.dataset.labels

array([10,  9,  7, ...,  9,  9,  9])

In [7]:
import pytorch_lightning as pl

lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch')
checkpoint = pl.callbacks.ModelCheckpoint(
    dirpath='checkpoints', 
    filename='{epoch}-{weighted_val_accuracy}'
)
# pruner = pl.callbacks.ModelPruning()
progressbar = pl.callbacks.RichProgressBar()

In [8]:
# tabnetmodel = TabNetLightning(
#     input_dim=module.num_features,
#     output_dim=module.num_labels,
#     optim_params={
#         'optimizer': torch.optim.Adam,
#         'lr': 0.02,
#         'weight_decay': 0,
#     },
#     scheduler_params={
#         'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau,
#         'factor': 0.001,
#     },
#     metrics={
#         'accuracy': accuracy,
#         'precision': precision,
#         'recall': recall,
#     },
# )

# wandb_logger = WandbLogger(
#     project=f"Retina Model",
#     name='local-retina-model'
# )

# early_stop = EarlyStopping(
#     monitor="weighted_val_accuracy", 
#     min_delta=0.00, 
#     patience=3,
#     verbose=False, 
#     mode="max"
# )

# trainer = pl.Trainer(
#     logger=wandb_logger,
#     callbacks=[early_stop, lr_monitor, checkpoint],
#     max_epochs=100,
# )

In [None]:
# trainer.fit(tabnetmodel, datamodule=module)

  rank_zero_deprecation(
[34m[1mwandb[0m: Currently logged in as: [33mjlehrer1[0m (use `wandb login --relogin` to force relogin)


Output()

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


In [2]:
to_test = pd.read_csv('../data/retina/fovea_for_test_labels.tsv', sep='\t')

In [3]:
train_labels = pd.read_csv('../data/retina/retina_labels_numeric.csv')

In [4]:
to_test

Unnamed: 0,cellId,final_cluster_labels,libraryID,celltype,region,disease_ontology_term_id,disease,sex,tissue_ontology_term_id,tissue,cell_type_ontology_term_id,cell_type
0,AAACCCAGTCGGCACT-1,9,fovea_donor_1,Unknown,fovea,MONDO:0005041,glaucoma (disease),male,UBERON:0001786,fovea centralis,,unknown
1,AAACGAACAGGGACTA-1,15,fovea_donor_1,Glial-F1,fovea,MONDO:0005041,glaucoma (disease),male,UBERON:0001786,fovea centralis,CL:0000125,glial cell
2,AAACGAAGTTCTCCAC-1,3,fovea_donor_1,Cones-Fov,fovea,MONDO:0005041,glaucoma (disease),male,UBERON:0001786,fovea centralis,,foveal cone photoreceptor
3,AAACGAATCATAGCAC-1,5,fovea_donor_1,Bipolar-1,fovea,MONDO:0005041,glaucoma (disease),male,UBERON:0001786,fovea centralis,CL:0000103,bipolar neuron
4,AAACGCTAGGAGGCAG-1,8A,fovea_donor_1,Horizontal,fovea,MONDO:0005041,glaucoma (disease),male,UBERON:0001786,fovea centralis,CL:0000745,retina horizontal cell
...,...,...,...,...,...,...,...,...,...,...,...,...
8212,TTTGACTCATGACTGT-6,5,peripheral_donor_3,Bipolar-1,peripheral,PATO:0000461,normal,female,UBERON:0013682,peripheral region of retina,CL:0000103,bipolar neuron
8213,TTTGACTGTCCGGACT-6,6,peripheral_donor_3,Bipolar-2,peripheral,PATO:0000461,normal,female,UBERON:0013682,peripheral region of retina,CL:0000103,bipolar neuron
8214,TTTGACTTCTGAGTCA-6,2,peripheral_donor_3,Rods-2,peripheral,PATO:0000461,normal,female,UBERON:0013682,peripheral region of retina,CL:0000604,retinal rod cell
8215,TTTGGAGCAGTTTGGT-6,6,peripheral_donor_3,Bipolar-2,peripheral,PATO:0000461,normal,female,UBERON:0013682,peripheral region of retina,CL:0000103,bipolar neuron


In [5]:
train_labels

Unnamed: 0,cell,index,Expressed Genes,UMI Count,Percent Mitochond.,BroadCellType,CellType,Donor,class_label
0,0,0024369980fd003553cbc9dfe29f7f95,2351,6060.0,4.125413,retinal rod cell,retinal rod cell type A,b8049daa-7458-47bf-8ec2-3f5c56d2cb34,9
1,2,0037f1b36684cb59b84d3585ca55ff69,950,1507.0,0.597213,retinal rod cell,retinal rod cell type B,427c0a62-9baf-42ab-a3a3-f48d10544280,10
2,3,00390952646f52d11a9ab9bba7d6ac51,961,1962.0,7.543323,retinal rod cell,retinal rod cell type A,b8049daa-7458-47bf-8ec2-3f5c56d2cb34,9
3,4,005b3351658380695a5dc46c384d72d7,858,1384.0,0.144509,retinal rod cell,retinal rod cell type B,b8049daa-7458-47bf-8ec2-3f5c56d2cb34,10
4,6,00906832f470fc434a52ac7d678a95bc,532,1054.0,6.451613,retinal rod cell,retinal rod cell type A,427c0a62-9baf-42ab-a3a3-f48d10544280,9
...,...,...,...,...,...,...,...,...,...
16441,19684,ff61c0c282f41e4a37885b05342441da,1494,3320.0,2.289157,retinal rod cell,retinal rod cell type B,427c0a62-9baf-42ab-a3a3-f48d10544280,10
16442,19689,ffa4633bef82949d2c6ac17b3ddf46e9,1813,4129.0,3.996125,retinal rod cell,retinal rod cell type A,b8049daa-7458-47bf-8ec2-3f5c56d2cb34,9
16443,19690,ffa5758b0600f47722fdc755444dfe0c,2692,8244.0,16.460457,retinal cone cell,retinal cone cell,b8049daa-7458-47bf-8ec2-3f5c56d2cb34,7
16444,19691,ffd3fd6119de767f3c3b8c47b2c28bf0,516,896.0,3.794643,retinal rod cell,retinal rod cell type A,427c0a62-9baf-42ab-a3a3-f48d10544280,9


In [7]:
to_test['cell_type'].unique(), train_labels['CellType'].unique()

(array(['unknown', 'glial cell', 'foveal cone photoreceptor',
        'bipolar neuron', 'retina horizontal cell', 'endothelial cell',
        'pericyte cell', 'retinal rod cell', 'retinal ganglion cell',
        'amacrine cell', 'microglial cell',
        'peripheral cone photoreceptor'], dtype=object),
 array(['retinal rod cell type A', 'retinal rod cell type B',
        'retinal bipolar neuron type B', 'retinal rod cell type C',
        'unspecified', 'retinal bipolar neuron type C', 'Muller cell',
        'retinal cone cell', 'retinal bipolar neuron type A',
        'amacrine cell', 'retinal bipolar neuron type D',
        'retinal ganglion cell', 'microglial cell'], dtype=object))

In [8]:
# map to_test --> training labels for testing prediction on other datasets

# mapping = {
#     'glial cell': 'microglial',
#     'foveal cone photoreceptor': 
#     'bipolar neuron': 
# }

In [4]:
model = TabNetLightning(
    input_dim=module.num_features,
    output_dim=module.num_labels,
    optim_params={
        'optimizer': torch.optim.Adam,
        'lr': 0.02,
        'weight_decay': 0,
    },
    scheduler_params={
        'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau,
        'factor': 0.001,
    },
    metrics={
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
    },
)

Initializing network
Initializing explain matrix


In [17]:
model = TabNetLightning.load_from_checkpoint(
    '../checkpoints/checkpoint-80-desc-retina.ckpt',
    input_dim=37475,
    output_dim=13,
    n_d=32,
    n_a=32,
    n_steps=10,
)

Initializing network
Initializing explain matrix


In [18]:
model

TabNetLightning(
  (network): TabNet(
    (embedder): EmbeddingGenerator()
    (tabnet): TabNetNoEmbeddings(
      (initial_bn): BatchNorm1d(37475, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      (encoder): TabNetEncoder(
        (initial_bn): BatchNorm1d(37475, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        (initial_splitter): FeatTransformer(
          (shared): GLU_Block(
            (shared_layers): ModuleList(
              (0): Linear(in_features=37475, out_features=128, bias=False)
              (1): Linear(in_features=64, out_features=128, bias=False)
            )
            (glu_layers): ModuleList(
              (0): GLU_Layer(
                (fc): Linear(in_features=37475, out_features=128, bias=False)
                (bn): GBN(
                  (bn): BatchNorm1d(128, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)
                )
              )
              (1): GLU_Layer(
                (fc): Linear

In [None]:
datafile = '../data/retina/fovea_for_test_T.h5ad'
data = an.read_h5ad('../data/retina/fovea_for_test_T.h5ad')

dataset = AnnDataMatrix(
    data=data.X,
    labels=
)

In [23]:
!ls ../data/retina

fovea_for_test_T.h5ad     retina_T.h5ad
fovea_for_test_labels.tsv retina_labels_numeric.csv


In [24]:
labels = pd.read_csv('../data/retina/fovea_for_test_labels.tsv', sep='\t')

In [25]:
labels

Unnamed: 0,cellId,final_cluster_labels,libraryID,celltype,region,disease_ontology_term_id,disease,sex,tissue_ontology_term_id,tissue,cell_type_ontology_term_id,cell_type
0,AAACCCAGTCGGCACT-1,9,fovea_donor_1,Unknown,fovea,MONDO:0005041,glaucoma (disease),male,UBERON:0001786,fovea centralis,,unknown
1,AAACGAACAGGGACTA-1,15,fovea_donor_1,Glial-F1,fovea,MONDO:0005041,glaucoma (disease),male,UBERON:0001786,fovea centralis,CL:0000125,glial cell
2,AAACGAAGTTCTCCAC-1,3,fovea_donor_1,Cones-Fov,fovea,MONDO:0005041,glaucoma (disease),male,UBERON:0001786,fovea centralis,,foveal cone photoreceptor
3,AAACGAATCATAGCAC-1,5,fovea_donor_1,Bipolar-1,fovea,MONDO:0005041,glaucoma (disease),male,UBERON:0001786,fovea centralis,CL:0000103,bipolar neuron
4,AAACGCTAGGAGGCAG-1,8A,fovea_donor_1,Horizontal,fovea,MONDO:0005041,glaucoma (disease),male,UBERON:0001786,fovea centralis,CL:0000745,retina horizontal cell
...,...,...,...,...,...,...,...,...,...,...,...,...
8212,TTTGACTCATGACTGT-6,5,peripheral_donor_3,Bipolar-1,peripheral,PATO:0000461,normal,female,UBERON:0013682,peripheral region of retina,CL:0000103,bipolar neuron
8213,TTTGACTGTCCGGACT-6,6,peripheral_donor_3,Bipolar-2,peripheral,PATO:0000461,normal,female,UBERON:0013682,peripheral region of retina,CL:0000103,bipolar neuron
8214,TTTGACTTCTGAGTCA-6,2,peripheral_donor_3,Rods-2,peripheral,PATO:0000461,normal,female,UBERON:0013682,peripheral region of retina,CL:0000604,retinal rod cell
8215,TTTGGAGCAGTTTGGT-6,6,peripheral_donor_3,Bipolar-2,peripheral,PATO:0000461,normal,female,UBERON:0013682,peripheral region of retina,CL:0000103,bipolar neuron
