In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os
import re
import h5py
import math
import shutil
import random
import nibabel
import tarfile
import nibabel as nib
import SimpleITK as sitk
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from tqdm import tqdm
from nilearn import plotting
from nilearn.plotting import plot_anat, plot_roi
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split


# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(filename, os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install einops
!pip install medpy
from einops import rearrange
from medpy import metric

### Preset

In [None]:
!mkdir checkpoint
!mkdir results
!mkdir data
!mkdir dataset

In [None]:
# shutil.rmtree('./dataset')

In [None]:
train_set = {
        'root': './data',    # input path to original dataset of 4 labels
        'out': './dataset',  # output path to preprocessed dataset
        'flist': '/kaggle/input/train-list/train.txt',  # training IDs
        }

# if origianal dataset folder is empty, extract the dataset from input
if len(os.listdir(train_set['root'])) == 0 and len(os.listdir(train_set['out'])) == 0:
    zip_file = tarfile.open("/kaggle/input/brats-2021-task1/BraTS2021_Training_Data.tar")
    zip_file.extractall(train_set['root'])
    zip_file.close()

### Data Preprocessing

In [None]:
# mri images of four modalities
modalities = ('flair', 't1ce', 't1', 't2')

def process_h5(path, out_path):
    """ Save the data with dtype=float32.
        z-score is used but keep the background with zero! """
    if not os.path.exists(path + 'seg.nii.gz'): return

    # SimpleITK reads images in DxHxD by default，convert it to HxWxD here
    label = sitk.GetArrayFromImage(sitk.ReadImage(path + 'seg.nii.gz')).transpose(1,2,0)
    print(label.shape)
    # stack images of four modalities 4 x (H,W,D) -> (4,H,W,D) 
    images = np.stack([sitk.GetArrayFromImage(sitk.ReadImage(path + modal + '.nii.gz')).transpose(1,2,0) for modal in modalities], 0)  # [240,240,155]
    # datatype converting
    label = label.astype(np.uint8)
    images = images.astype(np.float32)
    case_name = path.split('/')[-1]
    # case_name = os.path.split(path)[-1]  # different paths for windows and linux
    
    path = os.path.join(out_path,case_name)
    path_to_rm = os.path.join('./data',case_name)[:-1]
    output = path + 'mri_norm2.h5'
    # print('path_to_rm', path_to_rm)
    if os.path.exists(output):
        shutil.rmtree(path_to_rm)
        return
    
    # sum up the first channel, if all four modalities are 0, mark it as background (False):
    mask = images.sum(0) > 0
    for k in range(4):

        x = images[k,...]  #
        y = x[mask]

        # normalize the region outside the background
        x[mask] -= y.mean()
        x[mask] /= y.std()

        images[k,...] = x
    print(case_name, images.shape, label.shape)
    f = h5py.File(output, 'w')
    f.create_dataset('image', data=images, compression="gzip")
    f.create_dataset('label', data=label, compression="gzip")
    f.close()
    
    # remove the original dataset to save space
    shutil.rmtree(path_to_rm)


def pre_data(dset):
    root, out_path = dset['root'], dset['out']
    file_list = os.path.join(root, dset['flist'])
    subjects = open(file_list).read().splitlines()
    names = ['BraTS2021_' + sub for sub in subjects]
    names = random.sample(names, 125)
    paths = [os.path.join(root, name, name + '_') for name in names]

    for path in tqdm(paths):
        process_h5(path, out_path)

    print('Finished')

pre_data(train_set)

In [None]:
if torch.cuda.is_available():
    print("CUDA is available. GPU will be used for training.")
    device = torch.device("cuda")
else:
    print("CUDA is not available. Training will be on CPU.")
    device = torch.device("cpu")

In [None]:
train_and_test_set_ids = os.listdir(train_set['out'])
train_and_test_ids = [i[:15] for i in train_and_test_set_ids]

# randomly select 125 groups of sata from 1251 groups
train_and_test_ids = random.sample(train_and_test_ids, 125)

train_ids, val_test_ids = train_test_split(train_and_test_ids, test_size=0.2,random_state=21)
val_ids, test_ids = train_test_split(val_test_ids, test_size=0.5,random_state=21)
print("Using {} images for training, {} images for validation, {} images for testing.".format(len(train_ids),len(val_ids),len(test_ids)))

