In [None]:
import os
import json
import cv2 as cv
import numpy as np
from tqdm import tqdm
import h5py

import pydicom
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import albumentations as A

import torch.optim as optim
from torch.optim import lr_scheduler
import time
import torch.nn as nn
import copy

from collections import OrderedDict

In [None]:
class_names = ['0.공기누출', '1.과다팽창', '2.무기폐', '3.신생아호흡곤란증후군', '4.폐렴', '5.흉막삼출', '6.정상']
classes = [0, 1, 2, 3, 4, 5, 6]
num_class = len(class_names)
print(num_class)

In [None]:
transform = True

batch_size = 16
image_size = 224

In [None]:
def _random_augment(image):
    transform = A.Compose([
                        A.HorizontalFlip(p=0.3),
                        A.HorizontalFlip(p=0.05),
                        A.OneOf([
                                    A.InvertImg(p=0.5),
                                    A.ShiftScaleRotate(shift_limit=(0, 0), rotate_limit=(-5, 5), scale_limit=(-0.15,0.05), border_mode=cv.BORDER_CONSTANT, value=0, p=0.3),
                                    A.OpticalDistortion(distort_limit=0.1, shift_limit=0, border_mode=cv.BORDER_CONSTANT, value=0, p=0.3),
                                ]),
#       
                        A.Resize(image_size, image_size),                  
                        A.lize(mean=0.5, std=0.5)
                    ])

    
    augmented = transform(image=image)

    image = augmented['image']

    return image


def _to_tensor(image, label, name):
    
#     image = np.transpose(image, (2, 0, 1))
    image = torch.from_numpy(image)
    image = image.reshape((1, image.shape[0], image.shape[1]))
    
    data = {'name':name, 'input': image, 'label': label}
    return data


In [None]:
class InfantDataset(Dataset):
    def __init__(self, root_dir='/home/ncp/workspace/seung-ah/hdf5', transform=True, image_size=None, mode='train'):
        self.root_dir = os.path.join(root_dir, mode)
        self.image_size = image_size
        self.transform = transform
        self.mode=mode        
        self.dataset = os.listdir(self.root_dir)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):            
        data = self.dataset[index]
        image, label = self._load_hdf5(os.path.join(self.root_dir, data))
        image = self._preprocess_image(image)
        
        data = _to_tensor(image, label, data[:-5])
        return data

    def _preprocess_image(self, image):
        if len(image.shape) == 3:
            image = cv.cvtColor(image, cv.COLOR_BGR2GRAY)
            
        clahe = cv.createCLAHE(clipLimit=80, tileGridSize=(16, 16))
        image = clahe.apply(image)
        
        image1 = image - np.min(image)
        image = image1 / np.max(image1)
        
        # agumentation
        if self.transform :
            image = _random_augment(image)
            
        return image
            
    def _load_hdf5(self, hdf5_path):
        with h5py.File(hdf5_path, 'r') as hf:  # open a hdf5 file
            keys = list(hf.keys())
            keys.sort()                   # sort key(image, input)

            for fName in keys:
                context = hf[fName]

                if fName == 'input':
                    image = np.array(hf.get(context.name))
                elif fName == 'class_id':
                    class_id = np.array(hf.get(context.name))[0]-1
                    
        hf.close()
        
        return image, class_id
                        
        

In [None]:
def get_data(mode):
    
    if mode == 'train':
        dataset = InfantDataset(transform=True, image_size=image_size, mode=mode)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2) 
    elif mode == 'val' or mode == 'test' :
        dataset = InfantDataset(transform=False, image_size=image_size, mode=mode)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2) 

    return dataset, loader

# pytorch Dataloader
train_dataset, train_loader = get_data('train')
val_dataset, val_loader = get_data('val')
test_dataset, test_loader = get_data('test')


data = train_dataset.__getitem__(0)
    

In [None]:
num_data_train = len(train_dataset) 
num_data_val = len(val_dataset)  
num_data_test = len(test_dataset)  

num_batch_train = np.ceil(num_data_train / batch_size) 
num_batch_val = np.ceil(num_data_val / batch_size)
num_batch_test = np.ceil(num_data_test / batch_size)

