In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset
import torchvision as tv

import os, sys
import numpy as np
import pickle
from PIL import Image
import math
from tqdm import tqdm

from sklearn.metrics import roc_auc_score

# from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances, cosine_distances

from model_mlm_cls import Text_Classification
from transformers import BertConfig

from misc.config import Config
cfg  = Config()
cfg.GPU_ID = 1

torch.cuda.set_device(cfg.GPU_ID)

In [2]:
MAX_DIM = 2048

train_transform = tv.transforms.Compose([
    tv.transforms.RandomRotation(15),  # rotation will cost 0.1s for each 10 images
    tv.transforms.RandomCrop(MAX_DIM, pad_if_needed=True),  # 0.6s for each 10 images
    tv.transforms.ColorJitter(brightness=[0.5, 1.8]  # colorjitter will cost 0.32s for each 10 images
                              , contrast=[0.5, 1.8]
                              , saturation=[0.5, 1.8]),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize(0.5, 0.5)
])

val_transform = tv.transforms.Compose([
    tv.transforms.CenterCrop(MAX_DIM),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize(0.5, 0.5)
])


class MimicDataset(Dataset):
    def __init__(self, root, dataset, max_length, transform=train_transform, mode='train', log_dir='test'):
        super().__init__()

        self.root = root  #save dir
        self.transform = transform
        self.mode = mode

        self.classes = dataset['label']  # multi-label one-hot vector
        self.datadict = dataset['image']  # uid: {text:text, filenames:[filename]}
        if self.mode == 'train':
            self.keys = dataset['split']['train']  # uid list
        elif self.mode == 'val':
            self.keys = np.concatenate([dataset['split']['val1'], dataset['split']['val2']])  # uid list
        elif self.mode == 'val2':
            self.keys = dataset['split']['val2']  # uid list
        elif self.mode == 'test':
            self.keys = dataset['split']['test']  # uid list

        self.idx2word = dataset['idx2word']
        self.idx2word[8410] = '[MASK]'
        self.word2idx = dataset['word2idx']
        self.word2idx['[MASK]'] = 8410
        self.__sep_id__ = dataset['word2idx']['[SEP]']
        self.vocab_size = len(dataset['word2idx'])
        self.max_length = max_length + 1
        #         self.__mask_id__ = 8410 # [MASK] token id

        ## classification params
        self.class_to_idx = {
            'Atelectasis': 0, 'Cardiomegaly': 1, 'Consolidation': 2, 'Edema': 3, 'Enlarged Cardiomediastinum': 4,
            'Fracture': 5, 'Lung Lesion': 6, 'Lung Opacity': 7, 'No Finding': 8, 'Pleural Effusion': 9,
            'Pleural Other': 10, 'Pneumonia': 11, 'Pneumothorax': 12, 'Support Devices': 13
        }

        self.idx_to_class = {
            0: 'Atelectasis', 1: 'Cardiomegaly', 2: 'Consolidation', 3: 'Edema', 4: 'Enlarged Cardiomediastinum',
            5: 'Fracture', 6: 'Lung Lesion', 7: 'Lung Opacity', 8: 'No Finding', 9: 'Pleural Effusion',
            10: 'Pleural Other', 11: 'Pneumonia', 12: 'Pneumothorax', 13: 'Support Devices'
        }

        self.num_classes = 14

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        uid = self.keys[idx]

        classes = torch.tensor(self.classes[uid]).float()

        ## load text input ##
        max_len_array = np.zeros(self.max_length, dtype='int')
        cap_mask = np.zeros(self.max_length, dtype='int')
        caption = np.array(self.datadict[uid]['token_ids'])
        if len(caption) <= self.max_length:
            cap_mask[:len(caption)] = 1
            max_len_array[:len(caption)] = caption
        else:
            cap_mask[:] = 1
            max_len_array = caption[:self.max_length]
            max_len_array[-1] = self.__sep_id__
        #         caption = max_len_array
        cap_mask = cap_mask.astype(bool)
        cap_lens = cap_mask.sum(-1)

        return max_len_array, cap_mask, classes, uid, cap_lens


