In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl

from torch import Tensor
from lightly.models import utils
from typing import Optional, List,Tuple, Dict
from sklearn.metrics import roc_auc_score, average_precision_score
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping

In [2]:
def evaluate_new(df):
    yt =np.array([np.array(x) for x in df['y_truth'].values])
    yp =np.array([np.array(x) for x in df['y_pred'].values])
    auroc = roc_auc_score(yt, yp)
    auprc = average_precision_score(yt, yp)
    return auprc, auroc

def bootstraping_eval(df, num_iter):
    """This function samples from the testing dataset to generate a list of performance metrics using bootstraping method"""
    auroc_list = []
    auprc_list = []
    for _ in range(num_iter):
        sample = df.sample(frac=1, replace=True)
        auprc, auroc = evaluate_new(sample)
        auroc_list.append(auroc)
        auprc_list.append(auprc)
    return auprc_list, auroc_list

def computing_confidence_intervals(list_,true_value):

    """This function calcualts the 95% Confidence Intervals"""
    delta = (true_value - list_)
    list(np.sort(delta))
    delta_lower = np.percentile(delta, 97.5)
    delta_upper = np.percentile(delta, 2.5)

    upper = true_value - delta_upper
    lower = true_value - delta_lower
    return (upper,lower)

def get_model_performance(df):
    test_auprc, test_auroc = evaluate_new(df)
    auprc_list, auroc_list = bootstraping_eval(df, num_iter=1000)
    upper_auprc, lower_auprc = computing_confidence_intervals(auprc_list, test_auprc)
    upper_auroc, lower_auroc = computing_confidence_intervals(auroc_list, test_auroc)
    print("--------------")
    text_a=str(f"AUROC {round(test_auroc, 3)} ({round(lower_auroc, 3)}, {round(upper_auroc, 3)}) CI 95%")
    text_b=str(f"AUPRC {round(test_auprc, 3)} ({round(lower_auprc, 3)}, {round(upper_auprc, 3)}) CI 95% ")
    print(text_a)
    print(text_b)

    return (test_auprc, upper_auprc, lower_auprc), (test_auroc, upper_auroc, lower_auroc), (text_a,text_b)

