In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset
import torchvision as tv

import os, sys
import numpy as np
import pickle
from PIL import Image
import math
from tqdm import tqdm

from sklearn.metrics import roc_auc_score

# from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances, cosine_distances

from model_mlm_cls import Text_Classification
from transformers import BertConfig

from misc.config import Config
cfg  = Config()
cfg.GPU_ID = 1

torch.cuda.set_device(cfg.GPU_ID)

In [2]:
MAX_DIM = 2048

val_transform = tv.transforms.Compose([
    tv.transforms.CenterCrop(MAX_DIM),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize(0.5, 0.5)
])


class IUDataset(Dataset):
    def __init__(self, root, dataset, max_length, transform=val_transform, mode='test'):
        super().__init__()

        self.root = root  #save dir
        self.transform = transform
        self.mode = mode

        self.classes = dataset['classes']  # multi-label list
        self.datadict = dataset['data_dict']  # uid: {text:text, filenames:[filename]}
        if self.mode == 'test':
            self.keys = np.concatenate([dataset['data_split']['train_uids'], dataset['data_split']['val_uids'],
                                        dataset['data_split']['test_uids']])  # uid list

        #         self.idx2word = dataset['idx2word']
        #         self.idx2word[8410] = '[MASK]'
        #         self.word2idx = dataset['word2idx']
        #         self.word2idx['[MASK]'] = 8410
        #         self.__sep_id__ = dataset['word2idx']['[SEP]']
        self.vocab_size = 28996
        self.max_length = max_length + 1
        #         self.__mask_id__ = 8410 # [MASK] token id

        #         ## classification params
        #         self.class_to_idx = {
        #             'Atelectasis': 0, 'Cardiomegaly': 1, 'Consolidation': 2, 'Edema': 3, 'Enlarged Cardiomediastinum': 4,
        #             'Fracture': 5, 'Lung Lesion': 6, 'Lung Opacity': 7, 'No Finding': 8, 'Pleural Effusion': 9,
        #             'Pleural Other': 10, 'Pneumonia': 11, 'Pneumothorax': 12, 'Support Devices': 13
        #         }

        #         self.idx_to_class = {
        #             0:'Atelectasis', 1:'Cardiomegaly', 2:'Consolidation', 3:'Edema', 4:'Enlarged Cardiomediastinum',
        #             5:'Fracture', 6:'Lung Lesion', 7:'Lung Opacity', 8:'No Finding', 9:'Pleural Effusion',
        #             10:'Pleural Other', 11:'Pneumonia', 12:'Pneumothorax', 13:'Support Devices'
        #         }

        self.class_to_idx = {'no finding': 8
            , 'edema': 3
            , 'consolidation': 2
            , 'pneumonia': 11
            , 'pneumothorax': 12
            , 'atelectasis': 0
            , 'cardiomegaly': 1
            , 'effusion': 9}

        self.idx_to_class = {0: 'atelectasis'
            , 1: 'cardiomegaly'
            , 9: 'effusion'
            , 3: 'edema'
            , 2: 'consolidation'
            , 11: 'pneumonia'
            , 12: 'pneumothorax'
            , 8: 'no finding'}

        self.num_classes = 14

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        uid = self.keys[idx]

        #         image_id = np.random.choice(self.datadict[uid]['filenames'])# get one file name randomly
        #         image_path = os.path.join(self.root, image_id) #original used 'jpg', try 'png'

        #         try:
        #             with Image.open(image_path) as img:
        #                 if self.transform:
        #                     image = self.transform(img)

        #         except Exception as ex:
        #             print(ex)
        # #             with open(self.err_log, 'a+') as f:
        # #                 f.write('%s\nERR_IMG %s\n' % (ex, image_path))
        #             return None ## return None, collate_fn will ignore this broken sample

        classes = torch.tensor([self.class_to_idx[x] for x in self.classes[uid]])
        y_onehot = torch.FloatTensor(self.num_classes).zero_()
        y_onehot.scatter_(0, classes, 1)

        ## load text input ##
        max_len_array = np.zeros(self.max_length, dtype='int')
        cap_mask = np.zeros(self.max_length, dtype='int')
        caption = np.array(self.datadict[uid]['cb_token_ids'])
        if len(caption) <= self.max_length:
            cap_mask[:len(caption)] = 1
            max_len_array[:len(caption)] = caption
        else:
            cap_mask[:] = 1
            max_len_array = caption[:self.max_length]
            max_len_array[-1] = self.__sep_id__
        #         caption = max_len_array
        cap_mask = cap_mask.astype(bool)
        cap_lens = cap_mask.sum(-1)

        return max_len_array, cap_mask, y_onehot, uid, cap_lens


