In [2]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset
import torchvision as tv

import os, sys
import numpy as np
import pickle
from PIL import Image
import math
from tqdm import tqdm

from sklearn.metrics import roc_auc_score

# from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances, cosine_distances

from model_old import ImageEncoder_Classification
# from transformers import BertConfig

from misc.config import Config
cfg  = Config()
cfg.GPU_ID = 2

torch.cuda.set_device(cfg.GPU_ID)

In [5]:
MAX_DIM = 2048

train_transform = tv.transforms.Compose([
    tv.transforms.RandomRotation(15), # rotation will cost 0.1s for each 10 images
    tv.transforms.RandomCrop(MAX_DIM, pad_if_needed=True), # 0.6s for each 10 images
    tv.transforms.ColorJitter(brightness=[0.5, 1.8] # colorjitter will cost 0.32s for each 10 images
                              , contrast=[0.5, 1.8]
                              , saturation=[0.5, 1.8]),
    tv.transforms.ToTensor(), 
    tv.transforms.Normalize(0.5, 0.5)
])

val_transform = tv.transforms.Compose([
    tv.transforms.CenterCrop(MAX_DIM),
    tv.transforms.ToTensor(),  
    tv.transforms.Normalize(0.5, 0.5)
])

class MimicDataset(Dataset):
    def __init__(self, root, dataset, max_length, transform=train_transform, mode='train'):
        super().__init__()
        
        self.root = root #save dir
        self.transform = transform
        self.mode = mode

        self.classes = dataset['label'] # multi-label one-hot vector
        self.datadict = dataset['image'] # uid: {text:text, filenames:[filename]}
        if self.mode == 'train':
            self.keys = dataset['split']['train'] # uid list
        elif self.mode == 'val':
            self.keys = np.concatenate([dataset['split']['val1'], dataset['split']['val2']]) # uid list
        elif self.mode == 'val2':
            self.keys = dataset['split']['val2'] # uid list
        elif self.mode == 'test':
            self.keys = dataset['split']['test'] # uid list
            
        self.class_to_idx = {
            'Atelectasis': 0, 'Cardiomegaly': 1, 'Consolidation': 2, 'Edema': 3, 'Enlarged Cardiomediastinum': 4, 
            'Fracture': 5, 'Lung Lesion': 6, 'Lung Opacity': 7, 'No Finding': 8, 'Pleural Effusion': 9, 
            'Pleural Other': 10, 'Pneumonia': 11, 'Pneumothorax': 12, 'Support Devices': 13
        }
        
        self.idx_to_class = {
            0:'Atelectasis', 1:'Cardiomegaly', 2:'Consolidation', 3:'Edema', 4:'Enlarged Cardiomediastinum', 
            5:'Fracture', 6:'Lung Lesion', 7:'Lung Opacity', 8:'No Finding', 9:'Pleural Effusion', 
            10:'Pleural Other', 11:'Pneumonia', 12:'Pneumothorax', 13:'Support Devices'
        }
        
        self.num_classes = 14
        
#         self.err_log = os.path.join(log_dir, 'err.log') # create error log
#         if not os.path.exists(self.err_log):
#             with open(self.err_log, 'w') as f:
#                 f.write('Epoch 0:\n\n')

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        uid = self.keys[idx]
        
        image_id = np.random.choice(self.datadict[uid]['filenames'])# get one file name randomly
        image_path = os.path.join(self.root, image_id.replace('dcm','jpg')) #original used 'jpg', try 'png'

        try:
            with Image.open(image_path) as img: 
                if self.transform:
                    image = self.transform(img)

        except Exception as ex:
            print(ex)
#             with open(self.err_log, 'a+') as f:
#                 f.write('%s\nERR_IMG %s\n' % (ex, image_path))
            return None ## return None, collate_fn will ignore this broken sample
        
        classes = torch.tensor(self.classes[uid]).float()
        
        return image, classes, uid
        

def build_dataset(mode='train', cfg=None):
    data_dir = cfg.dataset_root
    img_dir = os.path.join(data_dir, 'physionet.org/files/', 'mimic-cxr-jpg/2.0.0/')
    with open(os.path.join(data_dir, 'lm_reports/class_label_mit.pkl'), 'rb') as f:
        dataset = pickle.load(f)
    if mode == 'train':
        data = MimicDataset(img_dir, dataset, 
                               max_length=cfg.max_length, mode=mode, 
                               transform=train_transform)
        return data

    elif mode == 'val':
        data = MimicDataset(img_dir, dataset, 
                               max_length=cfg.max_length, mode=mode, 
                               transform=val_transform)
        return data
    
    elif mode == 'val2':
        data = MimicDataset(img_dir, dataset, 
                               max_length=cfg.max_length, mode=mode, 
                               transform=val_transform)
        return data
    
    elif mode == 'test':
        data = MimicDataset(img_dir, dataset, 
                               max_length=cfg.max_length, mode=mode, 
                               transform=val_transform)
        return data

    else:
        raise NotImplementedError(f"{mode} not supported")
        
