In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/Colab\ Notebooks
![ ! -d "/content/drive/MyDrive/Colab Notebooks/COPM-Project/" ] && git clone https://github.com/egilltor17/COPM-Project.git
%cd /content/drive/MyDrive/Colab\ Notebooks/COPM-Project/
!git checkout MICCAI
!pip install -r requirements.txt

In [None]:
# !watch nvidia-smi
!pip install wandb
import wandb
wandb.init()

In [None]:
import os
import sys
sys.path.append(os.path.split(sys.path[0])[0])

from time import time
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from dataset.dataset import Dataset

from loss.Dice import DiceLoss
from loss.ELDice import ELDiceLoss
from loss.WBCE import WCELoss
from loss.Jaccard import JaccardLoss
from loss.SS import SSLoss
from loss.Tversky import TverskyLoss
from loss.Hybrid import HybridLoss
from loss.BCE import BCELoss

from net import net

import parameter as para

step_list = [0]
loss_plot = []

# Set enviroment variable for GPU
os.environ['CUDA_VISIBLE_DEVICES'] = para.gpu
cudnn.benchmark = para.cudnn_benchmark

# Load Network
# net = torch.nn.DataParallel(net).cuda()
net = net.cuda()
model_path = ""
start_epoch = 0
if len(model_path) > 0:
    start_epoch = int(re.search("_net(\d+).*", model_path).group(1))
    net.load_state_dict(torch.load(para.module_path + model_path))
    net.eval()
net.train()

# Load Dateset
train_ds = Dataset(para.training_set_path, para.training_set_path)

# train_dl = DataLoader(dataset=train_ds, batch_size=para.batch_size, shuffle=True, num_workers=para.num_workers, pin_memory=para.pin_memory)
train_dl = DataLoader(dataset=train_ds, batch_size=1, shuffle=True, pin_memory=False)
print("Nr of training samples:", len(train_dl))

# Loss functions
loss_func_list = [DiceLoss(), ELDiceLoss(), WCELoss(), JaccardLoss(), SSLoss(), TverskyLoss(), HybridLoss(), BCELoss()]
loss_func = loss_func_list[5]

# Define Optimizer
opt = torch.optim.Adam(net.parameters(), lr=para.learning_rate)

# Learning rate decay
lr_decay = torch.optim.lr_scheduler.MultiStepLR(opt, para.learning_rate_decay)

# In-depth supervision attenuation coefficient
alpha = para.alpha

In [None]:
# Training the network
print(finally"Training epochs: {start_epoch}-{para.Epoch}")
start = time()
for epoch in range(start_epoch, para.Epoch+1):
    mean_loss = []
    for step, (ct, seg) in enumerate(train_dl):
        # Half input resolution
        s = np.array(range(0, ct.shape(-1), 2))
        ct = ct[:,:,:,s,:]
        ct = ct[:,:,:,:,s]

        ct = ct.cuda()
        seg = seg.cuda()

        opt.zero_grad()
        outputs = net(ct)
    
        loss1 = loss_func(outputs[0], seg)
        loss2 = loss_func(outputs[1], seg)
        loss3 = loss_func(outputs[2], seg)
        loss4 = loss_func(outputs[3], seg)
        loss = (loss1 + loss2 + loss3) * alpha + loss4
        mean_loss.append(loss4.item())

        # opt.zero_grad()
        loss.backward()
        opt.step()

        if step % 5 is 0:
            step_list.append(step_list[-1] + 1)
            loss_plot.append(loss4.item())

            print('epoch:{}, step:{}, loss1:{:.3f}, loss2:{:.3f}, loss3:{:.3f}, loss4:{:.3f}, time:{:.3f} min'
                  .format(epoch, step, loss1.item(), loss2.item(), loss3.item(), loss4.item(), (time() - start) / 60))
            
    # Save model
    if epoch % 25 is 0 and epoch is not 0:
        torch.save(net.state_dict(), para.module_path + '_net{}-{:.3f}-{:.3f}.pth'.format(epoch, loss, sum(mean_loss) / len(mean_loss)))

    # Attenuate the depth supervision coefficient
    if epoch % 40 is 0 and epoch is not 0:
        alpha *= 0.8

    lr_decay.step()

