In [None]:
import os
import json
import cv2 as cv
import numpy as np
from tqdm import tqdm
import h5py

import pydicom
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import albumentations as A

import torch.optim as optim
from torch.optim import lr_scheduler
import time
import torch.nn as nn
import copy

from collections import OrderedDict

In [None]:
class_names = ['0.공기누출', '1.과다팽창', '2.무기폐', '3.신생아호흡곤란증후군', '4.폐렴', '5.흉막삼출', '6.정상']
classes = [0, 1, 2, 3, 4, 5, 6]
num_class = len(class_names)
print(num_class)

In [None]:
transform = True

batch_size = 16
image_size = 256

In [None]:
def _random_augment(image):
    transform = A.Compose([
                        A.HorizontalFlip(p=0.3),
                        A.VerticalFlip(p=0.3),
#                         A.InvertImg(p=0.1),
                        A.ShiftScaleRotate(shift_limit=(-0.05, 0.05), rotate_limit=(-10, 10), scale_limit=(0, 0.05), border_mode=cv.BORDER_CONSTANT, value=0, p=0.5),
                        A.RandomBrightnessContrast(brightness_limit=(0, 0.3), contrast_limit=(0, 0.3), p=0.3),
                        A.GaussianBlur(blur_limit=(7, 15), p=0.3),  
                        A.OpticalDistortion(distort_limit=0.1, shift_limit=0.1, border_mode=cv.BORDER_CONSTANT, value=0, p=0.3),
                        ])
    
    augmented = transform(image=image)

    image = augmented['image']

    return image


def _to_tensor(image, label, name):
    image = np.transpose(image, (2, 0, 1))
    image = torch.from_numpy(image)
    
    data = {'name':name, 'input': image, 'label': label}
    return data


In [None]:
class InfantDataset(Dataset):
    def __init__(self, root_dir='/home/ncp/workspace/seung-ah/hdf5', transform=True, image_size=None, mode='train'):
        self.root_dir = os.path.join(root_dir, mode)
        self.image_size = image_size
        self.transform = transform
        self.mode=mode        
        self.dataset = os.listdir(self.root_dir)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):            
        data = self.dataset[index]
        image, label = self._load_hdf5(os.path.join(self.root_dir, data))
        image = self._preprocess_image(image)
        image = image.astype('float32')
        
        # agumentation
        if self.transform :
            image = _random_augment(image)
            
        # resize image
        dim = (self.image_size, self.image_size)
        image = cv.resize(image, dim, interpolation = cv.INTER_AREA)

        data = _to_tensor(image, label, data[:-5])
        return data

    def _preprocess_image(self, image):
        if len(image.shape) == 3:
            image = cv.cvtColor(image, cv.COLOR_BGR2GRAY)
            
        clahe = cv.createCLAHE(clipLimit=80)
        image = clahe.apply(image)
        
        image1 = image - np.min(image)
        image = image1 / np.max(image1)
        # np_image *= 255
        
        if not len(image.shape) == 3:
            _image = np.zeros((image.shape[0], image.shape[1], 3))
            _image[:,:,0] = image
            _image[:,:,1] = image
            _image[:,:,2] = image
        else:
            _image = image
            
        return _image
    
    def _load_hdf5(self, hdf5_path):
        with h5py.File(hdf5_path, 'r') as hf:  # open a hdf5 file
            keys = list(hf.keys())
            keys.sort()                   # sort key(image, input)

            for fName in keys:
                context = hf[fName]

                if fName == 'input':
                    image = np.array(hf.get(context.name))
                elif fName == 'class_id':
                    class_id = np.array(hf.get(context.name))[0]-1
                    
        hf.close()
        
        return image, class_id
                        
        

In [None]:
def get_data(mode):
    
    if mode == 'train':
        dataset = InfantDataset(transform=True, image_size=image_size, mode=mode)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1) 
    elif mode == 'val' or mode == 'test' :
        dataset = InfantDataset(transform=False, image_size=image_size, mode=mode)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=1) 

    return dataset, loader

# pytorch Dataloader
train_dataset, train_loader = get_data('train')
val_dataset, val_loader = get_data('val')
test_dataset, test_loader = get_data('test')
    

In [None]:
num_data_train = len(train_dataset) 
num_data_val = len(val_dataset)  
num_data_test = len(test_dataset)  