train_ids.sort()
val_ids.sort()
test_ids.sort()

with open('./train.txt','w') as f:
    f.write('\n'.join(train_ids))

with open('./valid.txt','w') as f:
    f.write('\n'.join(val_ids))

with open('./test.txt','w') as f:
    f.write('\n'.join(test_ids))

### Models

In [None]:
# U-Net
class InConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(InConv, self).__init__()
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool3d(2, 2),
            DoubleConv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x

class OutConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(OutConv, self).__init__()
        self.conv = nn.Conv3d(in_ch, out_ch, 1)
        # self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv(x)
        # x = self.sigmoid(x)
        return x

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x

class Up(nn.Module):
    def __init__(self, in_ch, skip_ch,out_ch):
        super(Up, self).__init__()
        self.up = nn.ConvTranspose3d(in_ch, in_ch, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_ch+skip_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class UNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(UNet, self).__init__()
        features = [32,64,128,256]

        self.inc = InConv(in_channels, features[0])
        self.down1 = Down(features[0], features[1])
        self.down2 = Down(features[1], features[2])
        self.down3 = Down(features[2], features[3])
        self.down4 = Down(features[3], features[3])

        self.up1 = Up(features[3], features[3], features[2])
        self.up2 = Up(features[2], features[2], features[1])
        self.up3 = Up(features[1], features[1], features[0])
        self.up4 = Up(features[0], features[0], features[0])
        self.outc = OutConv(features[0], num_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x

In [None]:
# Attention U-Net
class AttentionBlock(nn.Module):
    def __init__(self, in_channels_x, in_channels_g, int_channels):
        super(AttentionBlock, self).__init__()
        self.Wx = nn.Sequential(nn.Conv3d(in_channels_x, int_channels, kernel_size=1),
                                nn.BatchNorm3d(int_channels))
        self.Wg = nn.Sequential(nn.Conv3d(in_channels_g, int_channels, kernel_size=1),
                                nn.BatchNorm3d(int_channels))
        self.psi = nn.Sequential(nn.Conv3d(int_channels, 1, kernel_size=1),
                                 nn.BatchNorm3d(1),
                                 nn.Sigmoid())

    def forward(self, x, g):
        # apply the Wx to the skip connection
        x1 = self.Wx(x)
        g1 = self.Wg(g)
        out = self.psi(nn.ReLU(inplace=True)(x1 + g1))
        return out * x


class AttentionUpBlock(nn.Module):
    def __init__(self, in_channels_x, in_channels_g, out_channels):
        super(AttentionUpBlock, self).__init__()
        # self.upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)
        self.attention = AttentionBlock(in_channels_x, in_channels_g, in_channels_g)
        self.conv_bn1 = DoubleConv(in_channels_g * 2, out_channels)
        self.conv_bn2 = DoubleConv(out_channels, out_channels)

    def forward(self, x, x_skip):
        # note : x_skip is the skip connection and x is the input from the previous block
        # apply the attention block to the skip connection, using x as context

        x = nn.functional.interpolate(x, x_skip.shape[2:], mode='trilinear', align_corners=False)
        x_attention = self.attention(x_skip, x)

        # stack their channels to feed to both convolution blocks
        x = torch.cat((x_attention, x), dim=1)
        x = self.conv_bn1(x)
        return self.conv_bn2(x)


class AttentionUNet(nn.Module):
    def __init__(self, in_channels, num_classes, feature_scale=4):
        super(AttentionUNet, self).__init__()
        feature = [96, 192, 384, 768, 1280]
        feature = [int(x / feature_scale) for x in feature]

        self.inc = InConv(in_channels, feature[0])
        self.down1 = Down(feature[0], feature[1])  # 48
        self.down2 = Down(feature[1], feature[2])  # 24
        self.down3 = Down(feature[2], feature[3])  # 12
        self.down4 = Down(feature[3], feature[3])  # 6

        self.up1 = AttentionUpBlock(feature[3], feature[3], feature[2])
        self.up2 = AttentionUpBlock(feature[2], feature[2], feature[1])
        self.up3 = AttentionUpBlock(feature[1], feature[1], feature[0])
        self.up4 = AttentionUpBlock(feature[0], feature[0], feature[0])
        self.outc = OutConv(feature[0], num_classes)

    def forward(self, x):
        # with torchsnooper.snoop():
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x

### Data Enhancing

In [None]:
class RandomCrop(object):
    """
    Crop randomly the image in a sample
    Args:
    output_size (int): Desired output size
    """

    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        (c, w, h, d) = image.shape
        w1 = np.random.randint(0, w - self.output_size[0])
        h1 = np.random.randint(0, h - self.output_size[1])
        d1 = np.random.randint(0, d - self.output_size[2])

        label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
        image = image[:,w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
        return {'image': image, 'label': label}


class CenterCrop(object):
    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        (c,w, h, d) = image.shape

        w1 = int(round((w - self.output_size[0]) / 2.))
        h1 = int(round((h - self.output_size[1]) / 2.))
        d1 = int(round((d - self.output_size[2]) / 2.))

        label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
        image = image[:,w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]

        return {'image': image, 'label': label}


class RandomRotFlip(object):
    """
    Crop randomly flip the dataset in a sample
    Args:
    output_size (int): Desired output size
    """

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        k = np.random.randint(0, 4)
        image = np.stack([np.rot90(x,k) for x in image],axis=0)
        label = np.rot90(label, k)
        axis = np.random.randint(1, 4)
        image = np.flip(image, axis=axis).copy()
        label = np.flip(label, axis=axis-1).copy()

        return {'image': image, 'label': label}


def augment_gaussian_noise(data_sample, noise_variance=(0, 0.1)):
    if noise_variance[0] == noise_variance[1]:
        variance = noise_variance[0]
    else:
        variance = random.uniform(noise_variance[0], noise_variance[1])
    data_sample = data_sample + np.random.normal(0.0, variance, size=data_sample.shape)
    return data_sample


class GaussianNoise(object):
    def __init__(self, noise_variance=(0, 0.1), p=0.5):
        self.prob = p
        self.noise_variance = noise_variance

    def __call__(self, sample):
        image = sample['image']
        label = sample['label']
        if np.random.uniform() < self.prob:
            image = augment_gaussian_noise(image, self.noise_variance)
        return {'image': image, 'label': label}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
        image = sample['image']
        label = sample['label']

        image = torch.from_numpy(image).float()
        label = torch.from_numpy(label).long()

        return {'image': image, 'label': label}


class BraTS(Dataset):
    def __init__(self,data_path, file_path,transform=None):
        with open(file_path, 'r') as f:
            self.paths = [os.path.join(data_path, x.strip())+'_mri_norm2.h5' for x in f.readlines()]
        self.transform = transform

    def __getitem__(self, item):
        h5f = h5py.File(self.paths[item], 'r')
        image = h5f['image'][:]
        label = h5f['label'][:]
        #[0,1,2,4] -> [0,1,2,3]
        label[label == 4] = 3
        # print(image.shape)
        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        return sample['image'], sample['label']

    def __len__(self):
        return len(self.paths)

    def collate(self, batch):
        return [torch.cat(v) for v in zip(*batch)]

### Loss and Evaluation Metrics

In [None]:
def Dice(output, target, eps=1e-3):
    inter = torch.sum(output * target,dim=(1,2,-1)) + eps
    union = torch.sum(output,dim=(1,2,-1)) + torch.sum(target,dim=(1,2,-1)) + eps * 2
    x = 2 * inter / union
    dice = torch.mean(x)
    return dice


def cal_dice(output, target):
    '''
    output: (b, num_class, d, h, w)  target: (b, d, h, w)
    dice1(ET):label4
    dice2(TC):label1 + label4
    dice3(WT): label1 + label2 + label4
    Note: label 4 has been replaced with 3
    '''
    output = torch.argmax(output,dim=1)
    dice1 = Dice((output == 3).float(), (target == 3).float())
    dice2 = Dice(((output == 1) | (output == 3)).float(), ((target == 1) | (target == 3)).float())
    dice3 = Dice((output != 0).float(), (target != 0).float())

    return dice1, dice2, dice3


class Loss(nn.Module):
    def __init__(self, n_classes, weight=None, alpha=0.5):
        "dice_loss_plus_cetr_weighted"
        super(Loss, self).__init__()
        self.n_classes = n_classes
        self.weight = weight.cuda()
        # self.weight = weight
        self.alpha = alpha

    def forward(self, input, target):
        # print(torch.unique(target))
        smooth = 0.01

        input1 = F.softmax(input, dim=1)
        target1 = F.one_hot(target,self.n_classes)
        input1 = rearrange(input1,'b n h w s -> b n (h w s)')
        target1 = rearrange(target1,'b h w s n -> b n (h w s)')

        input1 = input1[:, 1:, :]
        target1 = target1[:, 1:, :].float()

        # Calculate loss and dice loss on a batch basis, providing more stable training
        inter = torch.sum(input1 * target1)
        union = torch.sum(input1) + torch.sum(target1) + smooth
        dice = 2.0 * inter / union

        loss = F.cross_entropy(input,target, weight=self.weight)

        total_loss = (1 - self.alpha) * loss + (1 - dice) * self.alpha

        return total_loss

### Training

In [None]:
# The learning rate is updated at each iteration, not just at each epoch.
# Access the learning rate schedule at the current iteration index (scheduler[iter]).
# Update the learning rate for the optimizer's parameter group.
# help the model converge more smoothly and potentially achieve better performance.
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0.):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule


def train_loop(model,optimizer,scheduler,criterion,train_loader,device,epoch):
    model.train()
    running_loss = 0
    dice1_train = 0
    dice2_train = 0
    dice3_train = 0
    pbar = tqdm(train_loader)
    for it,(images,masks) in enumerate(pbar):
        # update learning rate according to the schedule
        it = len(train_loader) * epoch + it
        param_group = optimizer.param_groups[0]
        param_group['lr'] = scheduler[it]
        # print(scheduler[it])

        # [b,4,128,128,128] , [b,128,128,128]
        images, masks = images.to(device),masks.to(device)
        # [b,4,128,128,128], 4 segmentations
        outputs = model(images)
        # outputs = torch.softmax(outputs,dim=1)
        loss = criterion(outputs, masks)
        dice1, dice2, dice3 = cal_dice(outputs,masks)
        pbar.desc = "loss: {:.3f} ".format(loss.item())

        running_loss += loss.item()
        dice1_train += dice1.item()
        dice2_train += dice2.item()
        dice3_train += dice3.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss = running_loss / len(train_loader)
    dice1 = dice1_train / len(train_loader)
    dice2 = dice2_train / len(train_loader)
    dice3 = dice3_train / len(train_loader)
    return {'loss':loss,'dice1':dice1,'dice2':dice2,'dice3':dice3}


def val_loop(model,criterion,val_loader,device):
    model.eval()
    running_loss = 0
    dice1_val = 0
    dice2_val = 0
    dice3_val = 0
    pbar = tqdm(val_loader)
    with torch.no_grad():
        for images, masks in pbar:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            # outputs = torch.softmax(outputs,dim=1)

            loss = criterion(outputs, masks)
            dice1, dice2, dice3 = cal_dice(outputs, masks)

            running_loss += loss.item()
            dice1_val += dice1.item()
            dice2_val += dice2.item()
            dice3_val += dice3.item()
            # pbar.desc = "loss:{:.3f} dice1:{:.3f} dice2:{:.3f} dice3:{:.3f} ".format(loss,dice1,dice2,dice3)

    loss = running_loss / len(val_loader)
    dice1 = dice1_val / len(val_loader)
    dice2 = dice2_val / len(val_loader)
    dice3 = dice3_val / len(val_loader)
    return {'loss':loss,'dice1':dice1,'dice2':dice2,'dice3':dice3}


def train(model,optimizer,scheduler,criterion,train_loader,
          val_loader,epochs,device,train_log,valid_loss_min=999.0):
    train_metrics_all = []
    val_metrics_all = []
    for e in range(epochs):
        # train for epoch
        train_metrics = train_loop(model,optimizer,scheduler,criterion,train_loader,device,e)
        train_metrics_all.append(train_metrics)
        # eval for epoch
        val_metrics = val_loop(model,criterion,val_loader,device)
        val_metrics_all.append(val_metrics)
        info1 = "Epoch:[{}/{}] valid_loss_min: {:.3f} train_loss: {:.3f} valid_loss: {:.3f} ".format(e+1,epochs,valid_loss_min,train_metrics["loss"],val_metrics["loss"])
        info2 = "Train--ET: {:.3f} TC: {:.3f} WT: {:.3f} ".format(train_metrics['dice1'],train_metrics['dice2'],train_metrics['dice3'])
        info3 = "Valid--ET: {:.3f} TC: {:.3f} WT: {:.3f} ".format(val_metrics['dice1'],val_metrics['dice2'],val_metrics['dice3'])
        print(info1)
        print(info2)
        print(info3)
        with open(train_log,'a') as f:
            f.write(info1 + '\n' + info2 + ' ' + info3 + '\n')

        if not os.path.exists(args.save_path):
            os.makedirs(args.save_path)
        save_file = {"model": model.state_dict(),
                     "optimizer": optimizer.state_dict()}
        if val_metrics['loss'] < valid_loss_min:
            valid_loss_min = val_metrics['loss']
            torch.save(save_file, args.weights)
        else:
            torch.save(save_file,os.path.join(args.save_path,'checkpoint{}.pth'.format(e+1)))
    print("Finished Training!")
    return train_metrics_all, val_metrics_all


def main(args):
    torch.manual_seed(args.seed)  # Set the seed for the CPU to ensure reproducible results
    torch.cuda.manual_seed_all(args.seed)  # Set the seed for all GPUs to ensure reproducible results

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # get datasets
    patch_size = (160,160,128)
    train_dataset = BraTS(args.data_path,args.train_txt,transform=transforms.Compose([
        RandomRotFlip(), 
        RandomCrop(patch_size), 
        GaussianNoise(p=0.1), 
        ToTensor()
    ]))
    val_dataset = BraTS(args.data_path,args.valid_txt,transform=transforms.Compose([
        CenterCrop(patch_size),
        ToTensor()
    ]))
    test_dataset = BraTS(args.data_path,args.test_txt,transform=transforms.Compose([
        CenterCrop(patch_size),
        ToTensor()
    ]))
    # a glance at dataset
    # d1 = test_dataset[0]
    # image,label = d1
    # print(image.shape)
    # print(label.shape)
    # print(np.unique(label))

    # data loaders
    train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, num_workers=12,   # num_worker=4
                              shuffle=True, pin_memory=True)
    val_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, num_workers=12, shuffle=False,
                            pin_memory=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, num_workers=12, shuffle=False,
                             pin_memory=True)

    print("using {} device.".format(device))
    print("using {} images for training, {} images for validation.".format(len(train_dataset), len(val_dataset)))
    # img,label = train_dataset[0]
    # 1 - Necrotic tumor core (NT), 2 - Peritumoral edema (ED), 4 - Enhancing tumor (ET)
    # Evaluation metrics: ET (label 4), TC (label 1 + label 4), WT (label 1 + label 2 + label 4)
    if args.model == 'unet':
        print("The model we are using  is U-Net.")
        model = UNet(in_channels=4,num_classes=4)
    else:
        print("The model we are using is Attention U-Net.")
        model = AttentionUNet(in_channels=4, num_classes=4)
    # Use DataParallel for multi-GPU
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
    model = model.to(device)
    
    num_params = 0
    for param in model.parameters():
        num_params += param.numel()
    print(f'Total number of parameters: {num_params / 1e6:.3f} M')

    criterion = Loss(n_classes=4, weight=torch.tensor([0.2, 0.3, 0.25, 0.25])).to(device)
    optimizer = optim.SGD(model.parameters(),momentum=0.9, lr=0, weight_decay=5e-4)
    scheduler = cosine_scheduler(base_value=args.lr,final_value=args.min_lr,epochs=args.epochs,
                                 niter_per_ep=len(train_loader),warmup_epochs=args.warmup_epochs,start_warmup_value=5e-4)

    # load training model
    if os.path.exists(args.weights):
        weight_dict = torch.load(args.weights, map_location=device)
        model.load_state_dict(weight_dict['model'])
        optimizer.load_state_dict(weight_dict['optimizer'])
        print('Successfully loading checkpoint.')

    train_metrics_all, val_metrics_all = train(model,optimizer,scheduler,criterion,train_loader,val_loader,args.epochs,device,train_log=args.train_log)

    metrics1 = val_loop(model, criterion, train_loader, device)
    metrics2 = val_loop(model, criterion, val_loader, device)
    metrics3 = val_loop(model, criterion, test_loader, device)

    # Finally, evaluate all the data again. 
    # Note that the model parameters used here are from the end of the training
    print("Train -- loss: {:.3f} ET: {:.3f} TC: {:.3f} WT: {:.3f}".format(metrics1['loss'], metrics1['dice1'],metrics1['dice2'], metrics1['dice3']))
    print("Valid -- loss: {:.3f} ET: {:.3f} TC: {:.3f} WT: {:.3f}".format(metrics2['loss'], metrics2['dice1'], metrics2['dice2'], metrics2['dice3']))
    print("Test  -- loss: {:.3f} ET: {:.3f} TC: {:.3f} WT: {:.3f}".format(metrics3['loss'], metrics3['dice1'], metrics3['dice2'], metrics3['dice3']))

    return train_metrics_all, val_metrics_all