def build_dataset(mode='train', cfg=None, out_dir=None):
    data_dir = cfg.dataset_root
    img_dir = os.path.join(data_dir, 'physionet.org/files/', 'mimic-cxr-jpg/2.0.0/')
    with open(os.path.join(data_dir, 'lm_reports/class_label_mit.pkl'), 'rb') as f:
        dataset = pickle.load(f)
    with open(os.path.join(data_dir, 'lm_reports/mimic_dataset_mit_normalized.pkl'), 'rb') as f2:
        dataset_token = pickle.load(f2)
    dataset['word2idx'] = dataset_token['word2idx']  # copy the token dicts to the dataset
    dataset['idx2word'] = dataset_token['idx2word']

    if mode == 'train':
        data = MimicDataset(img_dir, dataset,
                            max_length=cfg.max_length, mode=mode,
                            transform=train_transform, log_dir=out_dir)
        return data

    elif mode == 'val':
        data = MimicDataset(img_dir, dataset,
                            max_length=cfg.max_length, mode=mode,
                            transform=val_transform, log_dir=out_dir)
        return data

    elif mode == 'val2':
        data = MimicDataset(img_dir, dataset,
                            max_length=cfg.max_length, mode=mode,
                            transform=val_transform, log_dir=out_dir)
        return data

    elif mode == 'test':
        data = MimicDataset(img_dir, dataset,
                            max_length=cfg.max_length, mode=mode,
                            transform=val_transform, log_dir=out_dir)
        return data

    else:
        raise NotImplementedError(f"{mode} not supported")