num_batch_train = np.ceil(num_data_train / batch_size) 
num_batch_val = np.ceil(num_data_val / batch_size)
num_batch_test = np.ceil(num_data_test / batch_size)

In [None]:
device = 'cuda'
print(device)

In [None]:
alex_model = torchvision.models.alexnet(pretrained=True).to(device)
res18_model = torchvision.models.resnet18(pretrained=True).to(device)
res50_model = torchvision.models.resnet50(pretrained=True).to(device)

In [None]:
num_ftrs = alex_model.classifier[6].in_features
fc = nn.Sequential(OrderedDict([
                                ('fc1', nn.Linear(num_ftrs,100)),
                                ('relu', nn.ReLU()),
                                ('drop-out', nn.Dropout(p=0.2, inplace=True)),
                                ('fc2', nn.Linear(100, num_class)),  # 3 is the number of classes we have in the dataset
                            ]))
alex_model.classifier[6] = fc.to(device)


num_ftrs = res18_model.fc.in_features
fc = nn.Sequential(OrderedDict([
                                ('fc1', nn.Linear(num_ftrs,100)),
                                ('relu', nn.ReLU()),
                                ('drop-out', nn.Dropout(p=0.2, inplace=True)),
                                ('fc2', nn.Linear(100, num_class)),  # 3 is the number of classes we have in the dataset
                            ]))
res18_model.fc = fc.to(device)


num_ftrs = res50_model.fc.in_features
fc = nn.Sequential(OrderedDict([
                                ('fc1', nn.Linear(num_ftrs,100)),
                                ('relu', nn.ReLU()),
                                ('drop-out', nn.Dropout(p=0.2, inplace=True)),
                                ('fc2', nn.Linear(100, num_class)),  # 3 is the number of classes we have in the dataset
                            ]))
res50_model.fc = fc.to(device)
# model[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)).to('cuda')

In [None]:
next(alex_model.parameters()).is_cuda # returns a boolean
next(res18_model.parameters()).is_cuda # returns a boolean
next(res50_model.parameters()).is_cuda # returns a boolean

In [None]:

lr=1e-5
momentum = 0.9
num_epoch=200

train_continue = False

criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.fc.parameters(), lr= lr, momentum= momentum)  
# optimizer = optim.SGD(model.parameters(), lr=lr, momentum= momentum)  

alex_optimizer = optim.Adam(alex_model.parameters(), lr= lr)  
res18_optimizer = optim.Adam(res18_model.parameters(), lr= lr)  
res50_optimizer = optim.Adam(res50_model.parameters(), lr= lr)  

In [None]:
def load_model(path, mode='test'):
    dict_model = torch.load(path)
    print("Get saved weights successfully.")
    if mode == 'test':
        return load_model
    else:
        epoch = int(ckpt_lst[-1].split('epoch')[1].split('.pth')[0])
        return model, epoch

def save_model(ckpt_dir, model, optim, epoch):
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    if 'best' in ckpt_dir:
        torch.save({'model': model.state_dict(), 'optim': optim.state_dict()},
            "./%s/model_best.pth" % (ckpt_dir))
        print(f'>> save model_best.pth')
    else:
        torch.save({'model': model.state_dict(), 'optim': optim.state_dict()},
                    "./%s/model_epoch%d.pth" % (ckpt_dir, epoch))
        print(f'>> save model_{epoch}.pth')

In [None]:
from sklearn.metrics import classification_report, roc_auc_score, roc_curve, confusion_matrix, f1_score
from sklearn.preprocessing import label_binarize