In [None]:
class Config:
    def __init__(self):
        self.num_classes = 4
        self.seed = 21
        self.epochs = 60
        self.warmup_epochs = 10
        self.batch_size = 1
        self.lr = 0.004
        self.min_lr = 0.0015
        self.data_path = train_set['out']
        self.train_txt = './train.txt'
        self.valid_txt = './valid.txt'
        self.test_txt = './test.txt'
        self.train_log = './results/UNet.txt'
        self.weights = './results/UNet.pth'
        self.save_path = './checkpoint/UNet'
        self.model = 'aunet'  # unet or aunet

args = Config()

In [None]:
if os.path.isfile(args.weights):
    os.remove(args.weights)
if os.path.isfile(self.train_log):
    os.remove(self.train_log)

import gc
gc.collect()
torch.cuda.empty_cache()

train_metrics_all, val_metrics_all = main(args)

### Plot

In [None]:
train_losses = [entry['loss'] for entry in train_metrics_all]
valid_losses = [entry['loss'] for entry in val_metrics_all]

plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(valid_losses, label='Validation Loss')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.legend()

plt.show()

In [None]:
train_et = [entry['dice1'] for entry in train_metrics_all]
val_et = [entry['dice1'] for entry in val_metrics_all]

plt.figure(figsize=(10, 5))
plt.plot(train_et, label='Training ET')
plt.plot(val_et, label='Validation ET')