In [3]:
class EvaluationModel(pl.LightningModule):
    def __init__(self,
                backbone:nn.Module,
                learning_rate: float =  1e-3,
                weight_decay: float = 0.0,
                output_dim: int = 14,
                freeze: bool = False,
                max_epochs: int = 50,
                mask_ratio: float = 0.15,
                ) -> None:
        super().__init__()

        # self.save_hyperparameters() 
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.max_epochs = max_epochs
        self.output_dim = output_dim
        self.mask_ratio = mask_ratio
        self.backbone = backbone
        
        self.train_step_preds = []
        self.train_step_label = []
        
        self.val_step_preds = []
        self.val_step_label = []
        
        self.test_step_preds = []
        self.test_step_label = []
        
     
        if freeze:
            utils.deactivate_requires_grad(self.backbone)
        else:
            utils.activate_requires_grad(self.backbone)
               
    def forward(self,
                x: Tensor
               ) -> Tensor:
        x = self.backbone(x)
        x = torch.sigmoid(x)
        return x
    
    

    def training_step(self, 
                      batch: List[Tensor], 
                      batch_idx: int
                     ) -> float:
        
        input, label = batch
        prediction = self.forward(input)
        
        self.train_step_label.append(label)
        self.train_step_preds.append(prediction)

        loss = nn.BCELoss()(prediction, label)            
        self.log("train_loss", loss, on_epoch= True,on_step=False , logger=True, prog_bar=True)
        
        return {'loss':loss,
                'pred':prediction,
                'label':label}

    def on_train_epoch_end(self) -> None:


        
        y = torch.cat(self.train_step_label).detach().cpu()
        pred = torch.cat(self.train_step_preds).detach().cpu()

        auroc = np.round(roc_auc_score(y, pred), 4)
        auprc = np.round(average_precision_score(y, pred), 4)   
        self.log('train_auroc',auroc, on_epoch=True, on_step=False,logger=True, prog_bar=True)
        self.log('train_auprc',auprc, on_epoch=True, on_step=False,logger=True, prog_bar=True)      
        self.train_step_label.clear()
        self.train_step_preds.clear()
        
    def validation_step (self, 
                      batch: List[Tensor], 
                      batch_idx: int
                     ) -> float:
        
        input,label = batch
        prediction = self.forward(input) 
        
        self.val_step_label.append(label)
        self.val_step_preds.append(prediction)

        loss = self._bce_loss(prediction, label,mode='val')       
        self.log("val_loss", loss, on_epoch= True,on_step=False,logger=True, prog_bar=True)

        return {'loss':loss,
                'pred':prediction,
                'label':label}

    def on_validation_epoch_end(self,*arg, **kwargs) -> None:
        
        y = torch.cat(self.val_step_label).detach().cpu()
        pred = torch.cat(self.val_step_preds).detach().cpu()

        auroc = np.round(roc_auc_score(y, pred), 4)
        auprc = np.round(average_precision_score(y, pred), 4)   
        self.log('val_auroc',auroc, on_epoch=True, on_step=False, logger=True, prog_bar=True)
        self.log('val_auprc',auprc, on_epoch=True, on_step=False, logger=True, prog_bar=True)    
        self.val_step_label.clear()
        self.val_step_preds.clear()
        
    def test_step(self, 
                  batch: List[Tensor], 
                  batch_idx: int
                 ) -> float:
        input, label = batch
        prediction = self.forward(input)
        
        self.test_step_label.append(label)
        self.test_step_preds.append(prediction)
        
        loss = self._bce_loss(prediction, label,mode='test')
        self.log("test_loss", loss, on_epoch= True,on_step=False , logger=True, prog_bar=True)

        
        return {'loss':loss,
                'pred':prediction,
                'label':label}

    def on_test_epoch_end(self,*arg, **kwargs) -> None:
        y = torch.cat(self.test_step_label).detach().cpu()
        pred = torch.cat(self.test_step_preds).detach().cpu()


        auroc = np.round(roc_auc_score(y, pred), 4)
        auprc = np.round(average_precision_score(y, pred), 4)   
        self.log('test_auroc',auroc, on_epoch=True, on_step=False, logger=True)
        self.log('test_auprc',auprc, on_epoch=True, on_step=False, logger=True) 
        
        df = pd.DataFrame()
        df['y_truth'] = y.tolist()
        df['y_pred'] = pred.tolist()
        get_model_performance(df)            
          
        self.test_step_label.clear()
        self.test_step_preds.clear()

    def configure_optimizers(self):

        optimizer = optim.Adam(params=self.parameters(), 
                                   lr=self.learning_rate, 
                                   weight_decay=self.weight_decay
                                   )

        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,
                                                         eta_min=0,
                                                         T_max=self.max_epochs
                                                         )
        
        return {'optimizer': optimizer,
                'lr_scheduler': scheduler
               }
    

    
    def _bce_loss(self, preds, y,mode='train'):
        loss = nn.BCELoss()(preds, y)
        if torch.is_tensor(y):
            y = y.detach().cpu().numpy()
        return loss

In [4]:
import os
import glob
import torch
import numpy as np
import pandas as pd 
import pytorch_lightning as pl
import torchvision.transforms as T

from PIL import Image
from typing import Optional
from torch.utils.data import Dataset, DataLoader

In [5]:
def preprocess(data_dir:str,
               paths: list,
               split:str,
               ) -> Tuple[List]:
    CLASSES  = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
                'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
                'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other',
                'Pneumonia', 'Pneumothorax', 'Support Devices']
        
    filenames_to_path = {path.split('/')[-1].split('.')[0]: path for path in paths}
    metadata = pd.read_csv(os.path.join(data_dir,'mimic-cxr-2.0.0-metadata.csv'))
    labels = pd.read_csv(os.path.join(data_dir,'mimic-cxr-2.0.0-chexpert.csv'))
    labels[CLASSES] = labels[CLASSES].fillna(0)
    labels = labels.replace(-1.0, 0.0)
    splits = pd.read_csv(os.path.join(data_dir,'mimic-cxr-ehr-split.csv'))
    metadata_with_labels = metadata.merge(labels[CLASSES+['study_id'] ], how='inner', on='study_id')
    filesnames_to_labels = dict(zip(metadata_with_labels['dicom_id'].values, metadata_with_labels[CLASSES].values))
    filenames_loaded = splits.loc[splits.split==split]['dicom_id'].values
    
    filenames_loaded = [filename  for filename in filenames_loaded if filename in filesnames_to_labels]

    return filenames_to_path, filenames_loaded, filesnames_to_labels

In [6]:
IMAGENET_STAT = {"mean":torch.tensor([0.4884, 0.4550, 0.4171]),
                 "std":torch.tensor([0.2596, 0.2530, 0.2556])}

