In [2]:
import os
import shutil
import glob

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import nibabel as nib
from sklearn.metrics import confusion_matrix, accuracy_score
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import Dataset, TensorDataset, DataLoader, random_split
from torchmetrics.classification import BinaryJaccardIndex, Dice
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp

# Dataset Class

In [3]:
class BraTSDataset(Dataset):    
    def __init__(self, data_root_folder, folder = '', n_sample=None):
        main_folder = os.path.join(data_root_folder, folder)
        self.folder_path = os.path.join(main_folder, 'slice')

    def __getitem__(self, index):
        file_name = os.listdir(self.folder_path)[index]
        sample = torch.from_numpy(np.load(os.path.join(self.folder_path, file_name)))
        img_as_tensor = np.expand_dims(sample[0,:,:], axis=0)
        mask_as_tensor = np.expand_dims(sample[1,:,:], axis=0)
        return {
            'image': img_as_tensor,
            'mask': mask_as_tensor,
            'img_id': file_name
        }
 
    def __len__(self):
        return len(os.listdir(self.folder_path))

# Load Dataset

In [4]:
data_root_folder = '/kaggle/input/brats-dataset/full_raw - Copy'
train_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'train')
val_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'val')
test_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'test')

In [5]:
BATCH_SIZE = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
validation_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Sub Classes for U-Net and Attention U-Net

In [7]:
class conv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class resconv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(resconv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True)
        )
        self.Conv_1x1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0)

    def forward(self, x):

        residual = self.Conv_1x1(x)
        x = self.conv(x)

        return residual + x


class up_conv(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x

# U-Net

In [8]:
class U_Net(nn.Module):
    def __init__(self, img_ch=3, output_ch=1, first_layer_numKernel=64, name = "U_Net"):
        super(U_Net, self).__init__()
        self.name = name
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(ch_in=img_ch, ch_out=first_layer_numKernel)
        self.Conv2 = conv_block(ch_in=first_layer_numKernel, ch_out=2 * first_layer_numKernel)
        self.Conv3 = conv_block(ch_in=2 * first_layer_numKernel, ch_out=4 * first_layer_numKernel)
        self.Conv4 = conv_block(ch_in=4 * first_layer_numKernel, ch_out=8 * first_layer_numKernel)
        self.Conv5 = conv_block(ch_in=8 * first_layer_numKernel, ch_out=16 * first_layer_numKernel)

        self.Up5 = up_conv(ch_in=16 * first_layer_numKernel, ch_out=8 * first_layer_numKernel)
        self.Up_conv5 = conv_block(ch_in=16 * first_layer_numKernel, ch_out=8 * first_layer_numKernel)

        self.Up4 = up_conv(ch_in=8 * first_layer_numKernel, ch_out=4 * first_layer_numKernel)
        self.Up_conv4 = conv_block(ch_in=8 * first_layer_numKernel, ch_out=4 * first_layer_numKernel)

        self.Up3 = up_conv(ch_in=4 * first_layer_numKernel, ch_out=2 * first_layer_numKernel)
        self.Up_conv3 = conv_block(ch_in=4 * first_layer_numKernel, ch_out=2 * first_layer_numKernel)

        self.Up2 = up_conv(ch_in=2 * first_layer_numKernel, ch_out=first_layer_numKernel)
        self.Up_conv2 = conv_block(ch_in=2 * first_layer_numKernel, ch_out=first_layer_numKernel)

        self.Conv_1x1 = nn.Sequential(
            nn.Conv2d(first_layer_numKernel, output_ch, kernel_size=1, stride=1, padding=0), nn.Sigmoid()  # Use sigmoid activation for binary segmentation
        )
        # self.Conv_1x1 =  nn.Conv2d(first_layer_numKernel, output_ch, kernel_size = 1, stride = 1, padding = 0)

    def forward(self, x):

        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)
        
        #print(d1)

        return d1