plt.xlabel('Epoch')
plt.ylabel('ET Accracy')
plt.title('Training and Validation ET Accuracy Over Epochs')
plt.legend()

plt.show()

In [None]:
train_tc = [entry['dice2'] for entry in train_metrics_all]
val_tc = [entry['dice2'] for entry in val_metrics_all]

plt.figure(figsize=(10, 5))
plt.plot(train_tc, label='Training TC')
plt.plot(val_tc, label='Validation TC')

plt.xlabel('Epoch')
plt.ylabel('TC Accracy')
plt.title('Training and Validation TC Accuracy Over Epochs')
plt.legend()

plt.show()

In [None]:
train_wt = [entry['dice3'] for entry in train_metrics_all]
val_wt = [entry['dice3'] for entry in val_metrics_all]

plt.figure(figsize=(10, 5))
plt.plot(train_wt, label='Training WT')
plt.plot(val_wt, label='Validation WT')

plt.xlabel('Epoch')
plt.ylabel('WT Accracy')
plt.title('Training and Validation WT Accuracy Over Epochs')
plt.legend()

plt.show()

In [None]:
train_mean = [(entry['dice3']+entry['dice2']+entry['dice1'])/3 for entry in train_metrics_all]
val_mean = [(entry['dice3']+entry['dice2']+entry['dice1'])/3 for entry in val_metrics_all]
    
