In [1]:
import os
import shutil
import glob

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import nibabel as nib
from sklearn.metrics import confusion_matrix, accuracy_score
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.utils.data import Dataset, TensorDataset, DataLoader, random_split
from torchmetrics.classification import BinaryJaccardIndex, Dice

# How to read and plot a nii file

In [None]:
nii_img = nib.load("/kaggle/input/brats2021/BraTS2021_00008_t1ce.nii/00000088_brain_t1ce.nii")
nii_data = nii_img.get_fdata()
nii_aff  = nii_img.affine
nii_hdr  = nii_img.header

print(nii_aff ,'\n',nii_hdr)
print(nii_data.shape)

# Define the number of columns for subplot
num_columns = 5

if(len(nii_data.shape)==3):
    start_slice = 70
    end_slice = 89
    num_slices = end_slice - start_slice + 1
    num_rows = num_slices // num_columns + 1

    fig, ax = plt.subplots(num_rows, num_columns, figsize=(10,10))

    for i in range(start_slice, end_slice+1):
        row = (i-start_slice) // num_columns
        col = (i-start_slice) % num_columns
        ax[row, col].imshow(nii_data[:,:,i])
        ax[row, col].axis('off')

    # Remove empty subplots
    for j in range(i-start_slice+1, num_rows*num_columns):
        fig.delaxes(ax.flatten()[j])

    plt.show()

In [None]:
nii_img = nib.load("/kaggle/input/brats2021/00008_post_normalized.nii")
nii_data = nii_img.get_fdata()
nii_aff  = nii_img.affine
nii_hdr  = nii_img.header

print(nii_aff ,'\n',nii_hdr)
print(nii_data.shape)

# Define the number of columns for subplot
num_columns = 5

if(len(nii_data.shape)==3):
    start_slice = 70
    end_slice = 89
    num_slices = end_slice - start_slice + 1
    num_rows = num_slices // num_columns + 1

    fig, ax = plt.subplots(num_rows, num_columns, figsize=(10,10))

    for i in range(start_slice, end_slice+1):
        row = (i-start_slice) // num_columns
        col = (i-start_slice) % num_columns
        ax[row, col].imshow(nii_data[:,:,i])
        ax[row, col].axis('off')

    # Remove empty subplots
    for j in range(i-start_slice+1, num_rows*num_columns):
        fig.delaxes(ax.flatten()[j])

    plt.show()

In [None]:
nii_img = nib.load("/kaggle/input/brats2021/BraTS2021_00008_seg.nii/00000088_final_seg.nii")
nii_data = nii_img.get_fdata()
nii_aff  = nii_img.affine
nii_hdr  = nii_img.header

print(nii_aff ,'\n',nii_hdr)
print(nii_data.shape)

# Define the number of columns for subplot
num_columns = 5

if(len(nii_data.shape)==3):
    start_slice = 70
    end_slice = 89
    num_slices = end_slice - start_slice + 1
    num_rows = num_slices // num_columns + 1

    fig, ax = plt.subplots(num_rows, num_columns, figsize=(10,10))

    for i in range(start_slice, end_slice+1):
        row = (i-start_slice) // num_columns
        col = (i-start_slice) % num_columns
        ax[row, col].imshow(nii_data[:,:,i])
        ax[row, col].axis('off')

    # Remove empty subplots
    for j in range(i-start_slice+1, num_rows*num_columns):
        fig.delaxes(ax.flatten()[j])

    plt.show()

In [2]:
class BraTSDataset(Dataset):    
    def __init__(self, data_root_folder, folder = '', n_sample=None):
        main_folder = os.path.join(data_root_folder, folder)
        self.folder_path = os.path.join(main_folder, 'slice')

    def __getitem__(self, index):
        file_name = os.listdir(self.folder_path)[index]
        sample = torch.from_numpy(np.load(os.path.join(self.folder_path, file_name)))
        img_as_tensor = np.expand_dims(sample[0,:,:], axis=0)
        mask_as_tensor = np.expand_dims(sample[1,:,:], axis=0)
        return {
            'image': img_as_tensor,
            'mask': mask_as_tensor,
            'img_id': file_name
        }
 
    def __len__(self):
        return len(os.listdir(self.folder_path))

# Load Dataset

In [3]:
data_root_folder = '/kaggle/input/brats-dataset/full_raw - Copy'
train_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'train')
val_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'val')
test_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'test')

In [4]:
BATCH_SIZE = 16
LR = 0.001
EPOCHS = 1000
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
validation_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# U-Net

