In [1]:
!pip install monai
!pip install lightning




In [2]:
import os
import shutil
import glob

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
#import nibabel as nib
from sklearn.metrics import confusion_matrix, accuracy_score
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import Dataset, TensorDataset, DataLoader, random_split
from torchmetrics.classification import BinaryJaccardIndex, Dice
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
from torch.nn.modules.loss import BCEWithLogitsLoss


In [3]:
from monai.data import DataLoader 
# , ArrayDataset
# from torch.optim.lr_scheduler import CosineAnnealingLR
# from monai.transforms import (
#     EnsureChannelFirst,
#     AsDiscrete,
#     Compose,
#     LoadImage,
#     ScaleIntensity,
# )
import pytorch_lightning as pl
import lightning
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping


2024-04-14 21:52:50.123784: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-14 21:52:50.123844: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-14 21:52:50.125431: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


# Dataset Class

In [15]:
class BraTSDataset(Dataset):    
    def __init__(self, data_root_folder, folder = '', n_sample=None):
        main_folder = os.path.join(data_root_folder, folder)
        self.folder_path = os.path.join(main_folder, 'slice')
        #self.file_names = sorted(os.listdir(self.folder_path))[:n_sample]


    def __getitem__(self, index):
        file_name = os.listdir(self.folder_path)[index]
        #file_name = self.file_names[index]
        sample = torch.from_numpy(np.load(os.path.join(self.folder_path, file_name)))
        img_as_tensor = np.expand_dims(sample[0,:,:], axis=0)
        mask_as_tensor = np.expand_dims(sample[1,:,:], axis=0)
        return {
            'image': img_as_tensor,
            'mask': mask_as_tensor,
            'img_id': file_name
        }
 
    def __len__(self):
        return len(os.listdir(self.folder_path))
        #return len(self.file_names)



# Load Dataset

In [16]:
data_root_folder = '/kaggle/input/full_raw - Copy'
train_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'train')
val_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'val')
test_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'test')

In [17]:
BATCH_SIZE = 16
#device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [18]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=2)
validation_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=2)

# Sub Classes for U-Net and Attention U-Net

In [19]:
class conv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class resconv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(resconv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True)
        )
        self.Conv_1x1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0)

    def forward(self, x):

        residual = self.Conv_1x1(x)
        x = self.conv(x)

        return residual + x


class up_conv(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x

# U-Net

In [20]:
class U_Net(nn.Module):
    def __init__(self, img_ch=3, output_ch=1, first_layer_numKernel=64, name = "U_Net"):
        super(U_Net, self).__init__()
        self.name = name
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(ch_in=img_ch, ch_out=first_layer_numKernel)
        self.Conv2 = conv_block(ch_in=first_layer_numKernel, ch_out=2 * first_layer_numKernel)
        self.Conv3 = conv_block(ch_in=2 * first_layer_numKernel, ch_out=4 * first_layer_numKernel)
        self.Conv4 = conv_block(ch_in=4 * first_layer_numKernel, ch_out=8 * first_layer_numKernel)
        self.Conv5 = conv_block(ch_in=8 * first_layer_numKernel, ch_out=16 * first_layer_numKernel)

        self.Up5 = up_conv(ch_in=16 * first_layer_numKernel, ch_out=8 * first_layer_numKernel)
        self.Up_conv5 = conv_block(ch_in=16 * first_layer_numKernel, ch_out=8 * first_layer_numKernel)

        self.Up4 = up_conv(ch_in=8 * first_layer_numKernel, ch_out=4 * first_layer_numKernel)
        self.Up_conv4 = conv_block(ch_in=8 * first_layer_numKernel, ch_out=4 * first_layer_numKernel)

        self.Up3 = up_conv(ch_in=4 * first_layer_numKernel, ch_out=2 * first_layer_numKernel)
        self.Up_conv3 = conv_block(ch_in=4 * first_layer_numKernel, ch_out=2 * first_layer_numKernel)

        self.Up2 = up_conv(ch_in=2 * first_layer_numKernel, ch_out=first_layer_numKernel)
        self.Up_conv2 = conv_block(ch_in=2 * first_layer_numKernel, ch_out=first_layer_numKernel)

        self.Conv_1x1 = nn.Sequential(
            nn.Conv2d(first_layer_numKernel, output_ch, kernel_size=1, stride=1, padding=0) # Use sigmoid activation for binary segmentation
        )
        # self.Conv_1x1 =  nn.Conv2d(first_layer_numKernel, output_ch, kernel_size = 1, stride = 1, padding = 0)

    def forward(self, x):

        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)
        
        return d1

