In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl

from torch import Tensor
from lightly.models import utils
from typing import Optional, List,Tuple, Dict
from sklearn.metrics import roc_auc_score, average_precision_score

In [2]:
class EvaluationModel(pl.LightningModule):
    def __init__(self,
                backbone:nn.Module,
                learning_rate: float =  1e-3,
                weight_decay: float = 0.0,
                output_dim: int = 14,
                freeze: bool = False,
                max_epochs: int = 50,
                mask_ratio: float = 0.15,
                ) -> None:
        super().__init__()

        # self.save_hyperparameters() 
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.max_epochs = max_epochs
        self.output_dim = output_dim
        self.mask_ratio = mask_ratio
        self.backbone = backbone
        
        self.train_step_preds = []
        self.train_step_label = []
        
        self.val_step_preds = []
        self.val_step_label = []
        
        self.test_step_preds = []
        self.test_step_label = []
        
     
        if freeze:
            utils.deactivate_requires_grad(self.backbone)
        else:
            utils.activate_requires_grad(self.backbone)
               
    def forward(self,
                x: Tensor
               ) -> Tensor:
        x = self.backbone(x)
        x = torch.sigmoid(x)
        return x
    
    

    def training_step(self, 
                      batch: List[Tensor], 
                      batch_idx: int
                     ) -> float:
        
        input, label = batch
        prediction = self.forward(input)
        
        self.train_step_label.append(label)
        self.train_step_preds.append(prediction)

        loss = nn.BCELoss()(prediction, label)            
        self.log("train_loss", loss, on_epoch= True,on_step=False , logger=True, prog_bar=True)
        
        return {'loss':loss,
                'pred':prediction,
                'label':label}

    def on_train_epoch_end(self) -> None:


        
        y = torch.cat(self.train_step_label).detach().cpu()
        pred = torch.cat(self.train_step_preds).detach().cpu()

        auroc = np.round(roc_auc_score(y, pred), 4)
        auprc = np.round(average_precision_score(y, pred), 4)   
        self.log('train_auroc',auroc, on_epoch=True, on_step=False,logger=True, prog_bar=True)
        self.log('train_auprc',auprc, on_epoch=True, on_step=False,logger=True, prog_bar=True)      
        self.train_step_label.clear()
        self.train_step_preds.clear()
        
    def validation_step (self, 
                      batch: List[Tensor], 
                      batch_idx: int
                     ) -> float:
        
        input,label = batch
        prediction = self.forward(input) 
        
        self.val_step_label.append(label)
        self.val_step_preds.append(prediction)

        loss = self._bce_loss(prediction, label,mode='val')       
        self.log("val_loss", loss, on_epoch= True,on_step=False,logger=True, prog_bar=True)

        return {'loss':loss,
                'pred':prediction,
                'label':label}

    def on_validation_epoch_end(self,*arg, **kwargs) -> None:
        
        y = torch.cat(self.val_step_label).detach().cpu()
        pred = torch.cat(self.val_step_preds).detach().cpu()

#         i=0
#         for output in outputs:
#             if i==0:
#                 y = output['label'].detach().cpu()
#                 pred = output['pred'].detach().cpu()
#             else:
#                 y = torch.cat((y, output['label'].detach().cpu()))
#                 pred = torch.cat((pred, output['pred'].detach().cpu()))
#             i+=1

        auroc = np.round(roc_auc_score(y, pred), 4)
        auprc = np.round(average_precision_score(y, pred), 4)   
        self.log('val_auroc',auroc, on_epoch=True, on_step=False, logger=True, prog_bar=True)
        self.log('val_auprc',auprc, on_epoch=True, on_step=False, logger=True, prog_bar=True)    
        self.val_step_label.clear()
        self.val_step_preds.clear()
        
    def test_step(self, 
                  batch: List[Tensor], 
                  batch_idx: int
                 ) -> float:
        input, label = batch
        prediction = self.forward(input)
        
        self.test_step_label.append(label)
        self.test_step_preds.append(prediction)
        
        loss = self._bce_loss(prediction, label,mode='test')
        self.log("test_loss", loss, on_epoch= True,on_step=False , logger=True, prog_bar=True)

        
        return {'loss':loss,
                'pred':prediction,
                'label':label}

    def on_test_epoch_end(self,*arg, **kwargs) -> None:
        y = torch.cat(self.test_step_label).detach().cpu()
        pred = torch.cat(self.test_step_preds).detach().cpu()

