In [1]:
from torchvision import models
from torch import nn
from torch.nn import init

class DensNetWithHead(nn.Module):
    def __init__(self,  hidden_layer_sizes, dropout_rate, num_classes):
        super(DensNetWithHead, self).__init__()

        # Pretrained DenseNet backbone
        self.backbone = models.densenet121(pretrained=True)
        num_features = self.backbone.classifier.in_features

        # Remove the last classification layer of the backbone
        self.backbone.classifier = nn.Identity()

        # Custom head with hidden layers
        layers = []
        input_size = num_features

        for size in hidden_layer_sizes:
            linear_layer = nn.Linear(input_size, size)
            init.kaiming_uniform_(linear_layer.weight, nonlinearity='relu')
            layers.append(linear_layer)
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(size))
            layers.append(nn.Dropout(dropout_rate))
            input_size = size

        # Output layer
        layers.append(nn.Linear(input_size, num_classes))

        # Assemble the custom head
        self.custom_head = nn.Sequential(*layers)

    def forward(self, x):
        # Forward pass through the backbone
        features = self.backbone(x)
  

        # Forward pass through the custom head
        output = self.custom_head(features)

        return output

  from .autonotebook import tqdm as notebook_tqdm


In [39]:
from data.image.dataset_class import Chexpert
import torchvision.transforms as transforms
import torch

# create the dataset

dataset = Chexpert(dataframe_path = '/fs01/home/hhamidi/projects/stable-diffusion/data/csv_files/val_from_train.csv',
                    path_image = '/datasets/chexpert/',
                    MAPPING = {-1: 1, 0: 0, 1: 1, float('NaN'): 0},
                    transform = transforms.Compose([
                        transforms.Resize([224,224]),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                    ]))

# if dataset.dataset_size > 2000:
#     dataset.dataset_size = 2000


In [40]:
class CLS():
    def __init__(self, model, device):
        self.model = model
        self.model.eval()
        self.device = device
        self.LABEL_NAMES =['Atelectasis',
                            'Cardiomegaly',
                            'Consolidation',
                            'Edema',
                            'Enlarged Cardiomediastinum',
                            'Fracture',
                            'Lung Lesion',
                            'Lung Opacity',
                            'No Finding',
                            'Pleural Effusion',
                            'Pleural Other',
                            'Pneumonia',
                            'Pneumothorax',
                            'Support Devices']
    def predict(self, image):
        image = image.to(self.device)
        with torch.no_grad():
            if len(image.shape) == 3:
                output = self.model(image[None, ...])

                
        output = torch.sigmoid(output)
        result = dict(zip(self.LABEL_NAMES, output[0].tolist()))
        return result

In [41]:
state_dict = torch.load('/fs01/home/hhamidi/fairness_on_embeddings/results/image_chexpert_real_real/checkpoints/best_val_loss-epoch=09-val_loss=0.2735.ckpt')['state_dict']

new_state_dict = {}
for k, v in state_dict.items():
    name = k.replace('model.', '')
    new_state_dict[name] = v


model = DensNetWithHead(hidden_layer_sizes=[768, 128], dropout_rate=0.0, num_classes=14)
model.load_state_dict(new_state_dict)
model = model.to('cuda')
cls = CLS(model, 'cuda')

In [42]:
from tqdm import tqdm
# run the model on the dataset
probabilities = []
labels = []

for i in tqdm(range(dataset.dataset_size)):
    item =  dataset[i]
    image, label = item['data'], item['labels']
    labels.append(label.numpy())
    output = cls.predict(image)
    probabilities.append(list(output.values()))

100%|██████████| 44683/44683 [25:13<00:00, 29.51it/s]