def compute_metrics(model, test_loader, plot_roc_curve = False, mode='val'):
    
    model.eval()
    val_loss = 0
        
    criterion = nn.CrossEntropyLoss()
    
    score_list   = torch.Tensor([]).to(device)
    pred_list    = torch.Tensor([]).to(device).long()
    target_list  = torch.Tensor([]).to(device).long()
    for iter_num, data in enumerate(test_loader):
        
        # Convert image data into single channel data
        image, target = data['input'].to(device), data['label'].to(device)
        
        with torch.no_grad():
            output = model(image)
        # Log loss
        val_loss += criterion(output, target.long()).item()
        # Calculate the number of correctly classified examples
        pred = output.argmax(dim=1, keepdim=True)
        
        pred_list    = torch.cat([pred_list, pred.squeeze()])
        target_list  = torch.cat([target_list, target.squeeze()])
        
    
    classification_metrics = classification_report(target_list.tolist(), pred_list.tolist(),
                                                  target_names = class_names,
                                                  output_dict= True)

    # sensitivity is the recall of the positive class
    sensitivity = 0
    for name in class_names:
        sensitivity += classification_metrics[f'{name}']['recall']
        
    # specificity is the recall of the negative class 
    specificity = 0
    for name in class_names:
        specificity += classification_metrics[f'{name}']['precision']
        
    # accuracy
    accuracy = classification_metrics['accuracy']
    
    
    f1_score = 2 * (specificity * sensitivity) / (specificity + sensitivity)
    
    # confusion matrix
    conf_matrix = confusion_matrix(target_list.tolist(), pred_list.tolist())
    
    # put together values
    metrics_dict = {"Accuracy": accuracy * 100,
                    "Sensitivity": (sensitivity * 100) / num_class,
                    "Specificity": (specificity * 100) / num_class,
                    "F1 Score": (f1_score * 100) / num_class,
                    "Validation Loss": val_loss / len(test_loader),
                    "Confusion Matrix": conf_matrix,
                    "pred_list": pred_list.tolist(),
                    "target_list": target_list.tolist(),}
    
    
    return metrics_dict

In [None]:
import warnings 
warnings.filterwarnings(action='ignore')

# Train/Valid

In [None]:
from collections import deque

class EarlyStopping(object):
    def __init__(self, patience = 8):
        super(EarlyStopping, self).__init__()
        self.patience = patience
        self.previous_loss = int(1e8)
        self.previous_accuracy = 0
        self.init = False
        self.accuracy_decrease_iters = 0
        self.loss_increase_iters = 0
        self.best_running_accuracy = 0
        self.best_running_loss = int(1e7)
    
    def add_data(self, model, loss, accuracy):
        
        # compute moving average
        if not self.init:
            running_loss = loss
            running_accuracy = accuracy 
            self.init = True
        
        else:
            running_loss = 0.2 * loss + 0.8 * self.previous_loss
            running_accuracy = 0.2 * accuracy + 0.8 * self.previous_accuracy
        
        # check if running accuracy has improved beyond the best running accuracy recorded so far
        if running_accuracy < self.best_running_accuracy:
            self.accuracy_decrease_iters += 1
        else:
            self.best_running_accuracy = running_accuracy
            self.accuracy_decrease_iters = 0
        
        # check if the running loss has decreased from the best running loss recorded so far
        if running_loss > self.best_running_loss:
            self.loss_increase_iters += 1
        else:
            self.best_running_loss = running_loss
            self.loss_increase_iters = 0
        
        # log the current accuracy and loss
        self.previous_accuracy = running_accuracy
        self.previous_loss = running_loss        
        
    
    def stop(self):
        
        # compute thresholds
        accuracy_threshold = self.accuracy_decrease_iters > self.patience
        loss_threshold = self.loss_increase_iters > self.patience
        
        
        # return codes corresponding to exhuaustion of patience for either accuracy or loss 
        # or both of them
        if accuracy_threshold and loss_threshold:
            return 1
        
        if accuracy_threshold:
            return 2
        
        if loss_threshold:
            return 3
        
        
        return 0
    
    def reset(self):
        # reset
        self.accuracy_decrease_iters = 0
        self.loss_increase_iters = 0
    

