In [1]:
from __future__ import print_function


import os
import sys
sys.path.append('..')

from misc.config import Config
from ds_mimic_cls_txt_neg import build_dataset
from model_cls_phr_neg import Text_Classification

import json
from tqdm import tqdm
import time
import random
import pandas as pd
# import pprint
import datetime
import dateutil.tz
import argparse
import numpy as np
from transformers import BertConfig
import torch
import torchvision.transforms as transforms
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_fscore_support, classification_report

In [2]:
cfg  = Config()

In [3]:
def collate_fn_ignore_none(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)


test_data_set = build_dataset('manual_gt_with_zeros', cfg, out_dir = None)
print('test set %d is loaded.' % len(test_data_set))
test_loader = torch.utils.data.DataLoader(
                test_data_set, batch_size=cfg.val_batch_size, 
                collate_fn=collate_fn_ignore_none, drop_last=False,
                shuffle=False, num_workers=4, pin_memory=True)
    

test set 493 is loaded.


In [4]:
bert_config = BertConfig(vocab_size=test_loader.dataset.vocab_size, hidden_size=512, num_hidden_layers=3,
                         num_attention_heads=8, intermediate_size=2048, hidden_act='gelu',
                         hidden_dropout_prob=cfg.hidden_dropout_prob,
                         attention_probs_dropout_prob=cfg.attention_probs_dropout_prob,
                         max_position_embeddings=512, layer_norm_eps=1e-12,
                         initializer_range=0.02, type_vocab_size=2, pad_token_id=0)


def build_models(trained_model_path):
    # ################### encoders ################################# #
    text_encoder = Text_Classification(num_class=test_data_set.num_classes
                                       , txt_encoder_path=cfg.init_text_encoder_path
                                       , pretrained=cfg.pretrained
                                       , cfg=cfg, bert_config=bert_config)

    cfg.text_encoder_path = trained_model_path
    if cfg.CUDA:
        text_encoder = text_encoder.cuda()

    if cfg.text_encoder_path != '':
        encoder_path = cfg.text_encoder_path  #.replace('text_encoder', 'text_encoder')
        print('Load text encoder checkpoint from:', encoder_path)
        state_dict = torch.load(encoder_path, map_location='cpu')
        if 'model' in state_dict.keys():
            text_encoder.load_state_dict(state_dict['model'])
        else:
            text_encoder.load_state_dict(state_dict)
    # ########################################################### #

    return text_encoder


In [47]:
# train from scratch
# text_encoder = build_models(trained_model_path='../../output/MIMIC_neg_txt_cls_fs_2021_06_25_05_39_29/Model/Txt_class_model10.pth')

# train from pretrained weights
text_encoder = build_models(trained_model_path='../../output/MIMIC_neg_txt_cls_ft_2021_06_25_05_34_59/Model/Txt_class_model10.pth')

Load text encoder checkpoint from: ../../output/MIMIC_neg_txt_cls_ft_2021_06_25_05_34_59/Model/Txt_class_model10.pth


In [48]:
@torch.no_grad()
def evaluate(cnn_model):
    cnn_model.eval()

    total_bce_loss_epoch=0.0
    val_data_iter = iter(test_loader)
    y_preds = []
    y_trues = []
    class_auc = []
    #####################################
    for step in tqdm(range(len(val_data_iter)), leave=False):  
        captions, cap_masks, classes, uids, cap_lens = val_data_iter.next()
        if cfg.CUDA:
             captions, cap_masks, classes = captions.cuda(), cap_masks.cuda(), classes.cuda()
            
        y_pred = cnn_model(captions, cap_masks)
        y_pred_sigmoid = torch.sigmoid(y_pred)

        
        y_preds.append(y_pred_sigmoid.detach().cpu().numpy())
        y_trues.append(classes.detach().cpu().numpy())
        
#         print(y_pred_sigmoid.shape,classes.shape)

#             if step == 5: break

    
    y_preds = np.concatenate(y_preds,axis=0)
    y_trues = np.concatenate(y_trues,axis=0)
    
    print(y_preds.shape,y_trues.shape)
    for i in range(y_preds.shape[-1]):
#         print(i, len(np.unique(y_trues[:,i])))
        if len(np.unique(y_trues[:,i]))<2: # No Finding
            class_auc.append(0)
        else:
            class_auc.append(roc_auc_score(y_trues[:,i],y_preds[:,i]))
            
    return class_auc, y_trues, y_preds