train_transforms = T.Compose([T.Resize(256),
                              T.RandomHorizontalFlip(),
                              T.RandomAffine(degrees=45, scale=(.85, 1.15), shear=0, translate=(0.15, 0.15)),
                              T.CenterCrop(224),
                              T.ToTensor(),
                              T.Normalize(mean=IMAGENET_STAT["mean"],
                                          std=IMAGENET_STAT["std"])                                                                   
                            ])


val_test_transforms = T.Compose([T.Resize(256),
                                 T.CenterCrop(224),
                                 T.ToTensor(),
                                 T.Normalize(mean=IMAGENET_STAT["mean"],
                                             std=IMAGENET_STAT["std"])                                                                
                            ])

In [7]:
class MIMICCXR(Dataset):
    def __init__(self, 
                 paths: str,
                 data_dir: str, 
                 transform: Optional[T.Compose] = None, 
                 split: str = 'validate',
                 percentage:float = 1.0
                 ) -> None:
        self.data_dir = data_dir
        self.transform = transform
        self.filenames_to_path, \
        self.filenames_loaded, \
        self.filesnames_to_labels = preprocess(data_dir=self.data_dir,
                                               paths=paths,
                                               split=split
                                              )
        limit = (round(len(self.filenames_loaded) * percentage))
        self.filenames_loaded = self.filenames_loaded[0:limit]
 
        
    def __getitem__(self, index):
        if isinstance(index, str):
            img = Image.open(self.filenames_to_path[index]).convert('RGB')
            labels = torch.tensor(self.filesnames_to_labels[index]).float()

            if self.transform is not None:
                img = self.transform(img)
            return img, labels
        
        filename = self.filenames_loaded[index]
        
        img = Image.open(self.filenames_to_path[filename]).convert('RGB')

        labels = torch.tensor(self.filesnames_to_labels[filename]).float()
        
            
        if self.transform is not None:
            img = self.transform(img)
        return img, labels

    def __len__(self):
        return len(self.filenames_loaded)

In [8]:
data_dir = '/scratch/fs999/shamoutlab/data/physionet.org/files/mimic-cxr-jpg/2.0.0'

In [9]:
paths = glob.glob(os.path.join(data_dir,'resized','**','*.jpg'), recursive=True)
train_dataset = MIMICCXR(paths=paths,
                         data_dir= data_dir, 
                         split='train', 
                         transform=train_transforms,
                         )
val_dataset = MIMICCXR(paths=paths,
                        data_dir=data_dir, 
                        split='validate', 
                        transform=val_test_transforms,
                        )
test_dataset = MIMICCXR(paths=paths,
                        data_dir=data_dir, 
                        split='test', 
                        transform=val_test_transforms,
                        )

In [10]:
train_dataloader = DataLoader(dataset=train_dataset,
                             batch_size=64,
                             shuffle=True,
                             num_workers=16,
                             pin_memory=True,
                             drop_last=True
                             )

val_dataloader = DataLoader(dataset=val_dataset,
                             batch_size=64,
                             shuffle=False,
                             num_workers=16,
                             pin_memory=True,
                             drop_last=True
                             )

test_dataloader = DataLoader(dataset=test_dataset,
                             batch_size=64,
                             shuffle=False,
                             num_workers=16,
                             pin_memory=True,
                             drop_last=True
                             )

In [12]:
from torchvision.models.vision_transformer import VisionTransformer

In [13]:
backbone = VisionTransformer(image_size=224,
                             patch_size=16,
                             num_layers=12,
                             num_heads=6,
                             hidden_dim=768,
                             mlp_dim=768*4)

In [14]:
# checkpoint_dir = '/scratch/sas10092/ChexMSN/models/lightning_logs/version_1/checkpoints/epoch=99-step=589300.ckpt'
# all_weights = torch.load(checkpoint_dir,map_location='cpu')['state_dict']
# /scratch/sas10092/ChexMSN/models/lightning_logs/4181612-test-run-rand-cls/epoch=0-step=11757.ckpt

In [28]:
def parse_weights(weights: Dict[str,Tensor]) -> Dict[str,Tensor]:
    
    for k in list(weights.keys()):

        if k.startswith('backbone.'):
            
            if k.startswith('backbone.') and not k.startswith('backbone.heads'):
                
                weights[k[len("backbone."):]] = weights[k]
                
        del weights[k]
#     del weights['class_token'] 
#     del weights['encoder.pos_embedding']    
    return weights

