In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from IPython.display import clear_output
import nibabel as nib
import glob
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm
from fastprogress import master_bar, progress_bar
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
clear_output()
NUM_CLASS = 4

def center_crop(img, crop_size = (160, 160)):
    img_crop = np.zeros(crop_size)
    w_in, h_in, d_in = img.shape
    img_crop = np.zeros((*crop_size,d_in))
    w_out, h_out = crop_size
    sub_w = max((w_in - w_out)//2-20, 0)
    sub_h = max((h_in - h_out)//2-10, 0)

    img_clone = img[sub_w:sub_w+w_out, sub_h:sub_h+h_out]
    img_crop[:img_clone.shape[0], :img_clone.shape[1]] = img_clone
    return img_crop

def crop2img(crop_image, initial_image):
    output = np.zeros_like(initial_image)
    w_in, h_in, _ = initial_image.shape
    w_out, h_out, _ = crop_image.shape
    sub_w = max((w_in - w_out)//2-20, 0)
    sub_h = max((h_in - h_out)//2-10, 0)
    output[sub_w:sub_w+w_out, sub_h:sub_h+h_out] = crop_image
    return output

def min_max_preprocess(image, low_perc=1, high_perc=99):
    """Main pre-processing function used for the challenge (seems to work the best).
    Remove outliers voxels first, then min-max scale.
    """
    image = np.array(image)
    non_zeros = image > 0
    low, high = np.percentile(image[non_zeros], [low_perc, high_perc])
    image = np.clip(image, low, high)
    image = (image - low) / (high - low)
    return image.astype(np.float32)

In [None]:
# %cd /content/drive/MyDrive/dataACDCA
# all_files = sorted(glob.glob("./training/*"))
# np.random.seed(42)
# np.random.shuffle(all_files)
# train_count = 70
# train_files = all_files[:train_count]
# test_files = all_files[train_count : train_count + 10]+all_files[-10:]
# valid_files = all_files[train_count + 10: train_count + 20]

#def get_data(): #list_file
#     x_list = []
#     y_list = []
#     for files in tqdm(list_file):
#         list_image = [x for x in glob.glob(files+"/*") if x.find('frame') != -1 and x.find('gt') == -1]
#         for image_name in list_image:
#             num = image_name.find("nii")
#             mask_name = image_name[:num-1] +"_gt.nii.gz"
#             image = nib.load(image_name).get_fdata().astype(np.uint16)
#             label = nib.load(mask_name).get_fdata().astype(np.uint8)
#             image = center_crop(image)
#             label = center_crop(label)
#             for z in range(image.shape[-1]):
#                 sub_image = image[...,z]
#                 sub_label = label[...,z]
#                 y_list.append(sub_label)
#                 x_list.append(sub_image)
#    images = np.load('/content/drive/MyDrive/MRI_ACDC/Dataset2/Sunny_Brook_MRI/x_test_crop_128_endo_sun09.npy') #np.asarray(x_list)
#    masks = np.load('/content/drive/MyDrive/MRI_ACDC/Dataset2/Sunny_Brook_MRI/y_test_crop_128_endo_sun09.npy') #np.asarray(y_list)
#    return images, masks
#x_train, y_train =  get_data()
# x_val, y_val =  get_data(valid_files)
# x_test, y_test =  get_data(test_files)

#np.savez_compressed('/content/drive/MyDrive/MRI_ACDC/Dataset2/Sunny_Brook_MRI/test_crop_128_endo_sun09', image=x_train, mask=y_train)
# np.savez_compressed('ACDC_val160', image=x_val, mask=y_val)
# np.savez_compressed('ACDC_test160', image=x_test, mask=y_test)

In [None]:
# data = np.load("/content/drive/MyDrive/dataACDCA/ACDC_train160.npz")
# x_train, y_train = data["image"], data["mask"]
# %cd /content/drive/MyDrive/dataACDCA
# x_list, y_list = [], []
# for i in tqdm(range(x_train.shape[0])):
#     image, mask = Image.fromarray(x_train[i]),  Image.fromarray(y_train[i])
#     x_list.append(image)
#     y_list.append(mask)

#     sub_image = image.transpose(Image.FLIP_TOP_BOTTOM)
#     sub_label = mask.transpose(Image.FLIP_TOP_BOTTOM)
#     y_list.append(sub_label)
#     x_list.append(sub_image)

#     sub_image = image.transpose(Image.FLIP_LEFT_RIGHT)
#     sub_label = mask.transpose(Image.FLIP_LEFT_RIGHT)
#     y_list.append(sub_label)
#     x_list.append(sub_image)

#     degree = np.random.uniform(-30,30)
#     sub_image = image.rotate(degree, Image.NEAREST)
#     sub_label = mask.rotate(degree, Image.NEAREST)
#     y_list.append(sub_label)
#     x_list.append(sub_image)
# images = np.stack(x_list)
# masks = np.stack(y_list)
# np.savez_compressed("ACDC_train_aug160", image=images, mask=masks)

In [None]:
class ACDCLoader(Dataset):
    def __init__(self, images, masks,
                 transform=True, typeData = "train"):
        self.transform = transform if typeData == "train" else False  # augment data bool
        self.typeData = typeData
        self.images = images
        self.masks = masks
    def __len__(self):
        return len(self.images)

    def rotate(self, image, mask, degrees=(-30,30), p=0.3):
        if torch.rand(1) < p:
            degree = np.random.uniform(*degrees)
            image = image.rotate(degree, Image.NEAREST)
            mask = mask.rotate(degree, Image.NEAREST)
        return image, mask
    def horizontal_flip(self, image, mask, p=0.5):
        if torch.rand(1) < p:
            image = image.transpose(Image.FLIP_LEFT_RIGHT)
            mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
        return image, mask
    def vertical_flip(self, image, mask, p=0.5):
        if torch.rand(1) < p:
            image = image.transpose(Image.FLIP_TOP_BOTTOM)
            mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
        return image, mask

    def augment(self, image, mask):
        image, mask = self.rotate(image, mask)
        image, mask = self.horizontal_flip(image, mask)
        image, mask = self.vertical_flip(image, mask)
        return image, mask

    def __getitem__(self, idx):
        image = Image.fromarray(self.images[idx])
        mask = Image.fromarray(self.masks[idx])
    ####################### augmentation data ##############################
        if self.transform:
            image, mask = self.augment(image, mask)
        image = min_max_preprocess(image)
        image = torch.from_numpy(image[np.newaxis])

        mask = np.asarray(mask, np.int64)
        mask = torch.from_numpy(mask[np.newaxis])
        return image, mask

##loss

In [None]:
class SemiActiveLoss(nn.Module):
    def __init__(self, device, alpha =1e-9, beta = 1e-1, lamda = 1e-3):
        super().__init__()
        self.device = device
        self.alpha = alpha
        self.beta = beta
        self.lamda = lamda
    def LevelsetLoss(self, image, y_pred, kernel_size=5, smooth=1e-5):
        kernel = torch.ones(1, y_pred.size(1), kernel_size, kernel_size, device=self.device) / kernel_size**2
        padding = kernel_size //2
        lossRegion = 0.0
        y_pred_fuzzy = y_pred
        for ich in range(image.size(1)):
            target_ = image[:,ich:ich+1]
            pcentroid_local = F.conv2d(target_ * y_pred_fuzzy + smooth, kernel, padding = padding) \
                                / F.conv2d(y_pred_fuzzy + smooth, kernel, padding = padding)
            plevel_local = target_ - pcentroid_local
            loss_local = plevel_local * plevel_local * y_pred_fuzzy

            pcentroid_global = torch.sum(target_ * y_pred_fuzzy, dim=(2,3),keepdim=True) \
                                / torch.sum(y_pred_fuzzy+smooth, dim=(2,3),keepdim = True)
            plevel_global = target_ - pcentroid_global
            loss_global = plevel_global * plevel_global * y_pred_fuzzy

            lossRegion += torch.sum(loss_local) + self.beta * torch.sum(loss_global)
        return lossRegion
    def GradientLoss(self, y_pred, penalty = "l1"):
        dH = torch.abs(y_pred[...,1:] - y_pred[...,:-1])
        dW = torch.abs(y_pred[:,:,1:] - y_pred[:,:,:-1])
        if penalty == "l2":
            dH = dH * dH
            dW = dW * dW
        loss =  torch.sum(dH) +  torch.sum(dW)
        return loss
    def ActiveContourLoss(self, y_true, y_pred, smooth=1e-5):
        dim = (1,2,3)
        yTrueOnehot = torch.zeros(y_true.size(0), NUM_CLASS, y_true.size(2), y_true.size(3), device=self.device)
        yTrueOnehot = torch.scatter(yTrueOnehot, 1, y_true, 1)[:,1:]
        y_pred = y_pred[:,1:]

        active = - torch.log(1-y_pred+smooth) * (1-yTrueOnehot) - torch.log(y_pred+smooth) * yTrueOnehot
        loss = torch.sum(active, dim = dim) / torch.sum(yTrueOnehot + y_pred - yTrueOnehot * y_pred +smooth, dim = dim)
        return torch.mean(loss)

    def forward(self, image, y_true, y_pred):
        active = self.ActiveContourLoss(y_true, y_pred)
        levelset =  self.LevelsetLoss(image, y_pred)
        length = self.GradientLoss(y_pred)
        return active + self.alpha * (levelset + self.lamda * length)

class CrossEntropy(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.device = device
    def forward(self, y_true, y_pred):
        yTrueOnehot = torch.zeros(y_true.size(0), NUM_CLASS, y_true.size(2), y_true.size(3), device=self.device)
        yTrueOnehot = torch.scatter(yTrueOnehot, 1, y_true, 1)

        loss = torch.sum(-yTrueOnehot * torch.log(y_pred+1e-10))
        return loss / (y_true.size(0) * y_true.size(2) * y_true.size(3))
class DiceLoss(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.device = device
    def forward(self, y_true, y_pred):
        yTrueOnehot = torch.zeros(y_true.size(0), NUM_CLASS, y_true.size(2), y_true.size(3), device=self.device)
        yTrueOnehot = torch.scatter(yTrueOnehot, 1, y_true, 1)[:, 1:]
        y_pred = y_pred[:,1:]

        intersection = torch.sum(yTrueOnehot * y_pred, dim=[1,2,3])
        cardinality  = torch.sum(yTrueOnehot + y_pred , dim=[1,2,3])
        loss = 1.0-torch.mean((2. * intersection + 1e-5) / (cardinality + 1e-5))
        return loss

##metrics

In [None]:
def dice_rv(y_true, y_pred, smooth = 1e-4):
    y_pred = torch.argmax(y_pred, dim=1, keepdim = True)
    y_pred = torch.where(y_pred == 1, 1, 0)
    y_true = torch.where(y_true == 1, 1, 0)
    intersection = torch.sum(y_true * y_pred, dim=[1,2,3])
    cardinality  = torch.sum(y_true + y_pred , dim=[1,2,3])
    return torch.mean((2. * intersection + smooth) / (cardinality + smooth), dim=0)

def dice_myo(y_true, y_pred, smooth = 1e-4):
    y_pred = torch.argmax(y_pred, dim=1, keepdim = True)
    y_pred = torch.where(y_pred == 2, 1, 0)
    y_true = torch.where(y_true == 2, 1, 0)
    intersection = torch.sum(y_true * y_pred, dim=[1,2,3])
    cardinality  = torch.sum(y_true + y_pred , dim=[1,2,3])
    return torch.mean((2. * intersection + smooth) / (cardinality + smooth), dim=0)

def dice_lv(y_true, y_pred, smooth = 1e-4):
    y_pred = torch.argmax(y_pred, dim=1, keepdim = True)
    y_pred = torch.where(y_pred == 3, 1, 0)
    y_true = torch.where(y_true == 3, 1, 0)
    intersection = torch.sum(y_true * y_pred, dim=[1,2,3])
    cardinality  = torch.sum(y_true + y_pred , dim=[1,2,3])
    return torch.mean((2. * intersection + smooth) / (cardinality + smooth), dim=0)

def jac_rv(y_true, y_pred, smooth = 1e-4):
    y_pred = torch.argmax(y_pred, dim=1, keepdim = True)
    y_pred = torch.where(y_pred == 1, 1, 0)
    y_true = torch.where(y_true == 1, 1, 0)
    intersection = torch.sum(y_true * y_pred, dim=[1,2,3])
    cardinality  = torch.sum(y_true + y_pred , dim=[1,2,3])
    return torch.mean((1. * intersection + smooth) / (cardinality - intersection + smooth), dim=0)

def jac_myo(y_true, y_pred, smooth = 1e-4):
    y_pred = torch.argmax(y_pred, dim=1, keepdim = True)
    y_pred = torch.where(y_pred == 2, 1, 0)
    y_true = torch.where(y_true == 2, 1, 0)
    intersection = torch.sum(y_true * y_pred, dim=[1,2,3])
    cardinality  = torch.sum(y_true + y_pred , dim=[1,2,3])
    return torch.mean((1. * intersection + smooth) / (cardinality -intersection + smooth), dim=0)

def jac_lv(y_true, y_pred, smooth = 1e-4):
    y_pred = torch.argmax(y_pred, dim=1, keepdim = True)
    y_pred = torch.where(y_pred == 3, 1, 0)
    y_true = torch.where(y_true == 3, 1, 0)
    intersection = torch.sum(y_true * y_pred, dim=[1,2,3])
    cardinality  = torch.sum(y_true + y_pred , dim=[1,2,3])
    return torch.mean((1. * intersection + smooth) / (cardinality - intersection + smooth), dim=0)

##optimizer

In [None]:
import math
import torch
from torch.optim.optimizer import Optimizer


class Nadam(Optimizer):
    """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
    It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 2e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        schedule_decay (float, optional): momentum schedule decay (default: 4e-3)
    __ http://cs229.stanford.edu/proj2015/054_report.pdf
    __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf
        Originally taken from: https://github.com/pytorch/pytorch/pull/1408
        NOTE: Has potential issues but does work well on some problems.
    """

    def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0, schedule_decay=4e-3):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        defaults = dict(
            lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay)
        super(Nadam, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['m_schedule'] = 1.
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)

                # Warming momentum schedule
                m_schedule = state['m_schedule']
                schedule_decay = group['schedule_decay']
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']
                eps = group['eps']
                state['step'] += 1
                t = state['step']
                bias_correction2 = 1 - beta2 ** t

                if group['weight_decay'] != 0:
                    grad = grad.add(p, alpha=group['weight_decay'])

                momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay)))
                momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
                m_schedule_new = m_schedule * momentum_cache_t
                m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1
                state['m_schedule'] = m_schedule_new

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2)

                denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
                p.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new))
                p.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next))

        return loss