In [None]:
import matplotlib.pyplot as plt

plt.plot(loss_plot)
plt.show()

In [None]:
!pwd
!git status
!echo -e "[user]\n\tname = egilltor17\n\temail = egilltor17@ru.is" > ~/.gitconfig
!git add .
!git commit -m "Changes from Google Colab"
#!git pull origin MICCAI
!git push origin MICCAI

In [None]:
!pip install --upgrade scikit-image

In [5]:
### Validation ###
import os
import copy
import collections
from time import time

import torch
import numpy as np
import pandas as pd
import scipy.ndimage as ndimage
import SimpleITK as sitk
import skimage
import skimage.measure as measure
import skimage.morphology as morphology


from net import ResUNet as UNet
from utilities.calculate_metrics import Metirc

import parameter as para

os.environ['CUDA_VISIBLE_DEVICES'] = para.gpu

# In order to calculate the two variables defined by dice_global
dice_intersection = 0.0  
dice_union = 0.0

file_name = []
time_pre_case = []

# Loss functions
liver_score = collections.OrderedDict()
liver_score['dice'] = []
liver_score['jacard'] = []
liver_score['voe'] = []
liver_score['fnr'] = []
liver_score['fpr'] = []
liver_score['assd'] = []
liver_score['rmsd'] = []
liver_score['msd'] = []

In [None]:
# Define network and load parameters
# net = torch.nn.DataParallel(UNet(training=False)).cuda()
model_path = "net1000-1.792-0.558.pth"
net = UNet(training=False).cuda()
net.load_state_dict(torch.load(para.module_path + model_path))
net.eval()

