In [None]:
import os
import shutil
import tempfile
import pandas as pd
import time
import gc
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import monai
import nibabel as nib
from monai.utils import set_determinism, first

from monai.transforms import *

from monai.config import print_config
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric

from monai.data import (
    DataLoader,
    Dataset,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)

from skimage import measure
from sklearn.metrics import roc_auc_score

import torch
import glob
print_config()

In [None]:
# BG seg

In [None]:
# dataloader passed!
# Using a custom data path and dataset

In [None]:
model = monai.networks.nets.BasicUNet(spatial_dims=3, 
                                      in_channels=1,
                                      out_channels=2,
                                      features=(32, 32, 64, 128, 256, 32),
                                     )
device = "cuda:0"
model.to(device)

In [None]:
import torch.nn as nn
"""
lossfunction.

References
----------
Gros C et al., MedIA (2021). DOI: https://doi.org/10.1016/j.media.2021.102038
"""

class AWing(nn.Module):

    def __init__(self, alpha=2.1, omega=8, epsilon=1, theta=0.5):
        super().__init__()
        self.alpha   = float(alpha)
        self.omega   = float(omega)
        self.epsilon = float(epsilon)
        self.theta   = float(theta)

    def forward(self, y_pred , y):
        lossMat = torch.zeros_like(y_pred)
        A = self.omega * (1/(1+(self.theta/self.epsilon)**(self.alpha-y)))*(self.alpha-y)*((self.theta/self.epsilon)**(self.alpha-y-1))/self.epsilon
        C = self.theta*A - self.omega*torch.log(1+(self.theta/self.epsilon)**(self.alpha-y))
        case1_ind = torch.abs(y-y_pred) < self.theta
        case2_ind = torch.abs(y-y_pred) >= self.theta
        lossMat[case1_ind] = self.omega*torch.log(1+torch.abs((y[case1_ind]-y_pred[case1_ind])/self.epsilon)**(self.alpha-y[case1_ind]))
        lossMat[case2_ind] = A[case2_ind]*torch.abs(y[case2_ind]-y_pred[case2_ind]) - C[case2_ind]
        return lossMat

class Loss_weighted(nn.Module):
    def __init__(self, W=10, alpha=2.1, omega=8, epsilon=1, theta=0.5):
        super().__init__()
        self.W = float(W)
        self.Awing = AWing(alpha, omega, epsilon, theta)

    def forward(self, y_pred, y, M):
        M = M.float()
        Loss = self.Awing(y_pred,y)
        weighted = Loss * (self.W * M + 1.)
        return weighted.mean()

In [None]:
loss_function = Loss_weighted(W=10)
torch.backends.cudnn.benchmark = True

In [None]:
ReLU = torch.nn.ReLU()

def normReLU(input):
    if torch.max(ReLU(input)) != 0:
        output = (ReLU(input)) / (torch.max(ReLU(input)))
    else: output = torch.zeros(input.shape)
    return output

In [None]:
def validation(epoch_iterator_val):
    model.eval()
    dice_vals = list()
    with torch.no_grad():
        for step, batch in enumerate(epoch_iterator_val):
            val_inputs, val_labels = (batch["t2"]).to(device), batch["seg"].to(device)
            # val_labels[val_labels > 0] = 1
            
            val_outputs = model(val_inputs)
            for b in range(val_outputs.shape[0]):
                for c in range(val_outputs.shape[1]):
                    val_outputs[b,c] = normReLU(val_outputs[b,c])

            val_outputs_bimask = (val_outputs[:,1] > 0.42).unsqueeze(1)
            
            if step==0:
                plt.figure(dpi=256)
                plt.subplot(231)
                plt.imshow(torch.sum(val_inputs[0,0].cpu(),axis=2),cmap='gray')
                plt.axis('off')
                plt.title('input')
                plt.subplot(232)
                plt.imshow(torch.sum(val_outputs[0,1].detach().cpu(),axis=2),cmap='gray')
                plt.axis('off')
                plt.title('output')
                plt.subplot(233)
                plt.imshow(torch.sum(val_labels[0,0].detach().cpu(),axis=2),cmap='gray')
                plt.axis('off')
                plt.title('label')
                plt.subplot(234)
                plt.imshow(torch.sum(val_inputs[1,0].cpu(),axis=2),cmap='gray')
                plt.axis('off')
                plt.title('input')
                plt.subplot(235)
                plt.imshow(torch.sum(val_outputs[1,1].detach().cpu(),axis=2),cmap='gray')
                plt.axis('off')
                plt.title('output')
                plt.subplot(236)
                plt.imshow(torch.sum(val_labels[1,0].detach().cpu(),axis=2),cmap='gray')
                plt.axis('off')
                plt.title('label')
                plt.savefig(os.path.join(log_dir, 'validation', f'valid_step{global_step}.png'))
                plt.show()         
            val_labels_list = decollate_batch(val_labels)
            val_labels_convert = [
                post_label(val_label_tensor) for val_label_tensor in val_labels_list
            ]
            val_outputs_list = decollate_batch(val_outputs_bimask)
            val_output_convert = [
                post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list
            ]
            dice_metric(y_pred=val_output_convert, y=val_labels_convert)
            dice = dice_metric.aggregate().item()
            dice_vals.append(dice)
            epoch_iterator_val.set_description(
                "Validate (%d / %d Steps) (dice=%2.5f)" % (global_step, 10.0, dice)
            )
                  
        dice_metric.reset()
    mean_dice_val = np.mean(dice_vals)
    return mean_dice_val