plt.figure(figsize=(10, 5))
plt.plot(train_mean, label='Training Average')
plt.plot(val_mean, label='Validation Average')

plt.xlabel('Epoch')
plt.ylabel('Average Accracy')
plt.title('Training and Validation Average Accuracy Over Epochs')
plt.legend()

plt.show()

### Testing by Sliding Window Inference

In [None]:
def calculate_metric_percase(pred, gt):
    dice = metric.binary.dc(pred, gt)
    jc = metric.binary.jc(pred, gt)
    hd = metric.binary.hd95(pred, gt)
    asd = metric.binary.asd(pred, gt)

    return dice, jc, hd, asd

def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1):
    print(image.shape)
    c, ww, hh, dd = image.shape

    sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
    sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
    sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
    # print("{}, {}, {}".format(sx, sy, sz))
    score_map = np.zeros((num_classes, ) + image.shape[1:]).astype(np.float32)
    cnt = np.zeros(image.shape[1:]).astype(np.float32)

    for x in range(0, sx):
        xs = min(stride_xy*x, ww-patch_size[0])
        for y in range(0, sy):
            ys = min(stride_xy * y,hh-patch_size[1])
            for z in range(0, sz):
                zs = min(stride_z * z, dd-patch_size[2])
                test_patch = image[:,xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
                test_patch = np.expand_dims(test_patch,axis=0).astype(np.float32)
                test_patch = torch.from_numpy(test_patch).cuda()
                with torch.no_grad():
                    y1 = net(test_patch)
                    y = F.softmax(y1, dim=1)
                y = y.cpu().data.numpy()
                y = y[0,:,:,:,:]
                score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
                  = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
                cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
                  = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
    score_map = score_map/np.expand_dims(cnt,axis=0)
    label_map = np.argmax(score_map, axis = 0)
    return label_map, score_map

def test_all_case(net, image_list, num_classes=2, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None):
    total_metric = 0.0
    for ith,image_path in enumerate(image_list):
        h5f = h5py.File(image_path+'_mri_norm2.h5', 'r')
        image = h5f['image'][:]
        label = h5f['label'][:]
        label[label==4] = 3  # Change label from 4 to 3
        if preproc_fn is not None:
            image = preproc_fn(image)
        prediction, score_map = test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes)
        print(np.unique(prediction),np.unique(label))

        if np.sum(prediction)==0:
            single_metric = (0,0,0,0)
        else:
            single_metric = calculate_metric_percase(prediction, label[:])
        print('%02d,\t%.5f, %.5f, %.5f, %.5f' % (ith, single_metric[0], single_metric[1], single_metric[2], single_metric[3]))
        total_metric += np.asarray(single_metric)

        if save_result:
            nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path + "%02d_pred.nii.gz"%(ith))
            # image only saves one modality
            nib.save(nib.Nifti1Image(image[0].astype(np.float32), np.eye(4)), test_save_path + "%02d_img.nii.gz"%(ith))
            nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path + "%02d_gt.nii.gz"%(ith))
    avg_metric = total_metric / len(image_list)
    print('average metric is {}'.format(avg_metric))

    return avg_metric