#         i=0
#         for output in outputs:
#             if i==0:
#                 y = output['label'].detach().cpu()
#                 pred = output['pred'].detach().cpu()
#             else:
#                 y = torch.cat((y, output['label'].detach().cpu()))
#                 pred = torch.cat((pred, output['pred'].detach().cpu()))
#             i+=1

        auroc = np.round(roc_auc_score(y, pred), 4)
        auprc = np.round(average_precision_score(y, pred), 4)   
        self.log('test_auroc',auroc, on_epoch=True, on_step=False, logger=True)
        self.log('test_auprc',auprc, on_epoch=True, on_step=False, logger=True)   
        
        self.test_step_label.clear()
        self.test_step_preds.clear()

    def configure_optimizers(self):

        optimizer = optim.Adam(params=self.parameters(), 
                                   lr=self.learning_rate, 
                                   weight_decay=self.weight_decay
                                   )

        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,
                                                         eta_min=0,
                                                         T_max=self.max_epochs
                                                         )
        
        return {'optimizer': optimizer,
                'lr_scheduler': scheduler
               }
    

    
    def _bce_loss(self, preds, y,mode='train'):
        loss = nn.BCELoss()(preds, y)
        if torch.is_tensor(y):
            y = y.detach().cpu().numpy()
        return loss

In [3]:
import os
import glob
import torch
import numpy as np
import pandas as pd 
import pytorch_lightning as pl
import torchvision.transforms as T

from PIL import Image
from typing import Optional
from torch.utils.data import Dataset, DataLoader

In [4]:
def preprocess(data_dir:str,
               paths: list,
               split:str,
               ) -> Tuple[List]:
    CLASSES  = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
                'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
                'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other',
                'Pneumonia', 'Pneumothorax', 'Support Devices']
        
    filenames_to_path = {path.split('/')[-1].split('.')[0]: path for path in paths}
    metadata = pd.read_csv(os.path.join(data_dir,'mimic-cxr-2.0.0-metadata.csv'))
    labels = pd.read_csv(os.path.join(data_dir,'mimic-cxr-2.0.0-chexpert.csv'))
    labels[CLASSES] = labels[CLASSES].fillna(0)
    labels = labels.replace(-1.0, 0.0)
    splits = pd.read_csv(os.path.join(data_dir,'mimic-cxr-ehr-split.csv'))
    metadata_with_labels = metadata.merge(labels[CLASSES+['study_id'] ], how='inner', on='study_id')
    filesnames_to_labels = dict(zip(metadata_with_labels['dicom_id'].values, metadata_with_labels[CLASSES].values))
    filenames_loaded = splits.loc[splits.split==split]['dicom_id'].values
    
    filenames_loaded = [filename  for filename in filenames_loaded if filename in filesnames_to_labels]

    return filenames_to_path, filenames_loaded, filesnames_to_labels

In [5]:
MIMIC_NORMALIZE ={"mean":torch.tensor([0.4723, 0.4723, 0.4723]), 
                  "std":torch.tensor([0.3023, 0.3023, 0.3023])}

train_transforms = T.Compose([T.Resize(256),
                              T.RandomHorizontalFlip(),
                              T.RandomAffine(degrees=45, scale=(.85, 1.15), shear=0, translate=(0.15, 0.15)),
                              T.CenterCrop(224),
                              T.ToTensor(),
                              T.Normalize(mean=MIMIC_NORMALIZE["mean"],
                                          std=MIMIC_NORMALIZE["std"])                                                                   
                            ])


val_test_transforms = T.Compose([T.Resize(256),
                                 T.CenterCrop(224),
                                 T.ToTensor(),
                                 T.Normalize(mean=MIMIC_NORMALIZE["mean"],
                                             std=MIMIC_NORMALIZE["std"])                                                                
                            ])

