In [1]:
from typing import *

import torch 
import numpy as np 
import shutil 
import json 
import zipfile 
import io 
import pytorch_lightning as pl 
from scipy.sparse import csc_matrix 
from pathlib import Path 
from pytorch_tabnet.utils import (
    create_explain_matrix,
    ComplexEncoder,
)
import torch.nn.functional as F
from torchmetrics.functional import accuracy, precision, recall 
from pytorch_tabnet.tab_network import TabNet
import copy
import warnings

import sys, os 
sys.path.append('../src')

from torchmetrics.functional import *
import torchmetrics 

from data import *
from lightning_train import *
from model import *

from torchmetrics.functional.classification.stat_scores import _stat_scores_update, _stat_scores
from sklearn.metrics import classification_report
import model

In [3]:
module = DataModule(
    datafiles=['../data/dental/human_dental_T.h5ad'],
    labelfiles=['../data/dental/labels_human_dental.tsv'],
    class_label='cell_type',
    sep='\t',
    batch_size=16,
    num_workers=0,
#     subset=list(range(1000)),
    stratify=True,
    drop_last=False,
)

trained = TabNetLightning.load_from_checkpoint(
    '../checkpoints/checkpoint-80-desc-dental.ckpt',
    input_dim=module.num_features,
    output_dim=module.num_labels,
)

trainer = pl.Trainer()

trainer.test(trained, datamodule=module)

Initializing network
Initializing explain matrix


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Creating train/val/test DataLoaders...
Done, continuing to training.
Calculating weights


  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

Median f1 score is 0.9615384615384616 for epoch=0
Test F1 scores are [0.96153846 0.         0.97967318 0.97609329 0.5794702  0.
 0.         0.96940024 0.98089172]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_accuracy': 0.9392621517181396,
 'test_precision': 0.9392621517181396,
 'test_recall': 0.9392621517181396}
--------------------------------------------------------------------------------


  precision = tp / (tp + fp)


[{'test_accuracy': 0.9392621517181396,
  'test_precision': 0.9392621517181396,
  'test_recall': 0.9392621517181396}]

In [8]:
from functools import partial 

# def confusion_matrix(dataloader):
#     confusion_matrix = torch.zeros(num_classes, num_classes)
#     with torch.no_grad():
#         for i, (inputs, classes) in enumerate(tqdm(dataloader)):
#             outputs, _ = model(inputs)
            
#             _, preds = torch.max(outputs, 1)
#             for t, p in zip(classes.view(-1), preds.view(-1)):
#                     confusion_matrix[t.long(), p.long()] += 1
                    
#     return confusion_matrix 

# def median_f1(tps, fps, fns):
#     precisions = tps / (tps+fps)
#     recalls = tps / (tps+fns)
    
#     f1s = 2*(np.dot(precisions, recalls)) / (precisions + recalls)
    
#     return np.nanmedian(f1s)

# def per_class_f1(*args, **kwargs):
#     res = torchmetrics.functional.f1_score(*args, **kwargs, average='none')
#     return res

# def per_class_precision(*args, **kwargs):
#     res = torchmetrics.functional.precision(*args, **kwargs, average='none')
    
#     return res

# def per_class_recall(*args, **kwargs):
#     res = torchmetrics.functional.precision(*args, **kwargs, average='none')
    
#     return res 

# def weighted_accuracy(*args, **kwargs):
#     res = torchmetrics.functional.accuracy(*args, **kwargs, average='weighted')
    
#     return res 

# def balanced_accuracy(*args, **kwargs):
#     res = torchmetrics.functional.accuracy(*args, **kwargs, average='macro')
    
#     return res 

def aggregate_metrics(num_classes) -> Dict[str, Callable]:
    metrics = {
        # Accuracies
        'total_accuracy': torchmetrics.functional.accuracy,
        'balanced_accuracy': partial(balanced_accuracy, num_classes=num_classes),
        'weighted_accuracy': weighted_accuracy,
        
        # Precision, recall and f1s
        'precision': torchmetrics.functional.precision,
        'recall': torchmetrics.functional.recall,
        'f1': torchmetrics.functional.f1_score,
        
        # Per class 
        'per_class_f1': per_class_f1,
        'per_class_precision': per_class_precision,
        'per_class_recall': per_class_recall,
    }
    
    return metrics 