In [None]:
!mkdir predictions
!mkdir predictions/unet

In [None]:
# gc.collect()
# torch.cuda.empty_cache()

In [None]:
test_save_path = './predictions/unet/'
save_mode_path = args.weights

if args.model == 'unet':
    net = UNet(in_channels=4,num_classes=4)
else:
    net = AttentionUNet(in_channels=4, num_classes=4)
counter = 0
for p in net.parameters():
    counter += p.numel()
print('Param', counter)
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    net = nn.DataParallel(net)
net = net.to(device)

net.load_state_dict(torch.load(save_mode_path)['model'])
print("init weight from {}".format(save_mode_path))
net.eval()

with open(args.test_txt, 'r') as f:
    image_list = [os.path.join(args.data_path, x.strip()) for x in f.readlines()]
print('Total number of images is', len(image_list))
# print(image_list[0])

# sliding_window_inference
avg_metric = test_all_case(net, image_list, num_classes=4,
                            patch_size=(160,160,128), stride_xy=32, stride_z=16,
                            save_result=True,test_save_path=test_save_path)   

In [None]:
# Pack up the output for viewing in the app like 'ITK-SNAP'
shutil.make_archive(f'./{args.model}_lr{str(args.min_lr)[-3:]}', 'zip', test_save_path)

### Output Images Visualization

In [None]:
def display_nifti(file_path, title):
    img = nib.load(file_path)
    data = img.get_fdata()
    
    # Display the middle slice of each dimension
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(data[data.shape[0] // 2, :, :], cmap='viridis')
    axes[0].set_title(f'{title} - Axial')
    axes[0].axis('off')
    
    axes[1].imshow(data[:, data.shape[1] // 2, :], cmap='viridis')
    axes[1].set_title(f'{title} - Coronal')
    axes[1].axis('off')
    
    axes[2].imshow(data[:, :, data.shape[2] // 2], cmap='viridis')
    axes[2].set_title(f'{title} - Sagittal')
    axes[2].axis('off')
    plt.show()

# Loop through all files and visualize them
for file_name in sorted(os.listdir(test_save_path)):
    if file_name.endswith('.nii.gz'):
        file_path = os.path.join(test_save_path, file_name)
        title = file_name.split('.')[0]
        display_nifti(file_path, title)