In [38]:
# calculate the roc auc
from metrics.metrics import find_best_threshold, calculate_roc_auc
import numpy as np
thersholds = find_best_threshold( np.array(probabilities),np.array(labels))
# create a dictionary of the thresholds with CLS.LABEL_NAMES
thersholds = dict(zip(cls.LABEL_NAMES, list(thersholds.values())))
print('thersholds: ', thersholds)
_,_,roc_auc = calculate_roc_auc(np.array(probabilities),np.array(labels))
roc_auc = dict(zip(cls.LABEL_NAMES, list(roc_auc.values())))
print('roc_auc: ', roc_auc)


thersholds:  {'Atelectasis': 0.07908083498477936, 'Cardiomegaly': 0.2774554491043091, 'Consolidation': 0.05144466832280159, 'Edema': 0.25417467951774597, 'Enlarged Cardiomediastinum': 0.0847725048661232, 'Fracture': 0.2180563360452652, 'Lung Lesion': 0.12561748921871185, 'Lung Opacity': 0.32183897495269775, 'No Finding': 0.1821669042110443, 'Pleural Effusion': 0.3480605185031891, 'Pleural Other': 0.039286863058805466, 'Pneumonia': 0.030648918822407722, 'Pneumothorax': 0.15165171027183533, 'Support Devices': 0.42294201254844666}
roc_auc:  {'Atelectasis': 0.6493029846273183, 'Cardiomegaly': 0.8603195599553128, 'Consolidation': 0.6320110751685948, 'Edema': 0.8146588978630822, 'Enlarged Cardiomediastinum': 0.6662043129304929, 'Fracture': 0.7999110742590022, 'Lung Lesion': 0.7832175786053984, 'Lung Opacity': 0.7078548857145512, 'No Finding': 0.8542649358299658, 'Pleural Effusion': 0.8611295308907205, 'Pleural Other': 0.774145772245826, 'Pneumonia': 0.7406809385422766, 'Pneumothorax': 0.8271

In [54]:
import numpy as np
import torch 
from torch.utils.data import Dataset
import torchvision.transforms as tfs
import cv2
from PIL import Image
import pandas as pd

class CheXpert(Dataset):
    '''
    Reference: 
        @inproceedings{yuan2021robust,
            title={Large-scale Robust Deep AUC Maximization: A New Surrogate Loss and Empirical Studies on Medical Image Classification},
            author={Yuan, Zhuoning and Yan, Yan and Sonka, Milan and Yang, Tianbao},
            booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
            year={2021}
            }
    '''
    def __init__(self, 
                 csv_path, 
                 image_root_path='',
                 image_size=320,
                 class_index=0, 
                 use_frontal=True,
                 use_upsampling=True,
                 flip_label=False,
                 shuffle=False,
                 seed=123,
                 verbose=True,
                 upsampling_cols=['Cardiomegaly', 'Consolidation'],
                 train_cols=['Cardiomegaly', 'Edema', 'Consolidation', 'Atelectasis',  'Pleural Effusion'],
                 mode='train'):
        
    
        # load data from csv
        self.df = pd.read_csv(csv_path)
        self.df['Path'] = self.df['Path'].str.replace('CheXpert-v1.0-small/', '')
        self.df['Path'] = self.df['Path'].str.replace('CheXpert-v1.0/', '')
        if use_frontal:
            self.df = self.df[self.df['Frontal/Lateral'] == 'Frontal']  
            
        # upsample selected cols
        if use_upsampling:
            assert isinstance(upsampling_cols, list), 'Input should be list!'
            sampled_df_list = []
            for col in upsampling_cols:
                print ('Upsampling %s...'%col)
                sampled_df_list.append(self.df[self.df[col] == 1])
            self.df = pd.concat([self.df] + sampled_df_list, axis=0)


        # impute missing values 
        for col in train_cols:
            if col in ['Edema', 'Atelectasis']:
                self.df[col].replace(-1, 1, inplace=True)  
                self.df[col].fillna(0, inplace=True) 
            elif col in ['Cardiomegaly','Consolidation',  'Pleural Effusion']:
                self.df[col].replace(-1, 0, inplace=True) 
                self.df[col].fillna(0, inplace=True)
            else:
                self.df[col].fillna(0, inplace=True)
        
        self._num_images = len(self.df)
        
        # 0 --> -1
        if flip_label and class_index != -1: # In multi-class mode we disable this option!
            self.df.replace(0, -1, inplace=True)   
            
        # shuffle data
        if shuffle:
            data_index = list(range(self._num_images))
            np.random.seed(seed)
            np.random.shuffle(data_index)
            self.df = self.df.iloc[data_index]
        
        
        assert class_index in [-1, 0, 1, 2, 3, 4], 'Out of selection!'
        assert image_root_path != '', 'You need to pass the correct location for the dataset!'

        if class_index == -1: # 5 classes
            print ('Multi-label mode: True, Number of classes: [%d]'%len(train_cols))
            self.select_cols = train_cols
            self.value_counts_dict = {}
            for class_key, select_col in enumerate(train_cols):
                class_value_counts_dict = self.df[select_col].value_counts().to_dict()
                self.value_counts_dict[class_key] = class_value_counts_dict
        else:       # 1 class
            self.select_cols = [train_cols[class_index]]  # this var determines the number of classes
            self.value_counts_dict = self.df[self.select_cols[0]].value_counts().to_dict()
        
        self.mode = mode
        self.class_index = class_index
        self.image_size = image_size
        
        self._images_list =  [image_root_path+path for path in self.df['Path'].tolist()]
        if class_index != -1:
            self._labels_list = self.df[train_cols].values[:, class_index].tolist()
        else:
            self._labels_list = self.df[train_cols].values.tolist()
    
        if verbose:
            if class_index != -1:
                print ('-'*30)
                if flip_label:
                    self.imratio = self.value_counts_dict[1]/(self.value_counts_dict[-1]+self.value_counts_dict[1])
                    print('Found %s images in total, %s positive images, %s negative images'%(self._num_images, self.value_counts_dict[1], self.value_counts_dict[-1] ))
                    print ('%s(C%s): imbalance ratio is %.4f'%(self.select_cols[0], class_index, self.imratio ))
                else:
                    self.imratio = self.value_counts_dict[1]/(self.value_counts_dict[0]+self.value_counts_dict[1])
                    print('Found %s images in total, %s positive images, %s negative images'%(self._num_images, self.value_counts_dict[1], self.value_counts_dict[0] ))
                    print ('%s(C%s): imbalance ratio is %.4f'%(self.select_cols[0], class_index, self.imratio ))
                print ('-'*30)
            else:
                print ('-'*30)
                imratio_list = []
                for class_key, select_col in enumerate(train_cols):
                    imratio = self.value_counts_dict[class_key][1]/(self.value_counts_dict[class_key][0]+self.value_counts_dict[class_key][1])
                    imratio_list.append(imratio)
                    print('Found %s images in total, %s positive images, %s negative images'%(self._num_images, self.value_counts_dict[class_key][1], self.value_counts_dict[class_key][0] ))
                    print ('%s(C%s): imbalance ratio is %.4f'%(select_col, class_key, imratio ))
                    print ()
                self.imratio = np.mean(imratio_list)
                self.imratio_list = imratio_list
                print ('-'*30)
            
    @property        
    def class_counts(self):
        return self.value_counts_dict
    
    @property
    def imbalance_ratio(self):
        return self.imratio

    @property
    def num_classes(self):
        return len(self.select_cols)
       
    @property  
    def data_size(self):
        return self._num_images 
    
    def image_augmentation(self, image):
        img_aug = tfs.Compose([tfs.RandomAffine(degrees=(-15, 15), translate=(0.05, 0.05), scale=(0.95, 1.05), fill=128)]) # pytorch 3.7: fillcolor --> fill
        image = img_aug(image)
        return image
    
    def __len__(self):
        return self._num_images
    
    def __getitem__(self, idx):

        image = cv2.imread(self._images_list[idx], 0)
        image = Image.fromarray(image)
        if self.mode == 'train':
            image = self.image_augmentation(image)
        image = np.array(image)
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        
        # resize and normalize; e.g., ToTensor()
        image = cv2.resize(image, dsize=(self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)  
        image = image/255.0
        __mean__ = np.array([[[0.485, 0.456, 0.406]]])
        __std__ =  np.array([[[0.229, 0.224, 0.225]  ]]) 
        image = (image-__mean__)/__std__
        image = image.transpose((2, 0, 1)).astype(np.float32)
        if self.class_index != -1: # multi-class mode
            label = np.array(self._labels_list[idx]).reshape(-1).astype(np.float32)
        else:
            label = np.array(self._labels_list[idx]).reshape(-1).astype(np.float32)
        return image, label


if __name__ == '__main__':
    root = '/datasets/chexpert/CheXpert-v1.0-small/'
    traindSet = CheXpert(csv_path=root+'train.csv', image_root_path=root, use_upsampling=True, use_frontal=True,
     train_cols=['Cardiomegaly', 'Edema', 'Consolidation', 'Atelectasis',  'Pleural Effusion','No Finding'],
      image_size=224, mode='train', class_index=-1)
    testSet =  CheXpert(csv_path=root+'valid.csv',  image_root_path=root, use_upsampling=False, use_frontal=True, image_size=224, mode='valid', class_index=-1)
    trainloader =  torch.utils.data.DataLoader(traindSet, batch_size=32, num_workers=2, drop_last=True, shuffle=True)
    testloader =  torch.utils.data.DataLoader(testSet, batch_size=32, num_workers=2, drop_last=False, shuffle=False)

 

Upsampling Cardiomegaly...
Upsampling Consolidation...
Multi-label mode: True, Number of classes: [5]
------------------------------
Found 227395 images in total, 48021 positive images, 179374 negative images
Cardiomegaly(C0): imbalance ratio is 0.2112

Found 227395 images in total, 77866 positive images, 149529 negative images
Edema(C1): imbalance ratio is 0.3424

Found 227395 images in total, 27217 positive images, 200178 negative images
Consolidation(C2): imbalance ratio is 0.1197

Found 227395 images in total, 70593 positive images, 156802 negative images
Atelectasis(C3): imbalance ratio is 0.3104

Found 227395 images in total, 94036 positive images, 133359 negative images
Pleural Effusion(C4): imbalance ratio is 0.4135

------------------------------
Multi-label mode: True, Number of classes: [5]
------------------------------
Found 202 images in total, 66 positive images, 136 negative images
Cardiomegaly(C0): imbalance ratio is 0.3267

Found 202 images in total, 42 positive image

In [50]:
traindSet[4]

(array([[[0.07406456, 0.07406456, 0.07406456, ..., 0.07406456,
          0.07406456, 0.07406456],
         [0.07406456, 0.07406456, 0.07406456, ..., 0.07406456,
          0.07406456, 0.07406456],
         [0.07406456, 0.07406456, 0.07406456, ..., 0.07406456,
          0.07406456, 0.07406456],
         ...,
         [0.07406456, 0.07406456, 0.07406456, ..., 0.07406456,
          0.07406456, 0.07406456],
         [0.07406456, 0.07406456, 0.07406456, ..., 0.07406456,
          0.07406456, 0.07406456],
         [0.07406456, 0.07406456, 0.07406456, ..., 0.07406456,
          0.07406456, 0.07406456]],
 
        [[0.20518208, 0.20518208, 0.20518208, ..., 0.20518208,
          0.20518208, 0.20518208],
         [0.20518208, 0.20518208, 0.20518208, ..., 0.20518208,
          0.20518208, 0.20518208],
         [0.20518208, 0.20518208, 0.20518208, ..., 0.20518208,
          0.20518208, 0.20518208],
         ...,
         [0.20518208, 0.20518208, 0.20518208, ..., 0.20518208,
          0.20518208, 0.