In [7]:
class TabNetLightning(pl.LightningModule):
    def __init__(
        self,
        input_dim,
        output_dim,
        n_d=8,
        n_a=8,
        n_steps=3,
        gamma=1.3,
        cat_idxs=[],
        cat_dims=[],
        cat_emb_dim=1,
        n_independent=2,
        n_shared=2,
        epsilon=1e-15,
        virtual_batch_size=128,
        momentum=0.02,
        mask_type="sparsemax",
        lambda_sparse = 1e-3,
        optim_params: Dict[str, float]={
            'optimizer': torch.optim.Adam,
            'lr': 0.001,
            'weight_decay': 0.01,
        },
        metrics: Dict[str, Callable]=None,
        scheduler_params: Dict[str, float]=None,
        weights=None,
        loss=None, # will default to cross_entropy
        pretrained=None,
        no_explain=False,
    ) -> None:
        super().__init__()

        # Stuff needed for training
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.lambda_sparse = lambda_sparse

        self.optim_params = optim_params
        self.scheduler_params = scheduler_params
        
        self.weights = weights 
        self.loss = loss 

        if pretrained is not None:
            self._from_pretrained(**pretrained.get_params())
        # self.device = ('cuda:0' if torch.cuda.is_available() else 'cpu!')
        
        if metrics is None:
            self.metrics = aggregate_metrics()
        else:
            self.metrics = metrics 
            
        print(f'Initializing network')
        self.network = TabNet(
            input_dim=input_dim, 
            output_dim=output_dim, 
            n_d=n_d,
            n_a=n_a,
            n_steps=n_steps,
            gamma=gamma,
            cat_idxs=cat_idxs,
            cat_dims=cat_dims,
            cat_emb_dim=cat_emb_dim,
            n_independent=n_independent,
            n_shared=n_shared,
            epsilon=epsilon,
            virtual_batch_size=virtual_batch_size,
            momentum=momentum,
            mask_type=mask_type,
        )
        
        if not no_explain:
            print(f'Initializing explain matrix')
            self.reducing_matrix = create_explain_matrix(
                self.network.input_dim,
                self.network.cat_emb_dim,
                self.network.cat_idxs,
                self.network.post_embed_dim,
            )

    def forward(self, x):
        return self.network(x)

    def _compute_loss(self, y, y_hat):
        # If user doesn't specify, just set to cross_entropy
        if self.loss is None:
            self.loss = F.cross_entropy 

        return self.loss(y, y_hat, weight=self.weights)

    def _step(self, batch, tag):
        x, y = batch
        y_hat, M_loss = self.network(x)

        loss = self._compute_loss(y_hat, y)
        # Add the overall sparsity loss
        loss = loss - self.lambda_sparse * M_loss
        self._compute_metrics(y_hat, y, tag)
        
        tp, fp, _, fn = _stat_scores_update(
            preds=y_hat,
            target=y,
            num_classes=self.output_dim,
            reduce="macro",
        )

        return {
            "loss": loss,
            "tp": tp,
            "fp": fp,
            "fn": fn,
        }

    # Calculations on step
    def training_step(self, batch, batch_idx):
        return self._step(batch, 'train')

    def validation_step(self, batch, batch_idx):
        return self._step(batch, 'val')

    def test_step(self, batch, batch_idx):
        return self._step(batch, 'test')
    
    def _epoch_end(self, step_outputs):
        tps, fps, fns = [], [], []
        
        for i in range(len(step_outputs)):
            res = step_outputs[i]
            tp, fp, fn = res['tp'], res['fp'], res['fn']
                
            tps.append(tp.numpy())
            fps.append(fp.numpy())
            fns.append(fn.numpy())
            
        tp = np.sum(np.array(tps), axis=0)
        fp = np.sum(np.array(fps), axis=0)
        fn = np.sum(np.array(fns), axis=0)
        
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        f1s = 2*(precision * recall) / (precision + recall)
        f1s = np.nan_to_num(f1s)
        print(f"Median f1 score is {np.nanmedian(f1s)} for epoch={self.current_epoch}")

    # Calculation on epoch end, for "median F1 score"
    def training_epoch_end(self, step_outputs):
        self._epoch_end(step_outputs)
        
    def validation_epoch_end(self, step_outputs):
        self._epoch_end(step_outputs) 
    
    def test_epoch_end(self, step_outputs):
        self._epoch_end(step_outputs) 
    
    def configure_optimizers(self):
        if 'optimizer' in self.optim_params:
            optimizer = self.optim_params.pop('optimizer')
            optimizer = optimizer(self.parameters(), **self.optim_params)
        else:
            optimizer = torch.optim.Adam(self.parameters(), lr=0.2, weight_decay=1e-5)

        if self.scheduler_params is not None:
            scheduler = self.scheduler_params.pop('scheduler')
            scheduler = scheduler(optimizer, **self.scheduler_params)

        if self.scheduler_params is None:
            return optimizer
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': scheduler,
            'monitor': 'val_loss',
        }
    
    def _compute_metrics(self, 
        y_hat: torch.Tensor, 
        y: torch.Tensor, 
        tag: str,
        on_epoch=True, 
        on_step=False,
    ):
        metrics = {}
        for name, metric in self.metrics.items():
            val = metric(y_hat, y)
            metrics[name] = val
            self.log(
                f"{tag}_{name}", 
                val, 
                on_epoch=on_epoch, 
                on_step=on_step,
                logger=True,
            )