In [21]:
class U_Net_DDP(pl.LightningModule):
    def __init__(self, net, lr, loss, dice, jaccard):
        super().__init__()
        self.net = net
        self.lr = lr
        self.loss = loss 
        self.dice = dice
        self.jaccard = jaccard
        
    def forward(self, x):
        return self.net(x)
    
    def training_step(self, batch, batch_idx):
        imgs = batch['image'].float()
        true_masks = batch['mask']

        y_pred = self(imgs)
        loss = self.loss(y_pred, true_masks.float())
        y_pred = (y_pred >= 0.5).float()

        batch_dice_score = dice_metric(y_pred, true_masks)
        batch_jaccard_score = jaccard_index_metric(y_pred, true_masks)
        
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log("dice", batch_dice_score, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log("jaccard", batch_jaccard_score, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        imgs = batch['image'].float()
        true_masks = batch['mask']
        
        y_pred = self(imgs)
        loss = self.loss(y_pred, true_masks.float())
        y_pred = (y_pred >= 0.5).float()

        batch_dice_score = dice_metric(y_pred, true_masks)
        batch_jaccard_score = jaccard_index_metric(y_pred, true_masks)
        
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log("val_dice", batch_dice_score, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log("val_jaccard", batch_dice_score, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        #scheduler = CosineAnnealingLR(optimizer, self.trainer.max_epochs * 200, 0)
        return [optimizer] #, [scheduler]

In [22]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.max_validation_dice = float('-inf')

    def early_stop(self, validation_dice):
        if validation_dice > self.max_validation_dice:
            self.max_validation_dice = validation_dice
            self.counter = 0
        elif validation_dice < (self.max_validation_dice + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [23]:
#torch.multiprocessing.set_start_method('spawn')
checkpointing = ModelCheckpoint(monitor="val_loss",
                                dirpath='/kaggle/working/',
                                filename='unet-epoch-{epoch}-{val_loss:.2f}-{val_dice:.2f}')
es = EarlyStopping(monitor="val_loss")

trainer = pl.Trainer(precision=16, 
                     devices=2, 
                     accelerator="gpu",
                     strategy="ddp_notebook", 
                     max_epochs=5, 
                     callbacks=[es, checkpointing])

INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs


In [24]:
net = U_Net(img_ch=1, output_ch=1)
lr = 1e-3
#loss = nn.BCELoss()
loss = nn.BCEWithLogitsLoss()
dice_metric = Dice()
jaccard_index_metric = BinaryJaccardIndex()
model = U_Net_DDP(net, lr, loss, dice_metric, jaccard_index_metric)

In [25]:
trainer.fit(model, train_dataloader, validation_dataloader)

INFO: ----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /kaggle/working exists and is not empty.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 16. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 13. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [27]:
!cd /kaggle/working  # Assuming the folder is in the working directory


In [29]:
!tar -czvf my_work.zip -C . .


./
./unet-epoch-epoch=1-val_loss=0.03-val_dice=0.41.ckpt
./lightning_logs/
./lightning_logs/version_1/
./lightning_logs/version_1/events.out.tfevents.1713130388.93725de63ae6.177.0
./lightning_logs/version_1/hparams.yaml
./lightning_logs/version_5/
./lightning_logs/version_5/events.out.tfevents.1713131230.93725de63ae6.589.0
./lightning_logs/version_5/hparams.yaml
./lightning_logs/version_4/
./lightning_logs/version_4/events.out.tfevents.1713130935.93725de63ae6.482.0
./lightning_logs/version_4/hparams.yaml
./lightning_logs/version_3/
./lightning_logs/version_3/hparams.yaml
./lightning_logs/version_3/events.out.tfevents.1713130903.93725de63ae6.375.0
./lightning_logs/version_7/
./lightning_logs/version_7/events.out.tfevents.1713131446.93725de63ae6.1566.0
./lightning_logs/version_7/hparams.yaml
./lightning_logs/version_9/
./lightning_logs/version_9/hparams.yaml
./lightning_logs/version_9/events.out.tfevents.1713131784.93725de63ae6.1817.0
./lightning_logs/version_8/
./lightning_logs/version_

In [30]:
!zip lightning_logs.zip


zip error: Nothing to do! (lightning_logs.zip)