In [None]:
def train(global_step, train_loader, dice_val_best, global_step_best):
    model.train()
    epoch_loss = 0
    step = 0
    epoch_iterator = tqdm(
        train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True
    )
    for step, batch in enumerate(epoch_iterator):
        step += 1
        x, y = (batch["t2"]).to(device), batch["seg"].to(device)
        
        mask = y.clone()
        mask[mask>0] = 1
    
        logit_map = model(x)
        
        for b in range(logit_map.shape[0]):
            for c in range(logit_map.shape[1]):
                logit_map[b,c] = normReLU(logit_map[b,c])

        # loss = (loss_function(logit_map[:,1:], y)).mean()
        loss = loss_function(logit_map[:,1:], y, mask)
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        epoch_iterator.set_description(
            "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss)
        )
        if (
            global_step % eval_num == 0 and global_step != 0
        ) or global_step == max_iterations:
            epoch_iterator_val = tqdm(
                valid_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True
            )
            dice_val = validation(epoch_iterator_val)
            metric_values.append(dice_val)
            
            plt.figure(1, figsize=(12,8))
            plt.plot(metric_values)
            plt.xlabel(f"{eval_num}")
            plt.savefig(os.path.join(log_dir, "Dice_plot.png"))
            plt.close(1)
            
            if dice_val > dice_val_best:
                dice_val_best = dice_val
                global_step_best = global_step
                torch.save(
                    model.state_dict(), os.path.join(log_dir, "best_model.pth")
                )
                print(
                    "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        dice_val_best, dice_val
                    )
                )
            else:
                print(
                    "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        dice_val_best, dice_val
                    )
                    
                )
    
        global_step += 1
        
        if global_step%100==0:        
            scheduler.step()
    
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)

    plt.figure(1, figsize=(12, 8))
    plt.plot(epoch_loss_values)
    plt.savefig(os.path.join(log_dir,"loss.png"))
    plt.close(1)
    
    file = open(f'{log_dir}/train.txt', 'a')
    file.write(f'global_step : {str(global_step)}\n')
    file.write('current dice best : ')
    file.write(str(dice_val_best)+'\n')
        


    return global_step, dice_val_best, global_step_best

In [None]:
max_iterations = 6000
eval_num = 64
post_label = AsDiscrete(to_onehot=2)
post_pred = AsDiscrete(argmax=False, to_onehot=2)
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer,
                                              lr_lambda=lambda epoch: 0.95 ** epoch,
                                        last_epoch=-1,
                                        verbose=True)

while global_step < max_iterations:
    global_step, dice_val_best, global_step_best = train(
        global_step, train_loader, dice_val_best, global_step_best
    )

In [None]:
# PVS seg
# using training label after correction.

In [None]:
import os
import shutil
import tempfile
import pandas as pd
import time
import gc
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import monai
import nibabel as nib
from monai.utils import set_determinism, first

from monai.transforms import *

from monai.config import print_config
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric

from monai.data import (
    DataLoader,
    Dataset,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)

from skimage import measure
from sklearn.metrics import roc_auc_score

import torch
import glob
print_config()

In [None]:
# dataloader passed!
# Using a custom data path and dataset

In [None]:
model = monai.networks.nets.BasicUNet(spatial_dims=3, 
                                      in_channels=1,
                                      out_channels=2,
                                      features=(32, 32, 64, 128, 256, 32),
                                     )

In [None]:
device= "cuda:0"
model.to(device)

In [None]:
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
torch.backends.cudnn.benchmark = True

