In [None]:
cd /home/ncp/workspace

In [None]:
import os
import json
import cv2 as cv
import numpy as np
from tqdm import tqdm
import h5py
import glob

import pydicom
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import albumentations as A

import torch.optim as optim
from torch.optim import lr_scheduler
import time
import torch.nn as nn
import copy

from collections import OrderedDict

In [None]:
class_names = ['0.공기누출', '1.과다팽창', '2.무기폐', '3.신생아호흡곤란증후군', '4.폐렴', '5.흉막삼출', '6.정상']
classes = [0, 1, 2, 3, 4, 5, 6]
num_class = len(class_names)
print(num_class)

In [None]:
transform = True

batch_size = 16
image_size = 256

In [None]:
def _to_tensor(image, label, name):
    image = np.transpose(image, (2, 0, 1))
    image = torch.from_numpy(image)
    
    data = {'name':name, 'input': image, 'label': label}
    return data


In [None]:
def load_data(path, external = False):
    if external:
        mode = 'test'
    else: 
        mode = 'train'
                            
    class_0 = glob.glob(f"/home/ncp/workspace/data/{mode}/01.공기누출/*/*/metadata/*")
    class_1 = glob.glob(f"/home/ncp/workspace/data/{mode}/02.과다팽창/*/*/metadata/*")
    class_2 = glob.glob(f"/home/ncp/workspace/data/{mode}/03.무기폐/*/*/metadata/*")
    class_3 = glob.glob(f"/home/ncp/workspace/data/{mode}/04.신생아호흡곤란증후군/*/*/metadata/*")
    class_4 = glob.glob(f"/home/ncp/workspace/data/{mode}/05.폐렴/*/*/metadata/*")
    class_5 = glob.glob(f"/home/ncp/workspace/data/{mode}/06.흉막삼출/*/*/metadata/*")
    class_6 = glob.glob(f"/home/ncp/workspace/data/{mode}/09.정상/*/*/metadata/*")
        
    dataset = class_0 + class_1 + class_2 + class_3 + class_4 + class_5 + class_6
    
    return dataset

In [None]:
class InfantDataset(Dataset):
    def __init__(self, root_dir='./data/test', transform=True, image_size=None, external=True):
        self.root_dir = root_dir
        self.image_size = image_size
        self.transform = transform 
        self.external=external
        self.dataset = load_data(self.root_dir, self.external)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):    
        data = self.dataset[index]
        image, label = self._load_dicom(os.path.join(self.root_dir, data))
        image = self._preprocess_image(image)
        image = image.astype('float32')
        
        # resize image
        dim = (self.image_size, self.image_size)
        image = cv.resize(image, dim, interpolation = cv.INTER_AREA)
        
        data = _to_tensor(image, label, data[:-5])
        return data

    def _preprocess_image(self, image):
        if len(image.shape) == 3:
            image = cv.cvtColor(image, cv.COLOR_BGR2GRAY)
            
        clahe = cv.createCLAHE(clipLimit=80)
        image = clahe.apply(image)
        
        image1 = image - np.min(image)
        image = image1 / np.max(image1)
        # np_image *= 255
        
        if not len(image.shape) == 3:
            _image = np.zeros((image.shape[0], image.shape[1], 3))
            _image[:,:,0] = image
            _image[:,:,1] = image
            _image[:,:,2] = image
        else:
            _image = image
            
        return _image
    
    def _load_dicom(self, path):
        with open(path, 'r', encoding='UTF8') as f:
            content = json.load(f)
            file_name = content['identifier']
            dicom_path = os.path.join(self.root_dir, content['mask_image']['org_dicom_file'][4:])

            class_id = int(content['patient']['diagnosis'])
            if class_id == 9:
                class_id = 7

        f.close()
        
        dcm = pydicom.dcmread(dicom_path)
        image = dcm.pixel_array
        
        return image, class_id - 1
        

In [None]:
def get_data(root_dir, external=False):
    if external:# external data
        dataset = InfantDataset(root_dir=root_dir, transform=False, image_size=image_size, external=True)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=1)  
    else:      # external data
        dataset = InfantDataset(root_dir=root_dir, transform=False, image_size=image_size)   
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=1)  
        
    return dataset, loader

# pytorch Dataloader
external_test_dataset, external_test_loader = get_data(root_dir='./data/test', external = True)   

In [None]:
num_data_external = len(external_test_dataset) 

num_batch_external = np.ceil(num_data_external / batch_size)

In [None]:
device = 'cuda'
print(device)

In [None]:
import warnings 
warnings.filterwarnings(action='ignore')