In [9]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.max_validation_dice = float('-inf')

    def early_stop(self, validation_dice):
        if validation_dice > self.max_validation_dice:
            self.max_validation_dice = validation_dice
            self.counter = 0
        elif validation_dice < (self.max_validation_dice + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [10]:
LR = 5e-3
EPOCHS = 10 # this maybe too big?


unet = U_Net(img_ch=1, output_ch=1).to(device)

optimizer = torch.optim.AdamW(unet.parameters(), lr=LR)
loss_function = nn.BCELoss()
early_stopper = EarlyStopper(patience=3, min_delta=0.02)

In [11]:
print(sum(p.numel() for p in unet.parameters() if p.requires_grad))

34525889


In [16]:
# The training function
def train_net(net, epochs, train_dataloader, valid_dataloader, optimizer, loss_function):

    if not os.path.isdir('{0}'.format(net.name)):
        os.mkdir('{0}'.format(net.name))
    
    n_train = len(train_dataloader)
    n_valid = len(valid_dataloader)

    train_loss, valid_loss = [], []
    train_dice, valid_dice = [], []
    train_jaccard, valid_jaccard = [], []
    dice_metric = Dice()
    jaccard_index_metric = BinaryJaccardIndex()

    # Training
    for epoch in range(epochs):
        net.train()
        train_batch_loss, train_batch_dice, train_batch_jaccard = [], [], []

        for i, batch in enumerate(tqdm(train_dataloader)):
            imgs = batch['image'].to(device).float()
            true_masks = batch['mask']

            # Produce the estimated mask using current weights
            y_pred = net(imgs).cpu()

            loss = loss_function(y_pred, true_masks.float())
            train_batch_loss.append(loss.item())
            
            y_pred = (y_pred >= 0.5).float()

            batch_dice_score = dice_metric(y_pred, true_masks)
            train_batch_dice.append(batch_dice_score)

            batch_jaccard_score = jaccard_index_metric(y_pred, true_masks)
            train_batch_jaccard.append(batch_jaccard_score)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step() 
            if i == 50:
                break

            print(f'EPOCH {epoch + 1}/{epochs} - Training Batch {i+1}/{n_train} - Loss: {loss.item()}, DICE score: {batch_dice_score}, Jaccard score: {batch_jaccard_score}', end='\r')
        
        train_loss.append(np.mean(train_batch_loss))
        train_dice.append(np.mean(train_batch_dice))
        train_jaccard.append(np.mean(train_batch_jaccard))

        net.eval()
        valid_batch_loss, valid_batch_dice, valid_batch_jaccard = [], [], []

        # Validation
        with torch.no_grad():
            for i, batch in enumerate(tqdm(valid_dataloader)):
                imgs = batch['image'].to(device).float()
                true_masks = batch['mask'] #.to(device).float()
    
                y_pred = net(imgs).cpu()

                loss = loss_function(y_pred, true_masks.float())
                valid_batch_loss.append(loss.item())
                
                y_pred = (y_pred >= 0.5).float()

                batch_dice_score = dice_metric(y_pred, true_masks)
                valid_batch_dice.append(batch_dice_score)

                batch_jaccard_score = jaccard_index_metric(y_pred, true_masks)
                valid_batch_jaccard.append(batch_jaccard_score)
     
                print(f'EPOCH {epoch + 1}/{epochs} - Validation Batch {i+1}/{n_valid} - Loss: {loss.item()}, DICE score: {batch_dice_score}, Jaccard score: {batch_jaccard_score}', end='\r')

        valid_loss.append(np.mean(valid_batch_loss))
        valid_dice.append(np.mean(valid_batch_dice))
        valid_jaccard.append(np.mean(valid_batch_jaccard))
        
        if early_stopper.early_stop(np.mean(valid_batch_dice)):             
            break 

        print(f'EPOCH {epoch + 1}/{epochs} - Training Loss: {np.mean(train_batch_loss)}, Training DICE score: {np.mean(train_batch_dice)}, Training Jaccard score: {np.mean(train_batch_jaccard)}, Validation Loss: {np.mean(valid_batch_loss)}, Validation DICE score: {np.mean(valid_batch_dice)}, Validation Jaccard score: {np.mean(valid_batch_jaccard)}')

        torch.save(net.state_dict(), f'{net.name}/epoch_{epoch+1:3}.pth')

    return train_loss, train_dice, train_jaccard, valid_loss, valid_dice, valid_jaccard

In [17]:
train_loss, train_dice, train_jaccard, valid_loss, valid_dice, valid_jaccard = train_net(unet, EPOCHS, train_dataloader, validation_dataloader, optimizer, loss_function)

  0%|          | 0/8496 [00:00<?, ?it/s]

EPOCH 1/10 - Training Batch 50/8496 - Loss: 0.03032330609858036, DICE score: 0.0, Jaccard score: 0.00

  0%|          | 0/1812 [00:00<?, ?it/s]

EPOCH 1/10 - Validation Batch 72/1812 - Loss: 0.03402853384613991, DICE score: 0.0, Jaccard score: 0.00

KeyboardInterrupt: 

#fixme: 