for file_index, file in enumerate(os.listdir(para.test_ct_path)):
    start = time()
    print(file_index, start)
    file_name.append(file)

    # Read CT-volume
    ct = sitk.ReadImage(os.path.join(para.test_ct_path, file), sitk.sitkInt16)
    ct_array = sitk.GetArrayFromImage(ct)

    origin_shape = ct_array.shape
    
    # Truncate the gray value outside the threshold
    ct_array[ct_array > para.upper] = para.upper
    ct_array[ct_array < para.lower] = para.lower

    # min max Normalization
    ct_array = ct_array.astype(np.float32)
    ct_array = ct_array / 200

    # Interpolate CT using bicubic algorithm, the array after interpolation is still int16
    ct_array = ndimage.zoom(ct_array, (1, para.down_scale, para.down_scale), order=3)

    # Use padding for data with too few slices
    too_small = False
    if ct_array.shape[0] < para.size:
        depth = ct_array.shape[0]
        temp = np.ones((para.size, int(512 * para.down_scale), int(512 * para.down_scale))) * para.lower
        temp[0: depth] = ct_array
        ct_array = temp 
        too_small = True

    # Sliding window sampling prediction
    start_slice = 0
    end_slice = start_slice + para.size - 1
    count = np.zeros((ct_array.shape[0], 512, 512), dtype=np.int16)
    probability_map = np.zeros((ct_array.shape[0], 512, 512), dtype=np.float32)

    with torch.no_grad():
        while end_slice < ct_array.shape[0]:

            ct_tensor = torch.FloatTensor(ct_array[start_slice: end_slice + 1]).cuda()
            ct_tensor = ct_tensor.unsqueeze(dim=0).unsqueeze(dim=0)

            outputs = net(ct_tensor)

            count[start_slice: end_slice + 1] += 1
            probability_map[start_slice: end_slice + 1] += np.squeeze(outputs.cpu().detach().numpy())

            # Due to insufficient video memory, the ndarray data is directly retained here, and the calculation graph is directly destroyed after saving
            del outputs      
            
            start_slice += para.stride
            end_slice = start_slice + para.size - 1
    
        if end_slice != ct_array.shape[0] - 1:
            end_slice = ct_array.shape[0] - 1
            start_slice = end_slice - para.size + 1

            ct_tensor = torch.FloatTensor(ct_array[start_slice: end_slice + 1]).cuda()
            ct_tensor = ct_tensor.unsqueeze(dim=0).unsqueeze(dim=0)
            outputs = net(ct_tensor)

            count[start_slice: end_slice + 1] += 1
            probability_map[start_slice: end_slice + 1] += np.squeeze(outputs.cpu().detach().numpy())

            del outputs
        
        pred_seg = np.zeros_like(probability_map)
        pred_seg[probability_map >= (para.threshold * count)] = 1

        if too_small:
            temp = np.zeros((depth, 512, 512), dtype=np.float32)
            temp += pred_seg[0: depth]
            pred_seg = temp

    # Read the ground truth into memory
    seg = sitk.ReadImage(os.path.join(para.test_seg_path, file.replace('volume', 'segmentation')), sitk.sitkUInt8)
    seg_array = sitk.GetArrayFromImage(seg)
    seg_array[seg_array > 0] = 1

    # Extract the largest connected domain of the liver, remove small areas, and fill in internal holes
    pred_seg = pred_seg.astype(np.uint8)
    liver_seg = copy.deepcopy(pred_seg)
    liver_seg = measure.label(liver_seg, 4)
    props = measure.regionprops(liver_seg)
    
    max_area = 0
    max_index = 0
    for index, prop in enumerate(props, start=1):
        if prop.area > max_area:
            max_area = prop.area
            max_index = index
    
    liver_seg[liver_seg != max_index] = 0
    liver_seg[liver_seg == max_index] = 1
    
    liver_seg = liver_seg.astype(np.bool)
    morphology.remove_small_holes(liver_seg, para.maximum_hole, connectivity=2, in_place=True)
    liver_seg = liver_seg.astype(np.uint8)

    # Calculate segmentation evaluation index
    liver_metric = Metirc(seg_array, liver_seg, ct.GetSpacing())

    liver_score['dice'].append(liver_metric.get_dice_coefficient()[0])
    liver_score['jacard'].append(liver_metric.get_jaccard_index())
    liver_score['voe'].append(liver_metric.get_VOE())
    liver_score['fnr'].append(liver_metric.get_FNR())
    liver_score['fpr'].append(liver_metric.get_FPR())
    liver_score['assd'].append(liver_metric.get_ASSD())
    liver_score['rmsd'].append(liver_metric.get_RMSD())
    liver_score['msd'].append(liver_metric.get_MSD())

    dice_intersection += liver_metric.get_dice_coefficient()[1]
    dice_union += liver_metric.get_dice_coefficient()[2]

    # Save the prediction as .nii
    pred_seg = sitk.GetImageFromArray(liver_seg)

    pred_seg.SetDirection(ct.GetDirection())
    pred_seg.SetOrigin(ct.GetOrigin())
    pred_seg.SetSpacing(ct.GetSpacing())

    sitk.WriteImage(pred_seg, os.path.join(para.pred_path, file.replace('volume', 'pred')))

    speed = time() - start
    time_pre_case.append(speed)

    print(file_index, 'this case use {:.3f} s'.format(speed))
    print('-----------------------')

In [None]:
# Write evaluation indicators into exel
liver_data = pd.DataFrame(liver_score, index=file_name)
liver_data['time'] = time_pre_case

liver_statistics = pd.DataFrame(index=['mean', 'std', 'min', 'max'], columns=list(liver_data.columns))
liver_statistics.loc['mean'] = liver_data.mean()
liver_statistics.loc['std'] = liver_data.std()
liver_statistics.loc['min'] = liver_data.min()
liver_statistics.loc['max'] = liver_data.max()

writer = pd.ExcelWriter('./result.xlsx')
liver_data.to_excel(writer, 'liver')
liver_statistics.to_excel(writer, 'liver_statistics')
writer.save()

# Print dice global
print('dice global:', dice_intersection / dice_union)