## collate_fn for handling None type item due to image corruption ##
## This will return batch size - broken image number ##
def collate_fn_ignore_none(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

In [29]:
test_set = build_dataset('val', cfg)
# test_set = build_dataset('test', cfg)
print('Testing set %d is loaded.' % len(test_set))
test_loader = torch.utils.data.DataLoader(
                test_set, batch_size=50, 
                collate_fn=collate_fn_ignore_none, drop_last=False,
                shuffle=False, num_workers=8, pin_memory=False)
# print('Vocab size is %d.' % test_set.vocab_size)

Testing set 4999 is loaded.


In [26]:
##### change the checkpoint path here #####
# raw checkpoint: '/media/My1TBSSD1/IPMI2021/output/MIMIC_class_raw_2020_12_09_22_17_55/Model/image_encoder25.pth'
# fz checkpoint: '/media/My1TBSSD1/IPMI2021/pretrained/image_encoder14.pth'
# ft checkpoint: '/media/My1TBSSD1/IPMI2021/output/MIMIC_class_ft_2020_12_11_09_18_56/Model/image_encoder12.pth'
cfg.text_encoder_path = '/media/My1TBSSD1/IPMI2021/pretrained/image_encoder14.pth'
# ################### encoders ################################# #      
image_encoder = ImageEncoder_Classification(num_class=14, encoder_path='', pretrained=False, cfg = cfg)

if cfg.CUDA:
    image_encoder = image_encoder.cuda()
    
if cfg.text_encoder_path != '':
    img_encoder_path = cfg.text_encoder_path.replace('text_encoder', 'image_encoder')
    print('Load image encoder checkpoint from:', img_encoder_path)
    state_dict = torch.load(img_encoder_path, map_location='cpu')
    image_encoder.load_state_dict(state_dict['model'])


Load image encoder checkpoint from: /media/My1TBSSD1/IPMI2021/pretrained/image_encoder14.pth


In [8]:
@torch.no_grad()
def evaluate(cnn_model, dataloader_val):
    cnn_model.eval()
    val_data_iter = iter(dataloader_val)
    y_preds = []
    y_trues = []
    class_auc = []
    #####################################
    for step in tqdm(range(len(val_data_iter))):  
        real_imgs, classes, uids = val_data_iter.next()
        if cfg.CUDA:
            real_imgs, classes = real_imgs.cuda(), classes.cuda()

        y_pred = cnn_model(real_imgs)
        y_pred_sigmoid = torch.sigmoid(y_pred)
        y_preds.append(y_pred_sigmoid.detach().cpu().numpy())
        y_trues.append(classes.detach().cpu().numpy())

    y_preds = np.concatenate(y_preds,axis=0)
    y_trues = np.concatenate(y_trues,axis=0)
    for i in range(y_preds.shape[-1]):
        class_auc.append(roc_auc_score(y_trues[:,i],y_preds[:,i]))

    return class_auc

In [30]:
auc = evaluate(image_encoder, test_loader)

100%|██████████| 100/100 [01:06<00:00,  1.51it/s]


In [31]:
for idx in range(len(auc)):
    print('%s: %.4f' % (test_loader.dataset.idx_to_class[idx], auc[idx]))

avg= np.mean(np.array(auc)[[0,1,2,3,4,5,6,7,9,10,11,12,13]])
print('Avg: %.4f' % avg)

# weight = np.array([679, 808, 191, 659, 132, 78, 108, 974, 539, 990, 52, 309, 94, 1061]) # weight for hold_out_set
# weight = np.array([958, 997, 233, 600, 174, 79, 168, 1134, 1636, 1243, 31, 361, 234, 1566]) # weight for val+test set
# weight = np.array([566, 575, 131, 354, 110, 33, 105, 650, 1017, 710, 28, 227, 124, 881]) # weight for val2 set
weight = np.array([566, 575, 131, 354, 110, 33, 105, 650, 710, 28, 227, 124, 881]) # weight for val2 set

weight = weight / weight.sum()
wavg = np.array(auc)[[0,1,2,3,4,5,6,7,9,10,11,12,13]] @ weight
print('wAvg: %.4f' % wavg)

Atelectasis: 0.8208
Cardiomegaly: 0.8183
Consolidation: 0.8074
Edema: 0.8895
Enlarged Cardiomediastinum: 0.7566
Fracture: 0.6757
Lung Lesion: 0.7187
Lung Opacity: 0.7539
No Finding: 0.8733
Pleural Effusion: 0.9048
Pleural Other: 0.8396
Pneumonia: 0.7498
Pneumothorax: 0.8527
Support Devices: 0.9137
Avg: 0.8078
wAvg: 0.8397


## split

In [3]:
## statistic for val sets and test sets
hold = np.array([679, 808, 191, 659, 132, 78, 108, 974, 539, 990, 52, 309, 94, 1061])
val1 = np.array([392, 422, 102, 246, 64, 46, 63, 484, 619, 533, 19, 134, 110, 685])
val2 = np.array([566, 575, 131, 354, 110, 33, 105, 650, 1017, 710, 12, 227, 124, 881])
print(list((val1+val2)))

[958, 997, 233, 600, 174, 79, 168, 1134, 1636, 1243, 31, 361, 234, 1566]


In [16]:
for i in val2:
    print(i)

566
575
131
354
110
33
105
650
1017
710
12
227
124
881