def build_dataset(mode='test', cfg=None):
    data_dir = cfg.dataset_root
    img_dir = os.path.join(data_dir, 'images', 'images_normalized')
    with open(os.path.join(data_dir, 'cleaned_dataset_v4.pickle'), 'rb') as f:
        dataset = pickle.load(f)
    if mode == 'test':
        data = IUDataset(img_dir, dataset,
                         max_length=cfg.max_length, mode=mode,
                         transform=val_transform)
        return data

    else:
        raise NotImplementedError(f"{mode} not supported")


## collate_fn for handling None type item due to image corruption ##
## This will return batch size - broken image number ##
def collate_fn_ignore_none(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

In [3]:
# cfg.max_length = 127

test_set = build_dataset('test', cfg)
# test_set = build_dataset('test', cfg)
print('Testing set %d is loaded.' % len(test_set))
test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=50,
    collate_fn=collate_fn_ignore_none, drop_last=False,
    shuffle=False, num_workers=8, pin_memory=False)
print('Vocab size is %d.' % test_set.vocab_size)
print('Max length is %d' % test_set.max_length)

Testing set 3666 is loaded.
Vocab size is 28996.
Max length is 160


In [4]:

bert_config = BertConfig(vocab_size=test_loader.dataset.vocab_size, hidden_size=512, num_hidden_layers=3,
                         num_attention_heads=8, intermediate_size=2048, hidden_act='gelu',
                         hidden_dropout_prob=cfg.hidden_dropout_prob,
                         attention_probs_dropout_prob=cfg.attention_probs_dropout_prob,
                         max_position_embeddings=512, layer_norm_eps=1e-12,
                         initializer_range=0.02, type_vocab_size=2, pad_token_id=0)

##### change the checkpoint path here #####
# cfg.text_encoder_path = '/media/My1TBSSD1/IPMI2021/output/MIMIC_class_txt_ft_new_2021_03_03_01_29_36/Model/Txt_class_model16.pth'
# cfg.text_encoder_path = '/media/My1TBSSD1/IPMI2021/output/MIMIC_class_mlm_cls_ft_2021_03_19_14_33_50/Model/Txt_class_model4.pth'
cfg.text_encoder_path = '/media/My1TBSSD1/MICCAI2021/output/MIMIC-CXR_mimic_mlm_cls_wp_2021_03_26_10_40_49/Model/Txt_class_model14.pth'
# ################### encoders ################################# #      
image_encoder = Text_Classification(num_class=14, pretrained=False, cfg=cfg, bert_config=bert_config)

if cfg.CUDA:
    image_encoder = image_encoder.cuda()

if cfg.text_encoder_path != '':
    #     img_encoder_path = cfg.text_encoder_path.replace('text_encoder', 'image_encoder')
    print('Load image encoder checkpoint from:', cfg.text_encoder_path)
    state_dict = torch.load(cfg.text_encoder_path, map_location='cpu')
    image_encoder.load_state_dict(state_dict['model'])


Bert encoder with MaskedLMhead.
Initiate text encoder from MLM pretrained parameters from: /media/My1TBSSD1/MICCAI2021/output/MIMIC-CXR_mlm_wordpiece_2021_03_25_19_20_18/Model/text_encoder.pth
Load image encoder checkpoint from: /media/My1TBSSD1/MICCAI2021/output/MIMIC-CXR_mimic_mlm_cls_wp_2021_03_26_10_40_49/Model/Txt_class_model14.pth