In [None]:
from segmentation_models_pytorch.unet.model import Unet
from torch import nn
from torch.optim.lr_scheduler import StepLR

In [None]:

# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = 'cuda'
print(device)

In [None]:
# model = torchvision.models.vgg19(pretrained=True).to('cuda')
# model = model.features
# model_name = 'vgg_cnn'

# classifier = nn.Sequential(OrderedDict([
#                                 ('conv1', nn.Conv2d(512, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)).to('cuda')),
#                                 ('relu', nn.ReLU()),
#                                 ('conv2', nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)).to('cuda')),
#                                 ('relu', nn.ReLU()),
#                                 ('bn', nn.BatchNorm2d(128)),
#                                 ('maxpool', nn.MaxPool2d(kernel_size=3)),
#                                 ('drop-out1', nn.Dropout(p=0.2, inplace=True)),
#                                 ('flatten', nn.Flatten()),
#                                 ('fc1', nn.Linear(512, 512)), 
#                                 ('drop-out2', nn.Dropout(p=0.2, inplace=True)),
#                                 ('fc2', nn.Linear(512, 7)), 
#                             ]))

# # model.classifier = classifier.to('cuda')
# model[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)).to('cuda')

In [None]:
# model = torchvision.models.resnet18(pretrained=True).to('cuda')

# model = torch.nn.Sequential(*(list(model.children())[:-2]))

# model_name = 'res18_cnn'

# classifier = nn.Sequential(OrderedDict([
#                                 ('conv1', nn.Conv2d(512, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)).to('cuda')),
#                                 ('relu', nn.ReLU()),
#                                 ('conv2', nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)).to('cuda')),
#                                 ('relu', nn.ReLU()),
#                                 ('bn', nn.BatchNorm2d(128)),
#                                 ('maxpool', nn.MaxPool2d(kernel_size=3)),
#                                 ('drop-out1', nn.Dropout(p=0.2, inplace=True)),
#                                 ('flatten', nn.Flatten()),
#                                 ('fc1', nn.Linear(512, 512)), 
#                                 ('drop-out2', nn.Dropout(p=0.2, inplace=True)),
#                                 ('fc2', nn.Linear(512, 7)), 
#                             ]))

# model.classifier = classifier.to('cuda')
# model[0] = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)).to('cuda')

model = torchvision.models.resnet18(pretrained=True).to(device)
model_name = 'resnet18'

num_ftrs = model.fc.in_features

fc = nn.Sequential(OrderedDict([
                                ('fc1', nn.Linear(num_ftrs,100)),
                                ('relu', nn.ReLU()),
                                ('drop-out', nn.Dropout(p=0.2, inplace=True)),
                                ('fc2', nn.Linear(100, num_class)),  # 3 is the number of classes we have in the dataset
                            ]))
model.fc = fc.to(device)
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)).to(device)

In [None]:

# lr=0.00006
lr = 1e-5
momentum = 0.9
num_epoch=500

train_continue = False

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr= lr)  
# Decay LR by a factor of 0.1 every 7 epochs
print(f'lr: {lr}, optim:{optimizer}')

In [None]:
def load_model(path, model, mode='test'):
    dict_model = torch.load(path)
    print("Get saved weights successfully.")
    if mode == 'test':
        return dict_model
    else:
        epoch = int(ckpt_lst[-1].split('epoch')[1].split('.pth')[0])
        return model, epoch

def save_model(ckpt_dir, model, optim, epoch):
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    if 'best' in ckpt_dir:
        torch.save({'model': model.state_dict(), 'optim': optim.state_dict()},
            "./%s/model_best.pth" % (ckpt_dir))
        print(f'>> save model_best.pth')
    else:
        torch.save({'model': model.state_dict(), 'optim': optim.state_dict()},
                    "./%s/model_epoch%d.pth" % (ckpt_dir, epoch))
        print(f'>> save model_{epoch}.pth')

In [None]:
from sklearn.metrics import classification_report, roc_auc_score, roc_curve, confusion_matrix, f1_score
from sklearn.preprocessing import label_binarize