In [None]:
def load_model(path, mode='test'):
    dict_model = torch.load(path)
    print("Get saved weights successfully.")
    
    return dict_model

In [None]:
from sklearn.metrics import classification_report, roc_auc_score, roc_curve, confusion_matrix, f1_score
from sklearn.preprocessing import label_binarize

import ttach as tta

transforms = tta.Compose([           
    tta.Multiply(factors=[0.7, 1]),

])

def compute_metrics(model, test_loader, plot_roc_curve = False, mode='val'):
    
    model.eval()
    val_loss = 0
        
    criterion = nn.CrossEntropyLoss()
    
    score_list   = torch.Tensor([]).to(device)
    pred_list    = torch.Tensor([]).to(device).long()
    target_list  = torch.Tensor([]).to(device).long()

    tta_model = tta.ClassificationTTAWrapper(model, transforms)
    
    for iter_num, data in enumerate(test_loader):
        
        # Convert image data into single channel data
        image, target = data['input'].to(device), data['label'].to(device)
        
        if mode == 'val' :
            with torch.no_grad():
                output = model(image)
        elif mode == 'test':
            output = tta_model(image)
        
        # Log loss
        val_loss += criterion(output, target.long()).item()
        # Calculate the number of correctly classified examples
        pred = output.argmax(dim=1, keepdim=True)
        
        pred_list    = torch.cat([pred_list, pred.squeeze()])
        target_list  = torch.cat([target_list, target.squeeze()])
        
    
    classification_metrics = classification_report(target_list.tolist(), pred_list.tolist(),
                                                  target_names = class_names,
                                                  output_dict= True)

    # sensitivity is the recall of the positive class
    sensitivity = 0
    for name in class_names:
        sensitivity += classification_metrics[f'{name}']['recall']
        
    # specificity is the recall of the negative class 
    specificity = 0
    for name in class_names:
        specificity += classification_metrics[f'{name}']['precision']
        
    # accuracy
    accuracy = classification_metrics['accuracy']
    
    f1_score = 2 * (specificity * sensitivity) / (specificity + sensitivity)
    
    # confusion matrix
    conf_matrix = confusion_matrix(target_list.tolist(), pred_list.tolist())
    
    # put together values
    metrics_dict = {"Accuracy": accuracy * 100,
                    "Sensitivity": (sensitivity * 100) / num_class,
                    "Specificity": (specificity * 100) / num_class,
                    "F1 Score": (f1_score * 100) / num_class,
                    "Validation Loss": val_loss / len(test_loader),
                    "Confusion Matrix": conf_matrix,
                    "pred_list": pred_list.tolist(),
                    "target_list": target_list.tolist(),}
    
    
    return metrics_dict

In [None]:
def majority_voting_by_3(alex_prediction, res18_prediction,res50_prediction):
    final_prediction = list()
    for idx, (alex, res18, res50) in enumerate(zip(alex_prediction, res18_prediction, res50_prediction)):
        # Keep track of votes per class
        zero = one = two = three = four = five = six = 0

        # Loop over all models
        image_predictions = [alex, res18, res50]
        for img_prediction in image_predictions:
            # Voting
            if img_prediction == 0:
                zero += 1
            elif img_prediction == 1:
                one += 1
            elif img_prediction == 2:
                two += 1
            elif img_prediction == 3:
                three += 1
            elif img_prediction == 4:
                four += 1
            elif img_prediction == 5:
                five += 1
            elif img_prediction == 6:
                six += 1
                
        # Find max vote
        count_dict = {'공기누출': zero, '과다팽창': one, '무기폐': two, '신생아호흡곤란증후군': three,
                      '폐렴': four, '흉막삼출': five, '정상': six}
        
        highest = max(count_dict.values())
        max_values = [k for k, v in count_dict.items() if v == highest]
        ensemble_prediction = []
        for max_value in max_values:
            if max_value == '공기누출':
                ensemble_prediction.append(0)
            elif max_value == '과다팽창':
                ensemble_prediction.append(1)
            elif max_value == '무기폐':
                ensemble_prediction.append(2)
            elif max_value == '신생아호흡곤란증후군':
                ensemble_prediction.append(3)
            elif max_value == '폐렴':
                ensemble_prediction.append(4)
            elif max_value == '흉막삼출':
                ensemble_prediction.append(5)
            elif max_value == '정상':
                ensemble_prediction.append(6)

        predict = ''
        if len(ensemble_prediction) > 1:
            predict = res50
        else:
            predict = ensemble_prediction[0]
        
        res50_prediction[idx] = predict
        
    return res50_prediction.cpu().numpy()