In [5]:
@torch.no_grad()
def evaluate(cnn_model, dataloader_val):
    cnn_model.eval()
    val_data_iter = iter(dataloader_val)
    y_preds = []
    y_trues = []
    class_auc = []
    #####################################
    for step in tqdm(range(len(val_data_iter))):
        captions, cap_masks, classes, uids, cap_lens = val_data_iter.next()
        if cfg.CUDA:
            captions, cap_masks, classes = captions.cuda(), cap_masks.cuda(), classes.cuda()

        y_pred = cnn_model(captions, cap_masks)
        y_pred_sigmoid = torch.sigmoid(y_pred)
        y_preds.append(y_pred_sigmoid.detach().cpu().numpy())
        y_trues.append(classes.detach().cpu().numpy())

    y_preds = np.concatenate(y_preds, axis=0)
    y_trues = np.concatenate(y_trues, axis=0)
    for i in [0, 1, 2, 3, 8, 9, 11, 12]:
        class_auc.append(roc_auc_score(y_trues[:, i], y_preds[:, i]))

    return class_auc

In [6]:
auc = evaluate(image_encoder, test_loader)

100%|██████████| 74/74 [00:02<00:00, 30.37it/s]


In [7]:
openi_cls = [0, 1, 2, 3, 8, 9, 11, 12]
# auc = np.array(auc)[[0,1,2,3,5,6,7]]

for idx in range(len(auc)):
    print('%s: %.4f' % (test_loader.dataset.idx_to_class[openi_cls[idx]], auc[idx]))

avg = np.mean(np.array(auc)[[0, 1, 2, 3, 5, 6, 7]])
print('Avg: %.4f' % avg)

# weight = np.array([295, 319, 28, 41, 2988, 141, 36, 25]) # weight for the whole open-i dataset
weight = np.array([295, 319, 28, 41, 141, 36, 25])  # weight for the whole open-i dataset
weight = weight / weight.sum()
wavg = np.array(auc)[[0, 1, 2, 3, 5, 6, 7]] @ weight
print('wAvg: %.4f' % wavg)

atelectasis: 0.9274
cardiomegaly: 0.9041
consolidation: 0.9152
edema: 0.9610
no finding: 0.8583
effusion: 0.9498
pneumonia: 0.8650
pneumothorax: 0.9819
Avg: 0.9292
wAvg: 0.9228


In [17]:
openi_cls = [0, 1, 2, 3, 8, 9, 11, 12]
# auc = np.array(auc)[[0,1,2,3,5,6,7]]

for idx in range(len(auc)):
    print('%s: %.4f' % (test_loader.dataset.idx_to_class[openi_cls[idx]], auc[idx]))

avg = np.mean(np.array(auc)[[0, 1, 2, 3, 5, 6, 7]])
print('Avg: %.4f' % avg)

# weight = np.array([295, 319, 28, 41, 2988, 141, 36, 25]) # weight for the whole open-i dataset
weight = np.array([295, 319, 28, 41, 141, 36, 25])  # weight for the whole open-i dataset
weight = weight / weight.sum()
wavg = np.array(auc)[[0, 1, 2, 3, 5, 6, 7]] @ weight
print('wAvg: %.4f' % wavg)

atelectasis: 0.9416
cardiomegaly: 0.9030
consolidation: 0.9456
edema: 0.9869
no finding: 0.8667
effusion: 0.9716
pneumonia: 0.8921
pneumothorax: 0.9842
Avg: 0.9464
wAvg: 0.9339


## split

In [3]:
## statistic for val sets and test sets
hold = np.array([679, 808, 191, 659, 132, 78, 108, 974, 539, 990, 52, 309, 94, 1061])
val1 = np.array([392, 422, 102, 246, 64, 46, 63, 484, 619, 533, 19, 134, 110, 685])
val2 = np.array([566, 575, 131, 354, 110, 33, 105, 650, 1017, 710, 12, 227, 124, 881])
print(list((val1+val2)))

[958, 997, 233, 600, 174, 79, 168, 1134, 1636, 1243, 31, 361, 234, 1566]


In [16]:
for i in val2:
    print(i)

566
575
131
354
110
33
105
650
1017
710
12
227
124
881