In [6]:
class MIMICCXR(Dataset):
    def __init__(self, 
                 paths: str,
                 data_dir: str, 
                 transform: Optional[T.Compose] = None, 
                 split: str = 'validate',
                 percentage:float = 1.0
                 ) -> None:
        self.data_dir = data_dir
        self.transform = transform
        self.filenames_to_path, \
        self.filenames_loaded, \
        self.filesnames_to_labels = preprocess(data_dir=self.data_dir,
                                               paths=paths,
                                               split=split
                                              )
        limit = (round(len(self.filenames_loaded) * percentage))
        self.filenames_loaded = self.filenames_loaded[0:limit]
 
        
    def __getitem__(self, index):
        if isinstance(index, str):
            img = Image.open(self.filenames_to_path[index]).convert('RGB')
            labels = torch.tensor(self.filesnames_to_labels[index]).float()

            if self.transform is not None:
                img = self.transform(img)
            return img, labels
        
        filename = self.filenames_loaded[index]
        
        img = Image.open(self.filenames_to_path[filename]).convert('RGB')

        labels = torch.tensor(self.filesnames_to_labels[filename]).float()
        
            
        if self.transform is not None:
            img = self.transform(img)
        return img, labels

    def __len__(self):
        return len(self.filenames_loaded)

In [7]:
data_dir = '/scratch/fs999/shamoutlab/data/physionet.org/files/mimic-cxr-jpg/2.0.0'

In [8]:
paths = glob.glob(os.path.join(data_dir,'resized','**','*.jpg'), recursive=True)
train_dataset = MIMICCXR(paths=paths,
                         data_dir= data_dir, 
                         split='train', 
                         transform=train_transforms,
                         )
val_dataset = MIMICCXR(paths=paths,
                        data_dir=data_dir, 
                        split='validate', 
                        transform=val_test_transforms,
                        )
test_dataset = MIMICCXR(paths=paths,
                        data_dir=data_dir, 
                        split='test', 
                        transform=val_test_transforms,
                        )

In [9]:
train_dataloader = DataLoader(dataset=train_dataset,
                             batch_size=64,
                             shuffle=True,
                             num_workers=16,
                             pin_memory=True,
                             drop_last=True
                             )

val_dataloader = DataLoader(dataset=val_dataset,
                             batch_size=64,
                             shuffle=False,
                             num_workers=16,
                             pin_memory=True,
                             drop_last=True
                             )

test_dataloader = DataLoader(dataset=test_dataset,
                             batch_size=64,
                             shuffle=False,
                             num_workers=16,
                             pin_memory=True,
                             drop_last=True
                             )

In [10]:
from torchvision.models.vision_transformer import VisionTransformer

In [11]:
backbone = VisionTransformer(image_size=224,
                             patch_size=16,
                             num_layers=12,
                             num_heads=6,
                             hidden_dim=768,
                             mlp_dim=768*4)

In [18]:
checkpoint_dir = '/scratch/sas10092/ChexMSN/notebooks/lightning_logs/version_13/checkpoints/epoch=2-step=17637.ckpt'
all_weights = torch.load(checkpoint_dir,map_location='cpu')['state_dict']
# /scratch/sas10092/ChexMSN/models/lightning_logs/4181612-test-run-rand-cls/epoch=0-step=11757.ckpt

In [19]:
def parse_weights(weights: Dict[str,Tensor]) -> Dict[str,Tensor]:
    
    for k in list(weights.keys()):

        if k.startswith('backbone.'):
            
            if k.startswith('backbone.') and not k.startswith('backbone.heads'):
                
                weights[k[len("backbone."):]] = weights[k]
                
        del weights[k]
#     del weights['class_token'] 
#     del weights['encoder.pos_embedding']    
    return weights

In [20]:
all_weights