## collate_fn for handling None type item due to image corruption ##
## This will return batch size - broken image number ##
def collate_fn_ignore_none(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

In [3]:
# cfg.max_length = 127

test_set = build_dataset('val2', cfg)
# test_set = build_dataset('test', cfg)
print('Testing set %d is loaded.' % len(test_set))
test_loader = torch.utils.data.DataLoader(
                test_set, batch_size=50, 
                collate_fn=collate_fn_ignore_none, drop_last=False,
                shuffle=False, num_workers=8, pin_memory=False)
print('Vocab size is %d.' % test_set.vocab_size)
print('Max length is %d' % test_set.max_length)

Testing set 3000 is loaded.
Vocab size is 8411.
Max length is 160


In [29]:

bert_config = BertConfig(vocab_size=test_loader.dataset.vocab_size, hidden_size=512, num_hidden_layers=3,
                    num_attention_heads=8, intermediate_size=2048, hidden_act='gelu',
                    hidden_dropout_prob=cfg.hidden_dropout_prob, attention_probs_dropout_prob=cfg.attention_probs_dropout_prob,
                    max_position_embeddings=512, layer_norm_eps=1e-12,
                    initializer_range=0.02, type_vocab_size=2, pad_token_id=0)

##### change the checkpoint path here #####
cfg.text_encoder_path = '../../output/MIMIC_class_mlm_cls_ft_2021_03_19_14_33_50/Model/Txt_class_model4.pth'
# ################### encoders ################################# #      
image_encoder = Text_Classification(num_class=14, pretrained=False, cfg=cfg, bert_config=bert_config)

if cfg.CUDA:
    image_encoder = image_encoder.cuda()
    
if cfg.text_encoder_path != '':
#     img_encoder_path = cfg.text_encoder_path.replace('text_encoder', 'image_encoder')
    print('Load image encoder checkpoint from:', cfg.text_encoder_path)
    state_dict = torch.load(cfg.text_encoder_path, map_location='cpu')
    image_encoder.load_state_dict(state_dict['model'])


Bert encoder with MaskedLMhead.
Initiate text encoder from MLM pretrained parameters from: /media/My1TBSSD1/IPMI2021/output/MIMIC_mlm_2021_02_24_19_58_21/Model/text_encoder.pth
Load image encoder checkpoint from: /media/My1TBSSD1/IPMI2021/output/MIMIC_class_mlm_cls_ft_2021_03_19_14_33_50/Model/Txt_class_model4.pth


In [30]:
@torch.no_grad()
def evaluate(cnn_model, dataloader_val):
    cnn_model.eval()
    val_data_iter = iter(dataloader_val)
    y_preds = []
    y_trues = []
    class_auc = []
    #####################################
    for step in tqdm(range(len(val_data_iter))):  
        captions, cap_masks, classes, uids, cap_lens = val_data_iter.next()
        if cfg.CUDA:
            captions, cap_masks, classes = captions.cuda(), cap_masks.cuda(), classes.cuda()

        y_pred = cnn_model(captions, cap_masks)
        y_pred_sigmoid = torch.sigmoid(y_pred)
        y_preds.append(y_pred_sigmoid.detach().cpu().numpy())
        y_trues.append(classes.detach().cpu().numpy())

    y_preds = np.concatenate(y_preds,axis=0)
    y_trues = np.concatenate(y_trues,axis=0)
    for i in range(y_preds.shape[-1]):
        class_auc.append(roc_auc_score(y_trues[:,i],y_preds[:,i]))

    return class_auc

In [31]:
auc = evaluate(image_encoder, test_loader)

100%|██████████| 60/60 [00:01<00:00, 31.84it/s]


In [32]:
for idx in range(len(auc)):
    print('%s: %.4f' % (test_loader.dataset.idx_to_class[idx], auc[idx]))

avg= np.mean(np.array(auc)[[0,1,2,3,4,5,6,7,9,10,11,12,13]])
print('Avg: %.4f' % avg)

# weight = np.array([679, 808, 191, 659, 132, 78, 108, 974, 539, 990, 52, 309, 94, 1061]) # weight for hold_out_set
# weight = np.array([958, 997, 233, 600, 174, 79, 168, 1134, 1636, 1243, 31, 361, 234, 1566]) # weight for val+test set
# weight = np.array([566, 575, 131, 354, 110, 33, 105, 650, 1017, 710, 28, 227, 124, 881]) # weight for val2 set
weight = np.array([566, 575, 131, 354, 110, 33, 105, 650, 710, 28, 227, 124, 881]) # weight for val2 set

weight = weight / weight.sum()
wavg = np.array(auc)[[0,1,2,3,4,5,6,7,9,10,11,12,13]] @ weight
print('wAvg: %.4f' % wavg)

Atelectasis: 0.9919
Cardiomegaly: 0.9901
Consolidation: 0.9979
Edema: 0.9944
Enlarged Cardiomediastinum: 0.9882
Fracture: 0.9892
Lung Lesion: 0.9951
Lung Opacity: 0.9860
No Finding: 0.9866
Pleural Effusion: 0.9943
Pleural Other: 0.9972
Pneumonia: 0.9834
Pneumothorax: 0.9909
Support Devices: 0.9845
Avg: 0.9910
wAvg: 0.9896


In [23]:
for idx in range(len(auc)):
    print('%s: %.4f' % (test_loader.dataset.idx_to_class[idx], auc[idx]))

avg= np.mean(np.array(auc)[[0,1,2,3,4,5,6,7,9,10,11,12,13]])
print('Avg: %.4f' % avg)

# weight = np.array([679, 808, 191, 659, 132, 78, 108, 974, 539, 990, 52, 309, 94, 1061]) # weight for hold_out_set
# weight = np.array([958, 997, 233, 600, 174, 79, 168, 1134, 1636, 1243, 31, 361, 234, 1566]) # weight for val+test set
# weight = np.array([566, 575, 131, 354, 110, 33, 105, 650, 1017, 710, 28, 227, 124, 881]) # weight for val2 set
weight = np.array([566, 575, 131, 354, 110, 33, 105, 650, 710, 28, 227, 124, 881]) # weight for val2 set

weight = weight / weight.sum()
wavg = np.array(auc)[[0,1,2,3,4,5,6,7,9,10,11,12,13]] @ weight
print('wAvg: %.4f' % wavg)

Atelectasis: 0.9343
Cardiomegaly: 0.9346
Consolidation: 0.9363
Edema: 0.9568
Enlarged Cardiomediastinum: 0.9324
Fracture: 0.9727
Lung Lesion: 0.9556
Lung Opacity: 0.9425
No Finding: 0.9703
Pleural Effusion: 0.9668
Pleural Other: 0.9353
Pneumonia: 0.9148
Pneumothorax: 0.9694
Support Devices: 0.9695
Avg: 0.9478
wAvg: 0.9501


## split

In [3]:
## statistic for val sets and test sets
hold = np.array([679, 808, 191, 659, 132, 78, 108, 974, 539, 990, 52, 309, 94, 1061])
val1 = np.array([392, 422, 102, 246, 64, 46, 63, 484, 619, 533, 19, 134, 110, 685])
val2 = np.array([566, 575, 131, 354, 110, 33, 105, 650, 1017, 710, 12, 227, 124, 881])
print(list((val1+val2)))

[958, 997, 233, 600, 174, 79, 168, 1134, 1636, 1243, 31, 361, 234, 1566]


In [16]:
for i in val2:
    print(i)

566
575
131
354
110
33
105
650
1017
710
12
227
124
881