In [None]:
def training(model, mode):
    early_stopper = EarlyStopping(patience = 8)
    lr = 1e-5
    print(f'pulmonary-classification-{mode}')
    
    if mode == 'alex':
        optimizer = alex_optimizer
    elif mode == 'res18':
        optimizer = res18_optimizer
    elif mode == 'res50':
        optimizer = res50_optimizer

    best_model = model
    best_val_score = 0
    
    st_epoch = 0
    
    for epoch in range(st_epoch + 1, num_epoch + 1):

        model.train()    
        train_loss = 0
        train_correct = 0

        for iter_num, data in enumerate(tqdm(train_loader), 1):
            image = data['input'].to(device)       # [N, 3, image_size, image_size]
            target = data['label'].to(device)        # [N, image_size, image_size]

            # Compute the loss
            output = model(image)
            loss = criterion(output, target.long())

            # Log loss
            train_loss += loss.item()
            loss.backward()
            
            optimizer.step()
            optimizer.zero_grad()
            
            # Calculate the number of correctly classified examples
            pred = output.argmax(dim=1, keepdim=True)
            train_correct += pred.eq(target.long().view_as(pred)).sum().item()

        # Compute and print the performance metrics
        metrics_dict = compute_metrics(model, val_loader)

        # Save the model with best validation accuracy
        if metrics_dict['F1 Score'] > best_val_score:
            torch.save(model, f"./checkpoint/best/best_model-{mode}.pt")
            best_val_score = metrics_dict['F1 Score']
            print('Save best model...')

        # print the metrics for training data for the epoch
        print('Training Performance Epoch {}: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
            epoch, train_loss/len(train_loader.dataset), train_correct, len(train_loader.dataset),
            100.0 * train_correct / len(train_loader.dataset)))
        epoch_loss = train_loss/len(train_loader.dataset)
        epoch_acc = 100.0 * train_correct / len(train_loader.dataset)
        wandb.log({"train-loss": epoch_loss, "train-acc": epoch_acc, })

        print(f"Validation Performance Epoch: {epoch}, Loss: {metrics_dict['Validation Loss']}, Accuracy: {metrics_dict['Accuracy']}, F1 Score: {metrics_dict['F1 Score']}")

        wandb.log({"val-loss": metrics_dict["Validation Loss"],
                   "val-acc": metrics_dict["Accuracy"],
                   "val-sensitivity": metrics_dict["Sensitivity"],
                   "val-specificity": metrics_dict["Specificity"],
                   "val_f1-score": metrics_dict["F1 Score"],
                  })

        
        # Add data to the EarlyStopper object
        early_stopper.add_data(model, metrics_dict['Validation Loss'], metrics_dict['F1 Score'])

        # If both accuracy and loss are not improving, stop the training
        if early_stopper.stop() == 1:
            break

        # if only loss is not improving, lower the learning rate
        if early_stopper.stop() == 3:
            
            for param_group in optimizer.param_groups:
                lr *= 0.1
                param_group['lr'] = lr
                print('Updating the learning rate to {}'.format(lr))
                early_stopper.reset()


In [None]:
print('################################# Training AlexNet #################################')
training(alex_model, "alex")
print('################################# Training ResNet18 #################################')
training(res18_model, "res18")
print('################################# Training ResNet50 #################################')
training(res50_model, "res50")

# Test

In [None]:
def majority_voting_by_3(alex_prediction, res18_prediction,res50_prediction):
    final_prediction = list()
    for idx, (alex, res18, res50) in enumerate(zip(alex_prediction, res18_prediction, res50_prediction)):
        # Keep track of votes per class
        zero = one = two = three = four = five = six = 0

        # Loop over all models
        image_predictions = [alex, res18, res50]
        for img_prediction in image_predictions:
            # Voting
            if img_prediction == 0:
                zero += 1
            elif img_prediction == 1:
                one += 1
            elif img_prediction == 2:
                two += 1
            elif img_prediction == 3:
                three += 1
            elif img_prediction == 4:
                four += 1
            elif img_prediction == 5:
                five += 1
            elif img_prediction == 6:
                six += 1
                
        # Find max vote
        count_dict = {'공기누출': zero, '과다팽창': one, '무기폐': two, '신생아호흡곤란증후군': three,
                      '폐렴': four, '흉막삼출': five, '정상': six}
        
        highest = max(count_dict.values())
        max_values = [k for k, v in count_dict.items() if v == highest]
        ensemble_prediction = []
        for max_value in max_values:
            if max_value == '공기누출':
                ensemble_prediction.append(0)
            elif max_value == '과다팽창':
                ensemble_prediction.append(1)
            elif max_value == '무기폐':
                ensemble_prediction.append(2)
            elif max_value == '신생아호흡곤란증후군':
                ensemble_prediction.append(3)
            elif max_value == '폐렴':
                ensemble_prediction.append(4)
            elif max_value == '흉막삼출':
                ensemble_prediction.append(5)
            elif max_value == '정상':
                ensemble_prediction.append(6)

        predict = ''
        if len(ensemble_prediction) > 1:
            predict = res50
        else:
            predict = ensemble_prediction[0]
        
        res50_prediction[idx] = predict
        
    return res50_prediction.cpu().numpy()


In [None]:
from sklearn.metrics import classification_report, roc_auc_score, roc_curve, confusion_matrix
from sklearn.preprocessing import label_binarize

