In [2]:
!pip install monai

Collecting monai
  Downloading monai-1.3.0-202310121228-py3-none-any.whl.metadata (10 kB)
Downloading monai-1.3.0-202310121228-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.3.0


In [3]:
import os
import shutil
import glob

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
#import nibabel as nib
from sklearn.metrics import confusion_matrix, accuracy_score
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import Dataset, TensorDataset, DataLoader, random_split
from torchmetrics.classification import BinaryJaccardIndex, Dice
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp

In [4]:
from monai.data import DataLoader 
# , ArrayDataset
# from torch.optim.lr_scheduler import CosineAnnealingLR
# from monai.transforms import (
#     EnsureChannelFirst,
#     AsDiscrete,
#     Compose,
#     LoadImage,
#     ScaleIntensity,
# )
import pytorch_lightning as pl

2024-04-14 02:23:31.428163: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-14 02:23:31.428296: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-14 02:23:31.681524: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


# Dataset Class

In [5]:
class BraTSDataset(Dataset):    
    def __init__(self, data_root_folder, folder = '', n_sample=None):
        main_folder = os.path.join(data_root_folder, folder)
        self.folder_path = os.path.join(main_folder, 'slice')

    def __getitem__(self, index):
        file_name = os.listdir(self.folder_path)[index]
        sample = torch.from_numpy(np.load(os.path.join(self.folder_path, file_name)))
        img_as_tensor = np.expand_dims(sample[0,:,:], axis=0)
        mask_as_tensor = np.expand_dims(sample[1,:,:], axis=0)
        return {
            'image': img_as_tensor,
            'mask': mask_as_tensor,
            'img_id': file_name
        }
 
    def __len__(self):
        return len(os.listdir(self.folder_path))

# Load Dataset

In [6]:
data_root_folder = '/kaggle/input/full_raw - Copy'
train_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'train')
val_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'val')
test_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'test')

In [7]:
BATCH_SIZE = 16
#device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=2)
validation_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=2)

# Sub Classes for U-Net and Attention U-Net

In [9]:
class conv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class resconv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(resconv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True)
        )
        self.Conv_1x1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0)

    def forward(self, x):

        residual = self.Conv_1x1(x)
        x = self.conv(x)

        return residual + x


class up_conv(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x

# U-Net

In [10]:
class U_Net(nn.Module):
    def __init__(self, img_ch=3, output_ch=1, first_layer_numKernel=64, name = "U_Net"):
        super(U_Net, self).__init__()
        self.name = name
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(ch_in=img_ch, ch_out=first_layer_numKernel)
        self.Conv2 = conv_block(ch_in=first_layer_numKernel, ch_out=2 * first_layer_numKernel)
        self.Conv3 = conv_block(ch_in=2 * first_layer_numKernel, ch_out=4 * first_layer_numKernel)
        self.Conv4 = conv_block(ch_in=4 * first_layer_numKernel, ch_out=8 * first_layer_numKernel)
        self.Conv5 = conv_block(ch_in=8 * first_layer_numKernel, ch_out=16 * first_layer_numKernel)

        self.Up5 = up_conv(ch_in=16 * first_layer_numKernel, ch_out=8 * first_layer_numKernel)
        self.Up_conv5 = conv_block(ch_in=16 * first_layer_numKernel, ch_out=8 * first_layer_numKernel)

        self.Up4 = up_conv(ch_in=8 * first_layer_numKernel, ch_out=4 * first_layer_numKernel)
        self.Up_conv4 = conv_block(ch_in=8 * first_layer_numKernel, ch_out=4 * first_layer_numKernel)

        self.Up3 = up_conv(ch_in=4 * first_layer_numKernel, ch_out=2 * first_layer_numKernel)
        self.Up_conv3 = conv_block(ch_in=4 * first_layer_numKernel, ch_out=2 * first_layer_numKernel)

        self.Up2 = up_conv(ch_in=2 * first_layer_numKernel, ch_out=first_layer_numKernel)
        self.Up_conv2 = conv_block(ch_in=2 * first_layer_numKernel, ch_out=first_layer_numKernel)

        self.Conv_1x1 = nn.Sequential(
            nn.Conv2d(first_layer_numKernel, output_ch, kernel_size=1, stride=1, padding=0), nn.Sigmoid()  # Use sigmoid activation for binary segmentation
        )
        # self.Conv_1x1 =  nn.Conv2d(first_layer_numKernel, output_ch, kernel_size = 1, stride = 1, padding = 0)

    def forward(self, x):

        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)
        
        return d1

In [11]:
class U_Net_DDP(pl.LightningModule):
    def __init__(self, net, lr, loss, dice, jaccard):
        super().__init__()
        self.net = net
        self.lr = lr
        self.loss = loss 
        self.dice = dice
        self.jaccard = jaccard
        
    def forward(self, x):
        return self.net(x)
    
    def training_step(self, batch, batch_idx):
        imgs = batch['image'].float()
        true_masks = batch['mask']

        y_pred = self(imgs)
        loss = self.loss(y_pred, true_masks.float())
        y_pred = (y_pred >= 0.5).float()

        batch_dice_score = dice_metric(y_pred, true_masks)
        batch_jaccard_score = jaccard_index_metric(y_pred, true_masks)
        
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log("dice", batch_dice_score, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log("jaccard", batch_jaccard_score, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        imgs = batch['image'].float()
        true_masks = batch['mask']
        
        y_pred = self(imgs)
        loss = self.loss(y_pred, true_masks.float())
        y_pred = (y_pred >= 0.5).float()

        batch_dice_score = dice_metric(y_pred, true_masks)
        batch_jaccard_score = jaccard_index_metric(y_pred, true_masks)
        
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log("val_dice", batch_dice_score, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log("val_jaccard", batch_dice_score, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        #scheduler = CosineAnnealingLR(optimizer, self.trainer.max_epochs * 200, 0)
        return [optimizer] #, [scheduler]

In [12]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.max_validation_dice = float('-inf')

    def early_stop(self, validation_dice):
        if validation_dice > self.max_validation_dice:
            self.max_validation_dice = validation_dice
            self.counter = 0
        elif validation_dice < (self.max_validation_dice + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [13]:
#torch.multiprocessing.set_start_method('spawn')
checkpointing = ModelCheckpoint(monitor="val_loss",
                                dirpath='/kaggle/working/',
                                filename='unet-epoch-{epoch}-{val_loss:.2f}-{val_dice:.2f}')
es = EarlyStopping(monitor="val_loss")

trainer = pl.Trainer(devices=2, accelerator="gpu",strategy="ddp_notebook", max_epochs=5, callbacks=[es, checkpointing])

In [14]:
net = U_Net(img_ch=1, output_ch=1)
lr = 1e-3
loss = nn.BCELoss()
dice_metric = Dice()
jaccard_index_metric = BinaryJaccardIndex()
model = U_Net_DDP(net, lr, loss, dice_metric, jaccard_index_metric)

In [None]:
trainer.fit(model, train_dataloader, validation_dataloader)

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 16. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:441: It is recommended to use `self.log('dice_score', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:441: It is recommended to use `self.log('batch_jaccard_score', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Training: |          | 0/? [00:00<?, ?it/s]