import ttach as tta

transforms = tta.Compose([           
    tta.Multiply(factors=[0.7, 1]),

])

def compute_metrics(model, test_loader, plot_roc_curve = False, mode='val'):
    
    model.eval()
    val_loss = 0
    losses = []
        
    criterion = nn.CrossEntropyLoss()
    
    score_list   = torch.Tensor([]).to(device)
    pred_list    = torch.Tensor([]).to(device).long()
    target_list  = torch.Tensor([]).to(device).long()

    tta_model = tta.ClassificationTTAWrapper(model, transforms)
    
    for iter_num, data in enumerate(test_loader):
        
        # Convert image data into single channel data
        image, target = data['input'].to(device).float(), data['label'].to(device)
        
        if mode == 'val' :
            with torch.no_grad():
                output = model(image)
        elif mode == 'test':
            with torch.no_grad():
                output = tta_model(image)
        
        # Log loss
        val_loss += criterion(output, target.long()).item()
        
        # Calculate the number of correctly classified examples
        pred = output.argmax(dim=1, keepdim=True)
        
        pred_list    = torch.cat([pred_list, pred.squeeze()])
        target_list  = torch.cat([target_list, target.squeeze()])
    
    classification_metrics = classification_report(target_list.tolist(), pred_list.tolist(),
                                                  target_names = class_names,
                                                  output_dict= True)

    # sensitivity is the recall of the positive class
    sensitivity = 0
    for name in class_names:
        sensitivity += classification_metrics[f'{name}']['recall']
        
    # specificity is the recall of the negative class 
    specificity = 0
    for name in class_names:
        specificity += classification_metrics[f'{name}']['precision']
        
    # accuracy
    accuracy = classification_metrics['accuracy']
    
    
    f1_score = 2 * (specificity * sensitivity) / (specificity + sensitivity)
    
    # confusion matrix
    conf_matrix = confusion_matrix(target_list.tolist(), pred_list.tolist())
    
    # put together values
    metrics_dict = {"Accuracy": accuracy * 100,
                    "Sensitivity": (sensitivity * 100) / num_class,
                    "Specificity": (specificity * 100) / num_class,
                    "F1 Score": (f1_score * 100) / num_class,
                    "Validation Loss": val_loss / len(test_loader),
                    "Confusion Matrix": conf_matrix,
                    "pred_list": pred_list.tolist(),
                    "target_list": target_list.tolist(),}
    
    
    return metrics_dict

In [None]:
from collections import deque

class EarlyStopping(object):
    def __init__(self, patience = 8):
        super(EarlyStopping, self).__init__()
        self.patience = patience
        self.previous_loss = int(1e8)
        self.previous_accuracy = 0
        self.init = False
        self.accuracy_decrease_iters = 0
        self.loss_increase_iters = 0
        self.best_running_accuracy = 0
        self.best_running_loss = int(1e7)
    
    def add_data(self, model, loss, accuracy):
        
        # compute moving average
        if not self.init:
            running_loss = loss
            running_accuracy = accuracy 
            self.init = True
        
        else:
            running_loss = 0.2 * loss + 0.8 * self.previous_loss
            running_accuracy = 0.2 * accuracy + 0.8 * self.previous_accuracy
        
        # check if running accuracy has improved beyond the best running accuracy recorded so far
        if running_accuracy < self.best_running_accuracy:
            self.accuracy_decrease_iters += 1
        else:
            self.best_running_accuracy = running_accuracy
            self.accuracy_decrease_iters = 0
        
        # check if the running loss has decreased from the best running loss recorded so far
        if running_loss > self.best_running_loss:
            self.loss_increase_iters += 1
        else:
            self.best_running_loss = running_loss
            self.loss_increase_iters = 0
        
        # log the current accuracy and loss
        self.previous_accuracy = running_accuracy
        self.previous_loss = running_loss        
        
    
    def stop(self):
        
        # compute thresholds
        accuracy_threshold = self.accuracy_decrease_iters > self.patience
        loss_threshold = self.loss_increase_iters > self.patience
        
        
        # return codes corresponding to exhuaustion of patience for either accuracy or loss 
        # or both of them
        if accuracy_threshold and loss_threshold:
            return 1
        
        if accuracy_threshold:
            return 2
        
        if loss_threshold:
            return 3
        
        
        return 0
    
    def reset(self):
        # reset
        self.accuracy_decrease_iters = 0
        self.loss_increase_iters = 0
    