OrderedDict([('model.backbone.class_token',
              tensor([[[ 4.3963e-01,  1.0040e+00, -1.2678e-01, -5.1735e-01, -1.0187e-01,
                        -1.2114e-01,  1.2429e+00,  1.0870e+00, -7.4632e-01, -1.8406e-01,
                         1.0492e+00,  5.7812e-02,  6.0123e-01,  1.9954e-01,  3.5261e-01,
                         4.6070e-01,  1.7805e+00, -3.4832e-01, -1.4319e+00, -3.9237e-01,
                         5.9781e-01,  4.9268e-01,  5.6679e-01,  3.6010e-01,  1.7284e+00,
                        -3.1344e-01,  9.9826e-02, -1.6348e+00, -8.6738e-01,  5.7307e-01,
                         1.5731e-01, -1.2313e+00,  1.3724e-01, -1.9400e+00,  9.4036e-01,
                         9.7819e-01,  1.3731e-01, -8.9377e-01, -1.5845e-01,  9.4274e-01,
                         1.2089e+00,  4.6240e-02,  8.7899e-01, -1.4404e+00, -1.4930e+00,
                         4.2723e-01, -2.4042e-01,  6.1047e-01, -1.8352e+00, -5.1868e-01,
                        -1.8879e-01,  9.1617e-02,  6.7383e-01, -1.

In [17]:
all_weights

OrderedDict([('model.backbone.class_token',
              tensor([[[ 4.4363e-01,  1.0120e+00, -1.2830e-01, -5.2085e-01, -1.0102e-01,
                        -1.2314e-01,  1.2513e+00,  1.0928e+00, -7.5201e-01, -1.8729e-01,
                         1.0546e+00,  5.8526e-02,  6.0819e-01,  1.9976e-01,  3.5750e-01,
                         4.6206e-01,  1.7907e+00, -3.5272e-01, -1.4404e+00, -3.9627e-01,
                         6.0377e-01,  4.9618e-01,  5.6986e-01,  3.6200e-01,  1.7374e+00,
                        -3.1511e-01,  9.9121e-02, -1.6423e+00, -8.7386e-01,  5.7538e-01,
                         1.5922e-01, -1.2378e+00,  1.3804e-01, -1.9526e+00,  9.4625e-01,
                         9.8843e-01,  1.3954e-01, -8.9808e-01, -1.5822e-01,  9.4535e-01,
                         1.2156e+00,  4.4827e-02,  8.8538e-01, -1.4515e+00, -1.5025e+00,
                         4.2854e-01, -2.4300e-01,  6.1426e-01, -1.8440e+00, -5.2228e-01,
                        -1.9001e-01,  9.1082e-02,  6.7879e-01, -1.

In [15]:
weight = parse_weights(all_weights)

In [16]:
backbone.load_state_dict(weight,strict=False)

_IncompatibleKeys(missing_keys=['heads.head.weight', 'heads.head.bias'], unexpected_keys=[])

In [17]:
model = EvaluationModel(backbone=backbone,
                        learning_rate=0.001,freeze=True)
# tiny running now

In [18]:
model.backbone.heads.head = nn.Linear(in_features=model.backbone.heads.head.in_features,
                                      out_features=model.output_dim)
model.backbone.heads.head

Linear(in_features=768, out_features=14, bias=True)

In [19]:
trainer = pl.Trainer(max_epochs=50,num_sanity_val_steps=0)

/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/sas10092/.conda/envs/chexmsn-env/lib/python3.9 ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]

In [20]:
trainer.fit(model=model, train_dataloaders=train_dataloader,val_dataloaders=val_dataloader)
trainer.test(model=model,dataloaders=test_dataloader)

You are using a CUDA device ('A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type              | Params
-----------------------------------------------
0 | backbone | VisionTransformer | 85.8 M
-----------------------------------------------
10.8 K    Trainable params
85.8 M    Non-trainable params
85.8 M    Total params
343.238   Total estimated model params size (MB)


Epoch 0: 100%|██████████| 5081/5081 [15:10<00:00,  5.58it/s, v_num=1]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/238 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/238 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 1/238 [00:00<00:37,  6.30it/s][A
Validation DataLoader 0:   1%|          | 2/238 [00:00<00:36,  6.51it/s][A
Validation DataLoader 0:   1%|▏         | 3/238 [00:00<00:35,  6.62it/s][A
Validation DataLoader 0:   2%|▏         | 4/238 [00:00<00:35,  6.67it/s][A
Validation DataLoader 0:   2%|▏         | 5/238 [00:00<00:34,  6.70it/s][A
Validation DataLoader 0:   3%|▎         | 6/238 [00:00<00:34,  6.71it/s][A
Validation DataLoader 0:   3%|▎         | 7/238 [00:01<00:34,  6.72it/s][A
Validation DataLoader 0:   3%|▎         | 8/238 [00:01<00:34,  6.72it/s][A
Validation DataLoader 0:   4%|▍         | 9/238 [00:01<00:34,  6.72it/s][A
Validation DataLoader 0:   4%|▍         | 10/238 [00:01<00:33,  6.74it/s]

Validation DataLoader 0:  88%|████████▊ | 209/238 [00:30<00:04,  6.78it/s][A
Validation DataLoader 0:  88%|████████▊ | 210/238 [00:30<00:04,  6.78it/s][A
Validation DataLoader 0:  89%|████████▊ | 211/238 [00:31<00:03,  6.78it/s][A
Validation DataLoader 0:  89%|████████▉ | 212/238 [00:31<00:03,  6.79it/s][A
Validation DataLoader 0:  89%|████████▉ | 213/238 [00:31<00:03,  6.79it/s][A
Validation DataLoader 0:  90%|████████▉ | 214/238 [00:31<00:03,  6.78it/s][A
Validation DataLoader 0:  90%|█████████ | 215/238 [00:31<00:03,  6.78it/s][A
Validation DataLoader 0:  91%|█████████ | 216/238 [00:31<00:03,  6.78it/s][A
Validation DataLoader 0:  91%|█████████ | 217/238 [00:31<00:03,  6.79it/s][A
Validation DataLoader 0:  92%|█████████▏| 218/238 [00:32<00:02,  6.79it/s][A
Validation DataLoader 0:  92%|█████████▏| 219/238 [00:32<00:02,  6.79it/s][A
Validation DataLoader 0:  92%|█████████▏| 220/238 [00:32<00:02,  6.79it/s][A
Validation DataLoader 0:  93%|█████████▎| 221/238 [00:32<00:02, 

Validation DataLoader 0:  74%|███████▍  | 177/238 [00:26<00:08,  6.79it/s][A
Validation DataLoader 0:  75%|███████▍  | 178/238 [00:26<00:08,  6.79it/s][A
Validation DataLoader 0:  75%|███████▌  | 179/238 [00:26<00:08,  6.79it/s][A
Validation DataLoader 0:  76%|███████▌  | 180/238 [00:26<00:08,  6.79it/s][A
Validation DataLoader 0:  76%|███████▌  | 181/238 [00:26<00:08,  6.79it/s][A
Validation DataLoader 0:  76%|███████▋  | 182/238 [00:26<00:08,  6.79it/s][A
Validation DataLoader 0:  77%|███████▋  | 183/238 [00:26<00:08,  6.79it/s][A
Validation DataLoader 0:  77%|███████▋  | 184/238 [00:27<00:07,  6.79it/s][A
Validation DataLoader 0:  78%|███████▊  | 185/238 [00:27<00:07,  6.79it/s][A
Validation DataLoader 0:  78%|███████▊  | 186/238 [00:27<00:07,  6.79it/s][A
Validation DataLoader 0:  79%|███████▊  | 187/238 [00:27<00:07,  6.79it/s][A
Validation DataLoader 0:  79%|███████▉  | 188/238 [00:27<00:07,  6.79it/s][A
Validation DataLoader 0:  79%|███████▉  | 189/238 [00:27<00:07, 

Validation DataLoader 0:  61%|██████    | 145/238 [00:21<00:13,  6.79it/s][A
Validation DataLoader 0:  61%|██████▏   | 146/238 [00:21<00:13,  6.79it/s][A
Validation DataLoader 0:  62%|██████▏   | 147/238 [00:21<00:13,  6.79it/s][A
Validation DataLoader 0:  62%|██████▏   | 148/238 [00:21<00:13,  6.79it/s][A
Validation DataLoader 0:  63%|██████▎   | 149/238 [00:21<00:13,  6.79it/s][A
Validation DataLoader 0:  63%|██████▎   | 150/238 [00:22<00:12,  6.79it/s][A
Validation DataLoader 0:  63%|██████▎   | 151/238 [00:22<00:12,  6.80it/s][A
Validation DataLoader 0:  64%|██████▍   | 152/238 [00:22<00:12,  6.80it/s][A
Validation DataLoader 0:  64%|██████▍   | 153/238 [00:22<00:12,  6.79it/s][A
Validation DataLoader 0:  65%|██████▍   | 154/238 [00:22<00:12,  6.79it/s][A
Validation DataLoader 0:  65%|██████▌   | 155/238 [00:22<00:12,  6.79it/s][A
Validation DataLoader 0:  66%|██████▌   | 156/238 [00:22<00:12,  6.79it/s][A
Validation DataLoader 0:  66%|██████▌   | 157/238 [00:23<00:11, 

Validation DataLoader 0:  47%|████▋     | 113/238 [00:16<00:18,  6.81it/s][A
Validation DataLoader 0:  48%|████▊     | 114/238 [00:16<00:18,  6.81it/s][A
Validation DataLoader 0:  48%|████▊     | 115/238 [00:16<00:18,  6.81it/s][A
Validation DataLoader 0:  49%|████▊     | 116/238 [00:17<00:17,  6.81it/s][A
Validation DataLoader 0:  49%|████▉     | 117/238 [00:17<00:17,  6.81it/s][A
Validation DataLoader 0:  50%|████▉     | 118/238 [00:17<00:17,  6.81it/s][A
Validation DataLoader 0:  50%|█████     | 119/238 [00:17<00:17,  6.81it/s][A
Validation DataLoader 0:  50%|█████     | 120/238 [00:17<00:17,  6.81it/s][A
Validation DataLoader 0:  51%|█████     | 121/238 [00:17<00:17,  6.81it/s][A
Validation DataLoader 0:  51%|█████▏    | 122/238 [00:17<00:17,  6.81it/s][A
Validation DataLoader 0:  52%|█████▏    | 123/238 [00:18<00:16,  6.81it/s][A
Validation DataLoader 0:  52%|█████▏    | 124/238 [00:18<00:16,  6.81it/s][A
Validation DataLoader 0:  53%|█████▎    | 125/238 [00:18<00:16, 

Validation DataLoader 0:  34%|███▍      | 81/238 [00:11<00:23,  6.81it/s][A
Validation DataLoader 0:  34%|███▍      | 82/238 [00:12<00:22,  6.81it/s][A
Validation DataLoader 0:  35%|███▍      | 83/238 [00:12<00:22,  6.81it/s][A
Validation DataLoader 0:  35%|███▌      | 84/238 [00:12<00:22,  6.81it/s][A
Validation DataLoader 0:  36%|███▌      | 85/238 [00:12<00:22,  6.81it/s][A
Validation DataLoader 0:  36%|███▌      | 86/238 [00:12<00:22,  6.81it/s][A
Validation DataLoader 0:  37%|███▋      | 87/238 [00:12<00:22,  6.81it/s][A
Validation DataLoader 0:  37%|███▋      | 88/238 [00:12<00:22,  6.81it/s][A
Validation DataLoader 0:  37%|███▋      | 89/238 [00:13<00:21,  6.81it/s][A
Validation DataLoader 0:  38%|███▊      | 90/238 [00:13<00:21,  6.81it/s][A
Validation DataLoader 0:  38%|███▊      | 91/238 [00:13<00:21,  6.81it/s][A
Validation DataLoader 0:  39%|███▊      | 92/238 [00:13<00:21,  6.81it/s][A
Validation DataLoader 0:  39%|███▉      | 93/238 [00:13<00:21,  6.81it/s][A

/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/sas10092/.conda/envs/chexmsn-env/lib/python3.9 ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 572/572 [01:24<00:00,  6.76it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auprc                 0.2628
       test_auroc                 0.6979
        test_loss           0.3202608823776245
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.3202608823776245, 'test_auroc': 0.6979, 'test_auprc': 0.2628}]

# Evaluate

In [21]:
ev = pd.read_csv('/scratch/sas10092/ChexMSN/evaluate/lightning_logs/version_0/metrics.csv')