In [None]:
import skimage
def validation(epoch_iterator_val):
    model.eval()
    epoch_loss = 0.0
    step = 0
    dice_vals = list()
    with torch.no_grad():
        for step, batch in enumerate(epoch_iterator_val):
            step += 1
            val_inputs, val_labels =  batch["t2"].to(device), batch["seg"].to(device)
            val_labels[val_labels>0]=1

            val_outputs = model(val_inputs)
            val_outputs_bimask = (val_outputs[:,1] > 0.42).unsqueeze(1)
            
            if step==0:
                plt.figure(dpi=256)
                plt.subplot(231)
                plt.imshow(torch.sum(val_inputs[0,0].cpu(),axis=2),cmap='gray')
                plt.axis('off')
                plt.title('input')
                plt.subplot(232)
                plt.imshow(torch.sum(val_outputs[0,1].detach().cpu(),axis=2),cmap='gray')
                plt.axis('off')
                plt.title('output')
                plt.subplot(233)
                plt.imshow(torch.sum(val_labels[0,0].detach().cpu(),axis=2),cmap='gray')
                plt.axis('off')
                plt.title('label')
                plt.subplot(234)
                plt.imshow(torch.sum(val_inputs[1,0].cpu(),axis=2),cmap='gray')
                plt.axis('off')
                plt.title('input')
                plt.subplot(235)
                plt.imshow(torch.sum(val_outputs[1,1].detach().cpu(),axis=2),cmap='gray')
                plt.axis('off')
                plt.title('output')
                plt.subplot(236)
                plt.imshow(torch.sum(val_labels[1,0].detach().cpu(),axis=2),cmap='gray')
                plt.axis('off')
                plt.title('label')
                plt.savefig(os.path.join(log_dir, 'validation', f'valid_step{global_step}.png'))
                plt.show()      
                
            val_labels_list = decollate_batch(val_labels)
            val_labels_convert = [
                post_label(val_label_tensor) for val_label_tensor in val_labels_list
            ]
            val_outputs_list = decollate_batch(val_outputs_bimask)
            val_output_convert = [
                post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list
            ]
            dice_metric(y_pred=val_output_convert, y=val_labels_convert)
            dice = dice_metric.aggregate().item()
            dice_vals.append(dice)
            epoch_iterator_val.set_description(
                "Validate (%d / %d Steps) (dice=%2.5f)" % (global_step, 10.0, dice)
            )
                  
        dice_metric.reset()
    mean_dice_val = np.mean(dice_vals)
    return mean_dice_val

In [None]:
def train(global_step, train_loader, dice_val_best, global_step_best):
    model.train()
    epoch_loss = 0.0
    step = 0
    epoch_iterator = tqdm(
        train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True
    )
    for step, batch in enumerate(epoch_iterator):
        step += 1
        t2, seg = batch["t2"].to(device), batch["seg"].to(device)
        seg[seg>0]=1
        
        pred = model(t2)
        loss = loss_function(pred, seg)
        
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        epoch_iterator.set_description(
            "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss)
        )
        if (
            global_step % eval_num == 0 and global_step != 0
        ) or global_step == max_iterations:
            epoch_iterator_val = tqdm(
                valid_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True
            )
            dice_val = validation(epoch_iterator_val)
            metric_values.append(dice_val)
            
            plt.figure(1, figsize=(12,8))
            plt.plot(metric_values)
            plt.xlabel(f"{eval_num}")
            plt.savefig(os.path.join(log_dir, "dice_plot.png"))
            plt.close(1)
            
            if dice_val > dice_val_best:
                dice_val_best = dice_val
                global_step_best = global_step
                torch.save(
                    model.state_dict(), os.path.join(log_dir, "model", "best_model.pth")
                )
                print(
                    "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        dice_val_best, dice_val
                    )
                )
            else:
                print(
                    "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        dice_val_best, dice_val
                    )
                    
                )
        global_step += 1
        
        if global_step%160==0:        
            scheduler.step()
    
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)

    plt.figure(1, figsize=(12, 8))
    plt.plot(epoch_loss_values)
    plt.savefig(os.path.join(log_dir,"loss.png"))
    plt.close(1)
    
    file = open(f'{log_dir}/train.txt', 'a')
    file.write(f'global_step : {str(global_step)}\n')
    file.write('current dice best : ')
    file.write(str(dice_val_best)+'\n')
    
    return global_step, dice_val_best, global_step_best

In [None]:
max_iterations = 6000
eval_num = 64
post_label = AsDiscrete(to_onehot=2)
post_pred = AsDiscrete(argmax=False, to_onehot=2)
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer,
                                              lr_lambda=lambda epoch: 0.95 ** epoch,
                                        last_epoch=-1,
                                        verbose=True)

while global_step < max_iterations:
    global_step, dice_val_best, global_step_best = train(
        global_step, train_loader, dice_val_best, global_step_best
    )