early_stopper = EarlyStopping(patience = 10)

In [None]:
wandb.init(project=f'pulmonary-classification-{model_name}', reinit=True)
wandb.run.name = 'v1'
wandb.config = {'learning_rate':lr, 'epochs':num_epoch, 'batch_size':batch_size}
wandb.watch(model)

best_model = model
best_val_score = 0

st_epoch = 0
for epoch in range(st_epoch + 1, num_epoch + 1):

    model.train()    
    train_loss = 0
    train_correct = 0
    
    for iter_num, data in enumerate(tqdm(train_loader), 1):
        image = data['input'].to(device)       # [N, 3, image_size, image_size]
        target = data['label'].to(device)        # [N, image_size, image_size]
        
        # Compute the loss
        output = model(image)
        loss = criterion(output, target.long())
        
        # Log loss
        train_loss += loss.item()
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
            

        # Calculate the number of correctly classified examples
        pred = output.argmax(dim=1, keepdim=True)
        train_correct += pred.eq(target.long().view_as(pred)).sum().item()
        
    # Compute and print the performance metrics
    metrics_dict = compute_metrics(model, val_loader)

    # Save the model with best validation accuracy
    if metrics_dict['F1 Score'] > best_val_score:
        torch.save(model, f"./checkpoint/best/best_model-{model_name}.pt")
        best_val_score = metrics_dict['F1 Score']
        print('save best model')
        
#     if epoch  == 10:
#         for param in model.parameters():
#             param.requires_grad = True
        

    # print the metrics for training data for the epoch
    print('Training Performance Epoch {}: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        epoch, train_loss/len(train_loader.dataset), train_correct, len(train_loader.dataset),
        100.0 * train_correct / len(train_loader.dataset)))
    epoch_loss = train_loss/len(train_loader.dataset)
    epoch_acc = 100.0 * train_correct / len(train_loader.dataset)
    wandb.log({"train-loss": epoch_loss, "train-acc": epoch_acc, })
    
    
    print(f"Validation Performance Epoch: {epoch}, Loss: {metrics_dict['Validation Loss']}, Accuracy: {metrics_dict['Accuracy']}, F1 Score: {metrics_dict['F1 Score']}")
    
    wandb.log({"val-loss": metrics_dict["Validation Loss"],
               "val-acc":metrics_dict["Accuracy"],
               "val-sensitivity": metrics_dict["Sensitivity"],
               "val-specificity": metrics_dict["Specificity"],
               "val_f1-score":  metrics_dict["F1 Score"],
              })
        
    if epoch % 10 == 0:
        save_model(ckpt_dir='./checkpoint/log', model=model, optim=optimizer, epoch=epoch)
    
#     Add data to the EarlyStopper object
    early_stopper.add_data(model, metrics_dict['Validation Loss'], metrics_dict['F1 Score'])

    # If both accuracy and loss are not improving, stop the training
    if early_stopper.stop() == 1:
        break

    # if only loss is not improving, lower the learning rate
    if early_stopper.stop() == 3:

        for param_group in optimizer.param_groups:
            lr *= 0.1
            param_group['lr'] = lr
            print('Updating the learning rate to {}'.format(lr))
            early_stopper.reset()
            

In [None]:
model = load_model(f'./checkpoint/best/best_model-{model_name}.pt', model)
    
metrics_dict = compute_metrics(model, test_loader)
print('------------------- Test Performance --------------------------------------')
print("Accuracy \t {:.3f}".format(metrics_dict['Accuracy']))
print("Sensitivity \t {:.3f}".format(metrics_dict['Sensitivity']))
print("Specificity \t {:.3f}".format(metrics_dict['Specificity']))
print("F1 Score \t {:.3f}".format(metrics_dict['F1 Score']))
print("---------------------------------------------------------------------------")