In [49]:
auc, y_trues, y_preds = evaluate(text_encoder)

                                             

(493, 14) (493, 14)


# Special Test:

In [50]:
df_res = pd.DataFrame(classification_report(y_trues
                                   , np.round(y_preds)
                                   , output_dict=True
                                   , target_names = list(test_data_set.class_to_idx.keys()))).T
df_res['auc'] = auc + [0,0,0,0]
df_res['auc'] = auc + [0
                       ,0
                       ,sum(df_res['support'].values * df_res['auc'].values)/1757
                       ,sum(df_res['auc'].values)/14
                      ]
df_res

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,precision,recall,f1-score,support,auc
Atelectasis,0.0,0.0,0.0,1.0,1.0
Cardiomegaly,0.909091,0.701754,0.792079,57.0,0.95614
Consolidation,0.703704,0.95,0.808511,20.0,0.995877
Edema,0.851852,0.896104,0.873418,77.0,0.974963
Enlarged Cardiomediastinum,0.555556,0.47619,0.512821,21.0,0.874496
Fracture,0.6,0.428571,0.5,7.0,0.71458
Lung Lesion,0.333333,0.5,0.4,2.0,0.990835
Lung Opacity,0.4,0.5,0.444444,16.0,0.836347
No Finding,0.0,0.0,0.0,0.0,0.0
Pleural Effusion,0.913793,0.929825,0.921739,57.0,0.977547


In [37]:
df_res.to_csv('../../../MIMIC-CXR/lm_reports/neg_results_fs_manual_labels.csv')

# Without pretrained weights

In [12]:
df_res = pd.DataFrame(classification_report(y_trues
                                   , np.round(y_preds)
                                   , output_dict=True
                                   , target_names = list(test_data_set.class_to_idx.keys()))).T
df_res['auc'] = auc + [0,0,0,0]
df_res['auc'] = auc + [0
                       ,0
                       ,sum(df_res['support'].values * df_res['auc'].values)/1757
                       ,sum(df_res['auc'].values)/14
                      ]
df_res

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,precision,recall,f1-score,support,auc
Atelectasis,0.0,0.0,0.0,1.0,0.739837
Cardiomegaly,0.860465,0.649123,0.74,57.0,0.917834
Consolidation,0.6,0.75,0.666667,20.0,0.984567
Edema,0.774194,0.935065,0.847059,77.0,0.975056
Enlarged Cardiomediastinum,0.5,0.333333,0.4,21.0,0.782284
Fracture,0.5,0.428571,0.461538,7.0,0.513228
Lung Lesion,0.2,0.5,0.285714,2.0,0.99389
Lung Opacity,0.5,0.5625,0.529412,16.0,0.816431
No Finding,0.0,0.0,0.0,0.0,0.0
Pleural Effusion,0.881356,0.912281,0.896552,57.0,0.981209


# With pretained weights

In [31]:
df_res = pd.DataFrame(classification_report(y_trues
                                   , np.round(y_preds)
                                   , output_dict=True
                                   , target_names = list(test_data_set.class_to_idx.keys()))).T
df_res['auc'] = auc + [0,0,0,0]
df_res['auc'] = auc + [0
                       ,0
                       ,sum(df_res['support'].values * df_res['auc'].values)/1757
                       ,sum(df_res['auc'].values)/14
                      ]
df_res

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,precision,recall,f1-score,support,auc
Atelectasis,0.5,0.222222,0.307692,9.0,0.955034
Cardiomegaly,0.942029,0.948905,0.945455,137.0,0.995022
Consolidation,0.885965,1.0,0.939535,101.0,0.998809
Edema,0.915493,0.996169,0.954128,261.0,0.995499
Enlarged Cardiomediastinum,0.912281,0.962963,0.936937,54.0,0.989887
Fracture,0.857143,1.0,0.923077,6.0,1.0
Lung Lesion,0.875,0.7,0.777778,10.0,0.988252
Lung Opacity,0.710526,0.84375,0.771429,64.0,0.974196
No Finding,0.0,0.0,0.0,0.0,0.0
Pleural Effusion,0.932624,0.97048,0.951175,271.0,0.995938
