In [44]:
from __future__ import print_function


import os
import sys
sys.path.append('..')

from misc.config import Config
from ds_mimic_cls_img_neg import build_dataset
from model_cls_img_neg import ImageEncoder_Classification

import json
from tqdm import tqdm
import time
import random
import pandas as pd
# import pprint
import datetime
import dateutil.tz
import argparse
import numpy as np
# import pandas as pd
import torch
import torchvision.transforms as transforms
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_fscore_support, classification_report

In [2]:
cfg  = Config()

In [3]:
def collate_fn_ignore_none(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)


test_data_set = build_dataset('test', cfg, out_dir = None)
print('test set %d is loaded.' % len(test_data_set))
test_loader = torch.utils.data.DataLoader(
                test_data_set, batch_size=cfg.val_batch_size, 
                collate_fn=collate_fn_ignore_none, drop_last=False,
                shuffle=False, num_workers=4, pin_memory=True)
    

test set 3041 is loaded.


In [4]:
def build_models(trained_model_path):
        # ################### encoders ################################# #      
        image_encoder = ImageEncoder_Classification(num_class=test_data_set.num_classes
                                                    , encoder_path=cfg.init_image_encoder_path
                                                    , pretrained=cfg.pretrained
                                                    , cfg = cfg)
        
        cfg.text_encoder_path = trained_model_path
        if cfg.CUDA:
            image_encoder = image_encoder.cuda()
            
        if cfg.text_encoder_path != '':
            img_encoder_path = cfg.text_encoder_path #.replace('text_encoder', 'image_encoder')
            print('Load image encoder checkpoint from:', img_encoder_path)
            state_dict = torch.load(img_encoder_path, map_location='cpu')
            if 'model' in state_dict.keys():
                image_encoder.load_state_dict(state_dict['model'])
            else:
                image_encoder.load_state_dict(state_dict)
        # ########################################################### #

        return image_encoder


In [84]:
# train from scratch
# image_encoder = build_models(trained_model_path='../../output/MIMIC_neg_img_cls_2021_06_20_23_08_47/Model/image_encoder10.pth')

# train from pretrained weights
image_encoder = build_models(trained_model_path='../../output/MIMIC_neg_img_cls_pt_2021_06_22_02_46_35/Model/image_encoder10.pth')

Load image encoder from: /media/My1TBSSD1/MICCAI2021/output/MIMIC_pretrain_2021_05_04_23_03_10/Model/image_encoder28.pth
Load image encoder checkpoint from: ../../output/MIMIC_neg_img_cls_pt_2021_06_22_02_46_35/Model/image_encoder10.pth


In [85]:
@torch.no_grad()
def evaluate(cnn_model):
    cnn_model.eval()

    total_bce_loss_epoch=0.0
    val_data_iter = iter(test_loader)
    y_preds = []
    y_trues = []
    class_auc = []
    #####################################
    for step in tqdm(range(len(val_data_iter)), leave=False):  
        real_imgs, classes, uids = val_data_iter.next()
        if cfg.CUDA:
            real_imgs, classes = real_imgs.cuda(), classes.cuda()

        y_pred, _, _, _, _ = cnn_model(real_imgs)
        y_pred_sigmoid = torch.sigmoid(y_pred)

        
        y_preds.append(y_pred_sigmoid.detach().cpu().numpy())
        y_trues.append(classes.detach().cpu().numpy())
        
#         print(y_pred_sigmoid.shape,classes.shape)

#             if step == 5: break

    
    y_preds = np.concatenate(y_preds,axis=0)
    y_trues = np.concatenate(y_trues,axis=0)
    
    print(y_preds.shape,y_trues.shape)
    for i in range(y_preds.shape[-1]):
#         print(i, len(np.unique(y_trues[:,i])))
        if len(np.unique(y_trues[:,i]))<2: # No Finding
            class_auc.append(0)
        else:
            class_auc.append(roc_auc_score(y_trues[:,i],y_preds[:,i]))
            
    return class_auc, y_trues, y_preds

In [86]:
auc, y_trues, y_preds = evaluate(image_encoder)

                                               

(1074, 14) (1074, 14)


# Without pretrained weights

In [83]:
df_res = pd.DataFrame(classification_report(y_trues
                                   , np.round(y_preds)
                                   , output_dict=True
                                   , target_names = list(test_data_set.class_to_idx.keys()))).T
df_res['auc'] = auc + [0,0,0,0]
df_res['auc'] = auc + [0
                       ,0
                       ,sum(df_res['support'].values * df_res['auc'].values)/1757
                       ,sum(df_res['auc'].values)/14
                      ]
df_res

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,precision,recall,f1-score,support,auc
Atelectasis,0.0,0.0,0.0,9.0,0.570683
Cardiomegaly,0.256484,0.649635,0.367769,137.0,0.77937
Consolidation,0.234375,0.29703,0.262009,101.0,0.72435
Edema,0.311404,0.816092,0.450794,261.0,0.662552
Enlarged Cardiomediastinum,0.0,0.0,0.0,54.0,0.723911
Fracture,0.0,0.0,0.0,6.0,0.596598
Lung Lesion,0.0,0.0,0.0,10.0,0.735338
Lung Opacity,0.0,0.0,0.0,64.0,0.467481
No Finding,0.0,0.0,0.0,0.0,0.0
Pleural Effusion,0.366612,0.826568,0.507937,271.0,0.723151


# With pretained weights

In [87]:
df_res = pd.DataFrame(classification_report(y_trues
                                   , np.round(y_preds)
                                   , output_dict=True
                                   , target_names = list(test_data_set.class_to_idx.keys()))).T
df_res['auc'] = auc + [0,0,0,0]
df_res['auc'] = auc + [0
                       ,0
                       ,sum(df_res['support'].values * df_res['auc'].values)/1757
                       ,sum(df_res['auc'].values)/14
                      ]
df_res

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,precision,recall,f1-score,support,auc
Atelectasis,0.0,0.0,0.0,9.0,0.594992
Cardiomegaly,0.296636,0.708029,0.418103,137.0,0.803029
Consolidation,0.241379,0.485149,0.322368,101.0,0.796119
Edema,0.364839,0.739464,0.488608,261.0,0.725095
Enlarged Cardiomediastinum,0.5,0.018519,0.035714,54.0,0.742792
Fracture,0.0,0.0,0.0,6.0,0.812734
Lung Lesion,0.0,0.0,0.0,10.0,0.756579
Lung Opacity,0.0,0.0,0.0,64.0,0.626686
No Finding,0.0,0.0,0.0,0.0,0.0
Pleural Effusion,0.430799,0.815498,0.563776,271.0,0.794888