In [29]:
# weight = parse_weights(all_weights)

In [15]:
# backbone.load_state_dict(weight,strict=False)

In [16]:
checkpoint_callback = ModelCheckpoint(monitor='val_auroc', 
                                      mode='max',
                                      every_n_epochs=1,
                                      save_top_k=1,
                                     )
early_stop = EarlyStopping(monitor='val_auroc', 
                           min_delta=0.00001,
                           mode='max', 
                           patience=4
                          )

In [17]:
model = EvaluationModel(backbone=backbone,
                        learning_rate=0.0005,freeze=False)
# tiny running now

In [18]:
model.backbone.heads.head = nn.Linear(in_features=model.backbone.heads.head.in_features,
                                      out_features=model.output_dim)
model.backbone.heads.head

Linear(in_features=768, out_features=14, bias=True)

In [19]:
trainer = pl.Trainer(max_epochs=50,num_sanity_val_steps=0)

/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/sas10092/.conda/envs/chexmsn-env/lib/python3.9 ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]

In [20]:
trainer.fit(model=model, train_dataloaders=train_dataloader,val_dataloaders=val_dataloader)
trainer.test(model=model,dataloaders=test_dataloader)

You are using a CUDA device ('A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type              | Params
-----------------------------------------------
0 | backbone | VisionTransformer | 85.8 M
-----------------------------------------------
85.8 M    Trainable params
0         Non-trainable params
85.8 M    Total params
343.238   Total estimated model params size (MB)


Epoch 0: 100%|██████████| 5081/5081 [37:14<00:00,  2.27it/s, v_num=2]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/238 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/238 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 1/238 [00:00<00:34,  6.95it/s][A
Validation DataLoader 0:   1%|          | 2/238 [00:00<00:33,  6.97it/s][A
Validation DataLoader 0:   1%|▏         | 3/238 [00:00<00:33,  6.99it/s][A
Validation DataLoader 0:   2%|▏         | 4/238 [00:00<00:33,  7.03it/s][A
Validation DataLoader 0:   2%|▏         | 5/238 [00:00<00:33,  7.06it/s][A
Validation DataLoader 0:   3%|▎         | 6/238 [00:00<00:32,  7.08it/s][A
Validation DataLoader 0:   3%|▎         | 7/238 [00:00<00:32,  7.09it/s][A
Validation DataLoader 0:   3%|▎         | 8/238 [00:01<00:32,  7.11it/s][A
Validation DataLoader 0:   4%|▍         | 9/238 [00:01<00:32,  7.12it/s][A
Validation DataLoader 0:   4%|▍         | 10/238 [00:01<00:31,  7.13it/s]

Validation DataLoader 0:  44%|████▎     | 104/238 [00:14<00:18,  7.12it/s][A
Validation DataLoader 0:  44%|████▍     | 105/238 [00:14<00:18,  7.12it/s][A
Validation DataLoader 0:  45%|████▍     | 106/238 [00:14<00:18,  7.11it/s][A
Validation DataLoader 0:  45%|████▍     | 107/238 [00:15<00:18,  7.11it/s][A
Validation DataLoader 0:  45%|████▌     | 108/238 [00:15<00:18,  7.11it/s][A
Validation DataLoader 0:  46%|████▌     | 109/238 [00:15<00:18,  7.12it/s][A
Validation DataLoader 0:  46%|████▌     | 110/238 [00:15<00:17,  7.12it/s][A
Validation DataLoader 0:  47%|████▋     | 111/238 [00:15<00:17,  7.12it/s][A
Validation DataLoader 0:  47%|████▋     | 112/238 [00:15<00:17,  7.12it/s][A
Validation DataLoader 0:  47%|████▋     | 113/238 [00:15<00:17,  7.12it/s][A
Validation DataLoader 0:  48%|████▊     | 114/238 [00:16<00:17,  7.11it/s][A
Validation DataLoader 0:  48%|████▊     | 115/238 [00:16<00:17,  7.11it/s][A
Validation DataLoader 0:  49%|████▊     | 116/238 [00:16<00:17, 

Validation DataLoader 0:  88%|████████▊ | 209/238 [00:29<00:04,  7.11it/s][A
Validation DataLoader 0:  88%|████████▊ | 210/238 [00:29<00:03,  7.12it/s][A
Validation DataLoader 0:  89%|████████▊ | 211/238 [00:29<00:03,  7.12it/s][A
Validation DataLoader 0:  89%|████████▉ | 212/238 [00:29<00:03,  7.12it/s][A
Validation DataLoader 0:  89%|████████▉ | 213/238 [00:29<00:03,  7.12it/s][A
Validation DataLoader 0:  90%|████████▉ | 214/238 [00:30<00:03,  7.12it/s][A
Validation DataLoader 0:  90%|█████████ | 215/238 [00:30<00:03,  7.12it/s][A
Validation DataLoader 0:  91%|█████████ | 216/238 [00:30<00:03,  7.12it/s][A
Validation DataLoader 0:  91%|█████████ | 217/238 [00:30<00:02,  7.12it/s][A
Validation DataLoader 0:  92%|█████████▏| 218/238 [00:30<00:02,  7.12it/s][A
Validation DataLoader 0:  92%|█████████▏| 219/238 [00:30<00:02,  7.12it/s][A
Validation DataLoader 0:  92%|█████████▏| 220/238 [00:30<00:02,  7.12it/s][A
Validation DataLoader 0:  93%|█████████▎| 221/238 [00:31<00:02, 

Validation DataLoader 0:  30%|███       | 72/238 [00:10<00:23,  7.02it/s][A
Validation DataLoader 0:  31%|███       | 73/238 [00:10<00:23,  7.02it/s][A
Validation DataLoader 0:  31%|███       | 74/238 [00:10<00:23,  7.01it/s][A
Validation DataLoader 0:  32%|███▏      | 75/238 [00:10<00:23,  7.01it/s][A
Validation DataLoader 0:  32%|███▏      | 76/238 [00:10<00:23,  7.00it/s][A
Validation DataLoader 0:  32%|███▏      | 77/238 [00:10<00:22,  7.01it/s][A
Validation DataLoader 0:  33%|███▎      | 78/238 [00:11<00:22,  7.01it/s][A
Validation DataLoader 0:  33%|███▎      | 79/238 [00:11<00:22,  7.01it/s][A
Validation DataLoader 0:  34%|███▎      | 80/238 [00:11<00:22,  7.01it/s][A
Validation DataLoader 0:  34%|███▍      | 81/238 [00:11<00:22,  7.01it/s][A
Validation DataLoader 0:  34%|███▍      | 82/238 [00:11<00:22,  7.01it/s][A
Validation DataLoader 0:  35%|███▍      | 83/238 [00:11<00:22,  7.00it/s][A
Validation DataLoader 0:  35%|███▌      | 84/238 [00:12<00:22,  7.00it/s][A

Validation DataLoader 0:  74%|███████▍  | 177/238 [00:24<00:08,  7.09it/s][A
Validation DataLoader 0:  75%|███████▍  | 178/238 [00:25<00:08,  7.09it/s][A
Validation DataLoader 0:  75%|███████▌  | 179/238 [00:25<00:08,  7.09it/s][A
Validation DataLoader 0:  76%|███████▌  | 180/238 [00:25<00:08,  7.09it/s][A
Validation DataLoader 0:  76%|███████▌  | 181/238 [00:25<00:08,  7.09it/s][A
Validation DataLoader 0:  76%|███████▋  | 182/238 [00:25<00:07,  7.09it/s][A
Validation DataLoader 0:  77%|███████▋  | 183/238 [00:25<00:07,  7.09it/s][A
Validation DataLoader 0:  77%|███████▋  | 184/238 [00:25<00:07,  7.09it/s][A
Validation DataLoader 0:  78%|███████▊  | 185/238 [00:26<00:07,  7.09it/s][A
Validation DataLoader 0:  78%|███████▊  | 186/238 [00:26<00:07,  7.10it/s][A
Validation DataLoader 0:  79%|███████▊  | 187/238 [00:26<00:07,  7.10it/s][A
Validation DataLoader 0:  79%|███████▉  | 188/238 [00:26<00:07,  7.10it/s][A
Validation DataLoader 0:  79%|███████▉  | 189/238 [00:26<00:06, 

/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/sas10092/.conda/envs/chexmsn-env/lib/python3.9 ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 572/572 [01:20<00:00,  7.14it/s]--------------
AUROC 0.557 (0.553, 0.561) CI 95%
AUPRC 0.17 (0.168, 0.172) CI 95% 
Testing DataLoader 0: 100%|██████████| 572/572 [10:42<00:00,  0.89it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auprc                 0.1699
       test_auroc                 0.5567
        test_loss           0.3589421808719635
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.3589421808719635, 'test_auroc': 0.5567, 'test_auprc': 0.1699}]

# Evaluate