In [6]:
class conv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class resconv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(resconv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True)
        )
        self.Conv_1x1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0)

    def forward(self, x):

        residual = self.Conv_1x1(x)
        x = self.conv(x)

        return residual + x


class up_conv(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x

In [7]:
class Attention_block(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(Attention_block, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.GroupNorm(32, F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.GroupNorm(32, F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.GroupNorm(1, 1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)

        return x * psi


class U_Net(nn.Module):
    def __init__(self, img_ch=3, output_ch=1, first_layer_numKernel=64, name = "U_Net"):
        super(U_Net, self).__init__()
        self.name = name
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(ch_in=img_ch, ch_out=first_layer_numKernel)
        self.Conv2 = conv_block(ch_in=first_layer_numKernel, ch_out=2 * first_layer_numKernel)
        self.Conv3 = conv_block(ch_in=2 * first_layer_numKernel, ch_out=4 * first_layer_numKernel)
        self.Conv4 = conv_block(ch_in=4 * first_layer_numKernel, ch_out=8 * first_layer_numKernel)
        self.Conv5 = conv_block(ch_in=8 * first_layer_numKernel, ch_out=16 * first_layer_numKernel)

        self.Up5 = up_conv(ch_in=16 * first_layer_numKernel, ch_out=8 * first_layer_numKernel)
        self.Up_conv5 = conv_block(ch_in=16 * first_layer_numKernel, ch_out=8 * first_layer_numKernel)

        self.Up4 = up_conv(ch_in=8 * first_layer_numKernel, ch_out=4 * first_layer_numKernel)
        self.Up_conv4 = conv_block(ch_in=8 * first_layer_numKernel, ch_out=4 * first_layer_numKernel)

        self.Up3 = up_conv(ch_in=4 * first_layer_numKernel, ch_out=2 * first_layer_numKernel)
        self.Up_conv3 = conv_block(ch_in=4 * first_layer_numKernel, ch_out=2 * first_layer_numKernel)

        self.Up2 = up_conv(ch_in=2 * first_layer_numKernel, ch_out=first_layer_numKernel)
        self.Up_conv2 = conv_block(ch_in=2 * first_layer_numKernel, ch_out=first_layer_numKernel)

        self.Conv_1x1 = nn.Sequential(
            nn.Conv2d(first_layer_numKernel, output_ch, kernel_size=1, stride=1, padding=0), nn.Sigmoid()  # Use sigmoid activation for binary segmentation
        )
        # self.Conv_1x1 =  nn.Conv2d(first_layer_numKernel, output_ch, kernel_size = 1, stride = 1, padding = 0)

    def forward(self, x):
        print(x)
        print(x.shape)
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return d1

In [8]:
unet = U_Net(img_ch=1, output_ch=1).to(device)

In [9]:
optimizer = torch.optim.Adam(unet.parameters(), lr=LR)
loss_function = nn.CrossEntropyLoss() #.to(device)

In [12]:
# The training function
def train_net(net, epochs, train_dataloader, valid_dataloader, optimizer, loss_function):

    if not os.path.isdir('{0}'.format(net.name)):
        os.mkdir('{0}'.format(net.name))

    n_train = len(train_dataloader)
    n_valid = len(valid_dataloader)

    train_loss = list()
    valid_loss = list()
    train_dice = list()
    valid_dice = list()
    train_jaccard = list()
    valid_jaccard = list()
    dice_metric = Dice().to(device)
    jaccard_index_metric = BinaryJaccardIndex().to(device)
    for epoch in range(epochs):

        ################################################################################################################################
        ########################################################### Training ###########################################################
        ################################################################################################################################
        net.train()
        train_batch_loss = list()
        train_batch_dice = list()
        train_batch_jaccard = list()

        for i, batch in enumerate(tqdm(train_dataloader)):
            
            # Load a batch and pass it to the GPU
            imgs = batch['image'].to(device).float()
            true_masks = batch['mask'].to(device).float()

            # Produce the estimated mask using current weights
            y_pred = net(imgs)
            
            # Compute the loss for this batch and append it to the epoch loss
            loss = loss_function(y_pred, true_masks)
            batch_loss = loss.item()
            train_batch_loss.append(batch_loss)
            
            # Make the binary mask to compute the DICE score. Since the y_pred is a Pytoch tensor, we use `torch.argmax()` instead of `np.argmax()`.
            # the axis must be 1 instead of 0 because the format is [batch, channel, height, width]
            pred_binary = torch.argmax(y_pred, axis=1)
            
            # Compute the DICE score for this batch and append it to the epoch dice
            batch_dice_score = dice_metric(pred_binary, true_masks.to(torch.int)).to(device)
            train_batch_dice.append(batch_dice_score)

            # Compute the Jaccard score here and append the score to the list
            batch_jaccard_score = jaccard_index_metric(pred_binary.unsqueeze(1), true_masks).to(device)
            train_batch_jaccard.append(batch_jaccard_score)

            # Reset gradient values
            optimizer.zero_grad()

            # Compute the backward losses
            loss.backward()

            # Update the weights
            optimizer.step()

            # Print the progress
            print(f'EPOCH {epoch + 1}/{epochs} - Training Batch {i+1}/{n_train} - Loss: {batch_loss}, DICE score: {batch_dice_score}, Jaccard score: {batch_jaccard_score}            ', end='\r')

        average_training_loss = np.array(train_batch_loss).mean()
        average_training_dice = np.array(train_batch_dice).mean()
        average_training_jaccard = np.array(train_batch_jaccard).mean()
        train_loss.append(average_training_loss)
        train_dice.append(average_training_dice)
        train_jaccard.append(average_training_jaccard)

        ################################################################################################################################
        ########################################################## Validation ##########################################################
        ################################################################################################################################

        net.eval()
        valid_batch_loss = list()
        valid_batch_dice = list()
        valid_batch_jaccard = list()

        # This part is almost the same as training with the difference that we will set all layers to evaluation mode (effects some layers such as BN and Dropout) and also
        # we don't need to calculate the gradient since we are only evaluating current state of the model. This will speed up the process and cause it to consume less memory.
        with torch.no_grad():
            for i, batch in enumerate(tqdm(valid_dataloader)):

                # Load a batch and pass it to the GPU
                imgs = batch['image'].to(device).float()
                true_masks = batch['mask'].to(device).float()
    
                # Produce the estimated mask using current weights
                y_pred = net(imgs)

                # Compute the loss for this batch and append it to the epoch loss
                loss = loss_function(y_pred, true_masks)
                batch_loss = loss.item()
                valid_batch_loss.append(batch_loss)

                # Make the binary mask to compute the DICE score. Since the y_pred is a Pytoch tensor, we use `torch.argmax()` instead of `np.argmax()`.
                # the axis must be 1 instead of 0 because the format is [batch, channel, height, width]
                pred_binary = torch.argmax(y_pred, axis=1)

                # Compute the DICE score for this batch and append it to the epoch dice
                batch_dice_score = dice_metric(pred_binary, true_masks.to(torch.int)).to(device)
                valid_batch_dice.append(batch_dice_score)
                
                # Compute the Jaccard score here and append the score to the list
                batch_jaccard_score = jaccard_index_metric(pred_binary.unsqueeze(1), true_masks).to(device)
                valid_batch_jaccard.append(batch_jaccard_score)

                # Print the progress
                print(f'EPOCH {epoch + 1}/{epochs} - Validation Batch {i+1}/{n_valid} - Loss: {batch_loss}, DICE score: {batch_dice_score}, Jaccard score: {batch_jaccard_score}            ', end='\r')

        average_validation_loss = np.array(valid_batch_loss).mean()
        average_validation_dice = np.array(valid_batch_dice).mean()
        average_validation_jaccard = np.array(valid_batch_jaccard).mean()
        valid_loss.append(average_validation_loss)
        valid_dice.append(average_validation_dice)
        valid_jaccard.append(average_validation_jaccard)

        print(f'EPOCH {epoch + 1}/{epochs} - Training Loss: {average_training_loss}, Training DICE score: {average_training_dice}, Training Jaccard score: {average_training_jaccard}, Validation Loss: {average_validation_loss}, Validation DICE score: {average_validation_dice}, Validation Jaccard score: {average_validation_jaccard}')

        ################################################################################################################################
        ###################################################### Saveing Checkpoints #####################################################
        ################################################################################################################################
        torch.save(net.state_dict(), f'{net.name}/epoch_{epoch+1:03}.pth')

    return train_loss, train_dice, train_jaccard, valid_loss, valid_dice, valid_jaccard

In [13]:
train_loss, train_dice, train_jaccard, valid_loss, valid_dice, valid_jaccard = train_net(unet, EPOCHS, train_dataloader, validation_dataloader, optimizer, loss_function)

  0%|          | 0/8496 [00:00<?, ?it/s]

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0.

KeyboardInterrupt: 

#fixme: 