# Test Module with Small Dental Dataset

In [5]:
module = DataModule(
    datafiles=['../data/dental/human_dental_T.h5ad'],
    labelfiles=['../data/dental/labels_human_dental.tsv'],
    class_label='cell_type',
    sep='\t',
    batch_size=4,
    num_workers=0,
#     subset=list(range(1000)),
    stratify=True,
    drop_last=False,
)

module.setup()
wandb_logger = WandbLogger(
    project=f"custom metric tests",
    name='Dental Model, First 500 samples'
)

model = TabNetLightning(
    input_dim=module.num_features,
    output_dim=module.num_labels,
    no_explain=True,
)

trainer = pl.Trainer(
    max_epochs=200,
#   logger=wandb_logger,
)

# trainer.fit(model, datamodule=module)

Creating train/val/test DataLoaders...
Done, continuing to training.
Calculating weights


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Initializing network


In [6]:
trained = TabNetLightning.load_from_checkpoint(
    '../checkpoints/checkpoint-80-desc-dental.ckpt',
    input_dim=module.num_features,
    output_dim=module.num_labels,
)

trainer.test(trained, datamodule=module)

Initializing network
Initializing explain matrix


  rank_zero_deprecation(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

ValueError: When you set `average` as macro, you have to provide the number of classes.

In [None]:
labels = []
for X, y in module.trainloader:
    labels.extend(list(y.numpy()))

In [None]:
set(labels)

In [None]:
labels = []
for X, y in module.trainloader:
    print(y.numpy())
    labels.extend(list(y.numpy()))
    
print(set(labels))

In [None]:
labels = []
for X, y in module.valloader:
    labels.extend(list(y.numpy()))
    
print(set(labels))

In [None]:
labels = []
for X, y in module.testloader:
    labels.extend(list(y.numpy()))
    
print(set(labels))

In [None]:
labels = pd.read_csv('../data/dental/labels_human_dental.tsv', sep='\t')

In [None]:
labels['cell_type'].value_counts()

In [None]:
module = DataModule(
    datafiles=['../data/dental/human_dental_T.h5ad'],
    labelfiles=['../data/dental/labels_human_dental.tsv'],
    class_label='cell_type',
    sep='\t',
    batch_size=64,
    num_workers=0,
    subset=list(range(1000)),
)

module.num_labels