def compute_metrics_test(alex_model, res18_model, res50_model, test_loader):
    
    alex_model.eval()
    res18_model.eval()
    res50_model.eval()
    
    val_loss = [0, 0, 0]
    val_correct = [0, 0, 0]
    
    criterion = nn.CrossEntropyLoss()
        
    alex_pred_list    = torch.Tensor([]).to(device).long()
    res18_pred_list    = torch.Tensor([]).to(device).long()
    res50_pred_list    = torch.Tensor([]).to(device).long()
    
    target_list  = torch.Tensor([]).to(device).long()

    
    for iter_num, data in enumerate(test_loader):
        
        # Convert image data into single channel data
        image, target = data['input'].to(device), data['label'].to(device)
        
        # Compute the loss
        with torch.no_grad():
            start = time.time()
            alex_output = alex_model(image)
            end = time.time()
            
            start = time.time()
            res18_output = res18_model(image)
            end = time.time()
            
            start = time.time()
            res50_output = res50_model(image)
            end = time.time()
        
        # Log loss
        val_loss[0] += criterion(alex_output, target.long()).item()
        val_loss[1] += criterion(res18_output, target.long()).item()
        val_loss[2] += criterion(res50_output, target.long()).item()
        
        # Calculate the number of correctly classified examples
        alex_pred = alex_output.argmax(dim=1, keepdim=True)
        val_correct[0] += alex_pred.eq(target.long().view_as(alex_pred)).sum().item()
        res18_pred = res18_output.argmax(dim=1, keepdim=True)
        val_correct[1] += res18_pred.eq(target.long().view_as(res18_pred)).sum().item()
        res50_pred = res50_output.argmax(dim=1, keepdim=True)
        val_correct[2] += res50_pred.eq(target.long().view_as(res50_pred)).sum().item()
        
        # Bookkeeping 
        alex_pred_list    = torch.cat([alex_pred_list, alex_pred.squeeze()])
        res18_pred_list    = torch.cat([res18_pred_list, res18_pred.squeeze()])
        res50_pred_list    = torch.cat([res50_pred_list, res50_pred.squeeze()])
        
        target_list  = torch.cat([target_list, target.squeeze()])
    
    pred_list = majority_voting_by_3(alex_pred_list, res18_pred_list, res50_pred_list)
    
    classification_metrics = classification_report(target_list.tolist(), pred_list.tolist(),
                                                  target_names = class_names,
                                                  output_dict= True)

    # sensitivity is the recall of the positive class
    sensitivity = 0
    for name in class_names:
        sensitivity += classification_metrics[f'{name}']['recall']
    
    # specificity is the recall of the negative class 
    specificity = 0
    for name in class_names:
        specificity += classification_metrics[f'{name}']['recall']
    
    f1_score = 2 * (specificity * sensitivity) / (specificity + sensitivity)
    # accuracy
    accuracy = classification_metrics['accuracy']
    
    # confusion matrix
    conf_matrix = confusion_matrix(target_list.tolist(), pred_list.tolist())
    
    val_loss = np.mean(val_loss)
    
    # put together values
    metrics_dict = {"Accuracy": accuracy,
                    "Sensitivity": (sensitivity * 100) / num_class,
                    "Specificity": (specificity * 100) / num_class,
                    "F1 Score": (f1_score * 100) / num_class,
                    "Confusion Matrix": conf_matrix,
                    "Validation Loss": val_loss / len(test_loader),
                    "pred_list": pred_list.tolist(),
                    "target_list": target_list.tolist(),}
    
    
    return metrics_dict

In [None]:
alex_model = load_model('./checkpoint/best/best_model-alex.pt')
res18_model = load_model('./checkpoint/best/best_model-res18.pt')
res50_model = load_model('./checkpoint/best/best_model-resnet50.pt')
    
metrics_dict = compute_metrics_test(alex_model, res18_model, res50_model, test_loader)
print('------------------- Test Performance --------------------------------------')
print("Accuracy \t {:.3f}".format(metrics_dict['Accuracy']))
print("Sensitivity \t {:.3f}".format(metrics_dict['Sensitivity']))
print("Specificity \t {:.3f}".format(metrics_dict['Specificity']))
print("Specificity \t {:.3f}".format(metrics_dict['F1 Score']))
print("---------------------------------------------------------------------------")