In [None]:
import os
import json
import cv2 as cv
import numpy as np
from tqdm import tqdm
import h5py

import pydicom
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import albumentations as A

import torch.optim as optim
from torch.optim import lr_scheduler
import time
import torch.nn as nn
import copy

from collections import OrderedDict

In [None]:
class_names = ['0.공기누출', '1.과다팽창', '2.무기폐', '3.신생아호흡곤란증후군', '4.폐렴', '5.흉막삼출', '6.정상']
classes = [0, 1, 2, 3, 4, 5, 6]
num_class = len(class_names)
print(num_class)

In [None]:
transform = False

batch_size = 16
image_size = 256

In [None]:
def _to_tensor(image, label, name):
    image = torch.from_numpy(image)
    image = image.reshape((1, image.shape[0], image.shape[1]))
        
    data = {'name':name, 'input': image, 'label': label}
    return data


In [None]:
class InfantDataset(Dataset):
    def __init__(self, root_dir='/home/ncp/workspace/data/', transform=True, image_size=None, mode='train'):
        self.root_dir = root_dir
        self.root_dir = os.path.join(root_dir, mode)
        self.image_size = image_size
        self.transform = transform
        self.mode=mode        

        self.dataset = open(f"/home/ncp/workspace/seung-ah/{self.mode}.txt",'r').read().splitlines()

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):            
        data = self.dataset[index]
        image, label = self._load_dicom(os.path.join(self.root_dir, data))
        image = self._preprocess_image(image)
        
        data = _to_tensor(image, label, data[:-5])
        return data

    def _preprocess_image(self, image):
        if len(image.shape) == 3:
            image = cv.cvtColor(image, cv.COLOR_BGR2GRAY)
            
        clahe = cv.createCLAHE(clipLimit=20)
        image = clahe.apply(image)
        
        image1 = image - np.min(image)
        image = image1 / np.max(image1)
        
        # agumentation
        if self.transform :
            image = _random_augment(image*255)

        return image
    
    def _load_dicom(self, image_path):
        with open(image_path, 'r', encoding='UTF8') as f:
            content = json.load(f)
        name = content['identifier']
        dicom_path = os.path.join(self.root_dir, content['mask_image']['org_dicom_file'][4:])
        f.close()
        
        dcm = pydicom.dcmread(dicom_path)
        origin_image = dcm.pixel_array
        
        class_id = int(content['patient']['diagnosis'])
        if class_id == 9:
            class_id = 7
        
        return origin_image, class_id-1
                        
        

In [None]:
def get_data(mode):
    dataset = InfantDataset(transform=False, image_size=image_size, mode=mode)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=1) 

    return dataset, loader

# pytorch Dataloader
test_dataset, test_loader = get_data('test')

In [None]:
num_data_test = len(test_dataset)  

num_batch_test = np.ceil(num_data_test / batch_size)

In [None]:
device = 'cuda'
print(device)

In [None]:
def load_model(path):
    dict_model = torch.load(path)
    print("Get saved weights successfully.")
    
    return dict_model

In [None]:
from sklearn.metrics import classification_report, roc_auc_score, roc_curve, confusion_matrix, f1_score
from sklearn.preprocessing import label_binarize


transforms = tta.Compose([           
    tta.Multiply(factors=[0.7, 1]),

])

def compute_metrics(model, test_loader, plot_roc_curve = False, mode='val'):
    
    model.eval()
    val_loss = 0
    losses = []
        
    criterion = nn.CrossEntropyLoss()
    
    score_list   = torch.Tensor([]).to(device)
    pred_list    = torch.Tensor([]).to(device).long()
    target_list  = torch.Tensor([]).to(device).long()

    for iter_num, data in enumerate(test_loader):
        
        # Convert image data into single channel data
        image, target = data['input'].to(device), data['label'].to(device)
        
        with torch.no_grad():
            output = model(image)
        losses.append({'name': data["name"], 'loss': criterion(output, target.long()).item()})
        
        # Log loss
        val_loss += criterion(output, target.long()).item()
        # Calculate the number of correctly classified examples
        pred = output.argmax(dim=1, keepdim=True)
        
        pred_list    = torch.cat([pred_list, pred.squeeze()])
        target_list  = torch.cat([target_list, target.squeeze()])
        
    
    classification_metrics = classification_report(target_list.tolist(), pred_list.tolist(),
                                                  target_names = class_names,
                                                  output_dict= True)

    # sensitivity is the recall of the positive class
    sensitivity = 0
    for name in class_names:
        sensitivity += classification_metrics[f'{name}']['recall']
        
    # specificity is the recall of the negative class 
    specificity = 0
    for name in class_names:
        specificity += classification_metrics[f'{name}']['precision']
        
    # accuracy
    accuracy = classification_metrics['accuracy']
    
    
    f1_score = 2 * (specificity * sensitivity) / (specificity + sensitivity)
    
    # confusion matrix
    conf_matrix = confusion_matrix(target_list.tolist(), pred_list.tolist())
    
    # put together values
    metrics_dict = {"Accuracy": accuracy * 100,
                    "Sensitivity": (sensitivity * 100) / num_class,
                    "Specificity": (specificity * 100) / num_class,
                    "F1 Score": (f1_score * 100) / num_class,
                    "Validation Loss": val_loss / len(test_loader),
                    "Confusion Matrix": conf_matrix,
                    "pred_list": pred_list.tolist(),
                    "target_list": target_list.tolist(),}
    
    
    return metrics_dict