In [None]:
from sklearn.metrics import classification_report, roc_auc_score, roc_curve, confusion_matrix
from sklearn.preprocessing import label_binarize

def compute_metrics_test(alex_model, res18_model, res50_model, test_loader):
    
    alex_model.eval()
    res18_model.eval()
    res50_model.eval()
    
    val_loss = [0, 0, 0]
    val_correct = [0, 0, 0]
    
    criterion = nn.CrossEntropyLoss()
        
    alex_pred_list    = torch.Tensor([]).to(device).long()
    res18_pred_list    = torch.Tensor([]).to(device).long()
    res50_pred_list    = torch.Tensor([]).to(device).long()
    
    target_list  = torch.Tensor([]).to(device).long()

    for iter_num, data in enumerate(tqdm(test_loader), 1):
        
        # Convert image data into single channel data
        image, target = data['input'].to(device), data['label'].to(device)
        
        # Compute the loss
        with torch.no_grad():
            start = time.time()
            alex_output = alex_model(image)
            end = time.time()
            
            start = time.time()
            res18_output = res18_model(image)
            end = time.time()
            
            start = time.time()
            res50_output = res50_model(image)
            end = time.time()
        
        # Log loss
        val_loss[0] += criterion(alex_output, target.long()).item()
        val_loss[1] += criterion(res18_output, target.long()).item()
        val_loss[2] += criterion(res50_output, target.long()).item()
        
        # Calculate the number of correctly classified examples
        alex_pred = alex_output.argmax(dim=1, keepdim=True)
        val_correct[0] += alex_pred.eq(target.long().view_as(alex_pred)).sum().item()
        res18_pred = res18_output.argmax(dim=1, keepdim=True)
        val_correct[1] += res18_pred.eq(target.long().view_as(res18_pred)).sum().item()
        res50_pred = res50_output.argmax(dim=1, keepdim=True)
        val_correct[2] += res50_pred.eq(target.long().view_as(res50_pred)).sum().item()
        
        # Bookkeeping 
        alex_pred_list    = torch.cat([alex_pred_list, alex_pred.squeeze()])
        res18_pred_list    = torch.cat([res18_pred_list, res18_pred.squeeze()])
        res50_pred_list    = torch.cat([res50_pred_list, res50_pred.squeeze()])
        
        target_list  = torch.cat([target_list, target.squeeze()])
    
    pred_list = majority_voting_by_3(alex_pred_list, res18_pred_list, res50_pred_list)
    
    classification_metrics = classification_report(target_list.tolist(), pred_list.tolist(),
                                                  target_names = class_names,
                                                  output_dict= True)

    # sensitivity is the recall of the positive class
    sensitivity = 0
    for name in class_names:
        sensitivity += classification_metrics[f'{name}']['recall']
    
    # specificity is the recall of the negative class 
    specificity = 0
    for name in class_names:
        specificity += classification_metrics[f'{name}']['recall']
    
    f1_score = 2 * (specificity * sensitivity) / (specificity + sensitivity)
    # accuracy
    accuracy = classification_metrics['accuracy']
    
    # confusion matrix
    conf_matrix = confusion_matrix(target_list.tolist(), pred_list.tolist())
    
    val_loss = np.mean(val_loss)
    
    # put together values
    metrics_dict = {"Accuracy": accuracy,
                    "Sensitivity": (sensitivity * 100) / num_class,
                    "Specificity": (specificity * 100) / num_class,
                    "F1 Score": (f1_score * 100) / num_class,
                    "Confusion Matrix": conf_matrix,
                    "Validation Loss": val_loss / len(test_loader),
                    "pred_list": pred_list.tolist(),
                    "target_list": target_list.tolist(),}
    
    
    return metrics_dict

In [None]:
alex_model = load_model('/home/ncp/workspace/seung-ah/checkpoint/best/best_model-alex.pt')
res18_model = load_model('/home/ncp/workspace/seung-ah/checkpoint/best/best_model-res18.pt')
res50_model = load_model('/home/ncp/workspace/seung-ah/checkpoint/best/best_model-resnet50(v1).pt')

criterion = nn.CrossEntropyLoss()

metrics_dict = compute_metrics_test(alex_model, res18_model, res50_model, external_test_loader)
print('------------------- Test Performance --------------------------------------')
print("Accuracy \t {:.3f}".format(metrics_dict['Accuracy']))
print("Sensitivity \t {:.3f}".format(metrics_dict['Sensitivity']))
print("Specificity \t {:.3f}".format(metrics_dict['Specificity']))
print("F1 Score \t {:.3f}".format(metrics_dict['F1 Score']))
print("---------------------------------------------------------------------------")