In [None]:
import warnings 
warnings.filterwarnings(action='ignore')

In [None]:
from sklearn.metrics import classification_report, roc_auc_score, roc_curve, confusion_matrix, f1_score
from sklearn.preprocessing import label_binarize

import ttach as tta

transforms = tta.Compose([           
    tta.Multiply(factors=[0.7, 1]),

])

def compute_metrics(model, test_loader, plot_roc_curve = False, mode='val'):
    
    model.eval()
    val_loss = 0
    losses = []
        
    criterion = nn.CrossEntropyLoss()
    
    score_list   = torch.Tensor([]).to(device)
    pred_list    = torch.Tensor([]).to(device).long()
    target_list  = torch.Tensor([]).to(device).long()

    tta_model = tta.ClassificationTTAWrapper(model, transforms)
    
    for iter_num, data in enumerate(test_loader):
        
        # Convert image data into single channel data
        image, target = data['input'].to(device), data['label'].to(device)
        
        if mode == 'val' :
            with torch.no_grad():
                output = model(image)
        elif mode == 'test':
            output = tta_model(image)
        
        losses.append({'name': data["name"], 'loss': criterion(output, target.long()).item()})
        
        # Log loss
        val_loss += criterion(output, target.long()).item()
        # Calculate the number of correctly classified examples
        pred = output.argmax(dim=1, keepdim=True)
        
        pred_list    = torch.cat([pred_list, pred.squeeze()])
        target_list  = torch.cat([target_list, target.squeeze()])
        
    
    classification_metrics = classification_report(target_list.tolist(), pred_list.tolist(),
                                                  target_names = class_names,
                                                  output_dict= True)

    # sensitivity is the recall of the positive class
    sensitivity = 0
    for name in class_names:
        sensitivity += classification_metrics[f'{name}']['recall']
        
    # specificity is the recall of the negative class 
    specificity = 0
    for name in class_names:
        specificity += classification_metrics[f'{name}']['precision']
        
    # accuracy
    accuracy = classification_metrics['accuracy']
    
    
    f1_score = 2 * (specificity * sensitivity) / (specificity + sensitivity)
    
    # confusion matrix
    conf_matrix = confusion_matrix(target_list.tolist(), pred_list.tolist())
    
    # put together values
    metrics_dict = {"Accuracy": accuracy * 100,
                    "Sensitivity": (sensitivity * 100) / num_class,
                    "Specificity": (specificity * 100) / num_class,
                    "F1 Score": (f1_score * 100) / num_class,
                    "Validation Loss": val_loss / len(test_loader),
                    "Confusion Matrix": conf_matrix,
                    "pred_list": pred_list.tolist(),
                    "target_list": target_list.tolist(),}
    
    
    return metrics_dict

In [None]:
model_name = 'resnet50'
model = load_model(f'./checkpoint/best/best_model-{model_name}.pt')
# print(model)

lr=1e-5
momentum = 0.9

criterion = nn.CrossEntropyLoss()
    
metrics_dict = compute_metrics(model, test_loader, mode='test')
print('------------------- Test Performance --------------------------------------')
print("Accuracy \t {:.3f}".format(metrics_dict['Accuracy']))
print("Sensitivity \t {:.3f}".format(metrics_dict['Sensitivity']))
print("Specificity \t {:.3f}".format(metrics_dict['Specificity']))
print("F1 Score \t {:.3f}".format(metrics_dict['F1 Score']))
print("---------------------------------------------------------------------------")