In [1]:
!pip install monai
!pip install lightning


Collecting monai
  Downloading monai-1.3.0-202310121228-py3-none-any.whl.metadata (10 kB)
Downloading monai-1.3.0-202310121228-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.3.0
Collecting lightning
  Downloading lightning-2.2.2-py3-none-any.whl.metadata (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.4/53.4 kB[0m [31m907.0 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Downloading lightning-2.2.2-py3-none-any.whl (2.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: lightning
Successfully installed lightning-2.2.2


In [2]:
import os
import shutil
import glob

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
#import nibabel as nib
from sklearn.metrics import confusion_matrix, accuracy_score
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import Dataset, TensorDataset, DataLoader, random_split
from torchmetrics.classification import BinaryJaccardIndex, Dice, JaccardIndex
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
from torch.nn.modules.loss import BCEWithLogitsLoss
from monai.losses.dice import *

2024-04-20 02:14:06.947242: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-20 02:14:06.947377: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-20 02:14:07.053294: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
from monai.data import DataLoader 
# , ArrayDataset
# from torch.optim.lr_scheduler import CosineAnnealingLR
# from monai.transforms import (
#     EnsureChannelFirst,
#     AsDiscrete,
#     Compose,
#     LoadImage,
#     ScaleIntensity,
# )
import pytorch_lightning as pl
import lightning
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping


# Dataset Class

In [4]:
class BraTSDataset(Dataset):    
    def __init__(self, data_root_folder, folder = '', n_sample=None):
        main_folder = os.path.join(data_root_folder, folder)
        self.folder_path = os.path.join(main_folder, 'slice')
        #self.file_names = sorted(os.listdir(self.folder_path))[:n_sample]


    def __getitem__(self, index):
        file_name = os.listdir(self.folder_path)[index]
        #file_name = self.file_names[index]
        sample = np.load(os.path.join(self.folder_path, file_name))
        #eps = 0.0001
        img = sample[0,:,:]
        #img = img.resize((256, 256)) 
        diff = np.subtract(img.max(), img.min(), dtype=np.float64)
        denom = np.clip(diff, a_min=1e-8, a_max=None)
        img = (img - img.min()) / denom
        mask = sample[1, :, :]
        #mask= mask.resize((256, 256)) 
        mask[mask>0.0] = 1.0
        mask[mask==0.0] = 0
        img_as_tensor = np.expand_dims(img, axis=0)
        mask_as_tensor = np.expand_dims(mask, axis=0)
        img_as_tensor = torch.from_numpy(img_as_tensor)
        mask_as_tensor = torch.from_numpy(mask_as_tensor)
        
        #return img_as_tensor, mask_as_tensor
        return {
            'image': img_as_tensor.type(torch.FloatTensor),
            'mask': mask_as_tensor.type(torch.LongTensor),
            'img_id': file_name
        }
 
    def __len__(self):
        return len(os.listdir(self.folder_path))
        #return len(self.file_names)



# Load Dataset

In [5]:
data_root_folder = '/kaggle/input/brats-dataset/full_raw - Copy'
train_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'train')
val_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'val')
test_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'test')

In [6]:
BATCH_SIZE = 16
#device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=2)
validation_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=2)

## Pre-trained model

In [None]:
!pip install pretrained-backbones-unet


In [None]:
from backbones_unet.model.unet import Unet
from backbones_unet.utils.dataset import SemanticSegmentationDataset
from backbones_unet.model.losses import DiceLoss
from backbones_unet.utils.trainer import Trainer

In [None]:
next(iter(train_dataloader))

In [None]:
model = Unet(
    backbone='convnext_base', # backbone network name
    in_channels=1,            # input channels (1 for gray-scale images, 3 for RGB, etc.)
    num_classes=2,            # output channels (number of classes in your dataset)
)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(params, 1e-4) 

trainer = Trainer(
    model,                    # UNet model with pretrained backbone
    criterion=DiceLoss(),     # loss function for model convergence
    optimizer=optimizer,      # optimizer for regularization
    epochs=10                 # number of epochs for model training
)

trainer.fit(train_dataloader, validation_dataloader)

# Sub Classes for U-Net and Attention U-Net

In [8]:
class conv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class resconv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(resconv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True)
        )
        self.Conv_1x1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0)

    def forward(self, x):

        residual = self.Conv_1x1(x)
        x = self.conv(x)

        return residual + x


class up_conv(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(32, ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x

# U-Net

In [13]:
class U_Net(nn.Module):
    def __init__(self, img_ch=3, output_ch=2, first_layer_numKernel=64, name = "U_Net"):
        super(U_Net, self).__init__()
        self.name = name
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(ch_in=img_ch, ch_out=first_layer_numKernel)
        self.Conv2 = conv_block(ch_in=first_layer_numKernel, ch_out=2 * first_layer_numKernel)
        self.Conv3 = conv_block(ch_in=2 * first_layer_numKernel, ch_out=4 * first_layer_numKernel)
        self.Conv4 = conv_block(ch_in=4 * first_layer_numKernel, ch_out=8 * first_layer_numKernel)
        self.Conv5 = conv_block(ch_in=8 * first_layer_numKernel, ch_out=16 * first_layer_numKernel)

        self.Up5 = up_conv(ch_in=16 * first_layer_numKernel, ch_out=8 * first_layer_numKernel)
        self.Up_conv5 = conv_block(ch_in=16 * first_layer_numKernel, ch_out=8 * first_layer_numKernel)

        self.Up4 = up_conv(ch_in=8 * first_layer_numKernel, ch_out=4 * first_layer_numKernel)
        self.Up_conv4 = conv_block(ch_in=8 * first_layer_numKernel, ch_out=4 * first_layer_numKernel)

        self.Up3 = up_conv(ch_in=4 * first_layer_numKernel, ch_out=2 * first_layer_numKernel)
        self.Up_conv3 = conv_block(ch_in=4 * first_layer_numKernel, ch_out=2 * first_layer_numKernel)

        self.Up2 = up_conv(ch_in=2 * first_layer_numKernel, ch_out=first_layer_numKernel)
        self.Up_conv2 = conv_block(ch_in=2 * first_layer_numKernel, ch_out=first_layer_numKernel)

        self.Conv_1x1 = nn.Sequential(
            nn.Conv2d(first_layer_numKernel, output_ch, kernel_size=1, stride=1, padding=0), nn.Sigmoid() # Use sigmoid activation for binary segmentation
        )
        # self.Conv_1x1 =  nn.Conv2d(first_layer_numKernel, output_ch, kernel_size = 1, stride = 1, padding = 0)

    def forward(self, x):

        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)
        
        return d1

In [14]:
def dice_coeff_binary(y_pred, y_true):
        """Values must be only zero or one."""
        y_pred[y_pred >= 0.5] = 1
        y_pred[y_pred < 0.5] = 0
        eps = 0.0001
        inter = torch.dot(y_pred.view(-1).float(), y_true.view(-1).float())
        union = torch.sum(y_pred.float()) + torch.sum(y_true.float())
        return ((2 * inter.float() + eps) / (union.float() + eps))

In [11]:
next(iter(train_dataloader))['mask'].shape

torch.Size([16, 1, 240, 240])

In [15]:
class U_Net_DDP(pl.LightningModule):
    def __init__(self, net, lr, loss, jaccard, batch_size):
        super().__init__()
        self.net = net
        self.lr = lr
        self.loss = loss 
        #self.dice = dice
        self.jaccard = jaccard
        self.sigmoid = nn.Sigmoid()
        self.batch_size = batch_size
        
    def forward(self, x):
        return self.net(x)
    
    def training_step(self, batch, batch_idx):
        imgs = batch['image']
        true_masks = batch['mask'] #.unsqueeze(1)
        y_pred = self(imgs)  
        loss = self.loss(y_pred, true_masks)
        
        #y_pred = (y_pred >= 0.5).float()
        y_pred = torch.argmax(y_pred, dim=1)
        y_pred = y_pred.unsqueeze(1)

        batch_dice_score = dice_coeff_binary(y_pred, true_masks)
        batch_jaccard_score = jaccard_index_metric(y_pred, true_masks)
        
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size = self.batch_size)
        self.log("train_dice", batch_dice_score, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True,  batch_size = self.batch_size)
        self.log("train_jaccard", batch_jaccard_score, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True,  batch_size = self.batch_size)

        return loss
    
    def validation_step(self, batch, batch_idx):
        imgs = batch['image'] #.float()
        true_masks = batch['mask'] #.unsqueeze(1)
        
        y_pred = self(imgs)
        #y_pred = (y_pred >= 0.5).float()
        loss = self.loss(y_pred, true_masks)

        y_pred = torch.argmax(y_pred, dim=1)
        y_pred = y_pred.unsqueeze(1)
        

        batch_dice_score = dice_coeff_binary(y_pred, true_masks)
        batch_jaccard_score = jaccard_index_metric(y_pred, true_masks)
        
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True,  batch_size = self.batch_size)
        self.log("val_dice", batch_dice_score, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True,  batch_size = self.batch_size)
        self.log("val_jaccard", batch_dice_score, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True,  batch_size = self.batch_size)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        #scheduler = CosineAnnealingLR(optimizer, self.trainer.max_epochs * 200, 0)
        return [optimizer] #, [scheduler]

In [16]:
#torch.multiprocessing.set_start_method('spawn')
checkpointing = ModelCheckpoint(monitor="val_loss",
                                dirpath='/kaggle/working/',
                                filename='unet-epoch-{epoch}-{val_loss:.2f}-{val_dice:.2f}-{val_jaccard:.2f}', 
                                save_top_k=-1)
es = EarlyStopping(monitor="val_loss")

trainer = pl.Trainer(precision=16, 
                     devices=2, 
                     accelerator="gpu",
                     strategy="ddp_notebook", 
                     max_epochs=10, 
                     callbacks=[es, checkpointing])

/opt/conda/lib/python3.10/site-packages/lightning_fabric/connector.py:563: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs


In [17]:
net = U_Net(img_ch=1, output_ch=2)
lr = 1e-3
loss = nn.CrossEntropyLoss()
#loss = nn.BCEWithLogitsLoss()
#loss = GeneralizedDiceFocalLoss()
#dice_metric = dice_coeff_binary()
jaccard_index_metric = BinaryJaccardIndex()
# model.load_state_dict(checkpoint['state_dict'])
# #optimizer.load_state_dict(checkpoint['optimizer_states'])
# epoch = checkpoint['epoch']
#checkpoint = torch.load('/kaggle/input/unet/pytorch/unet-epoch-1/1/unet-epoch-epoch1-val_loss0.03-val_dice0.41.ckpt')
#unet_weights = checkpoint['state_dict']

model = U_Net_DDP(net, lr, loss, jaccard_index_metric, BATCH_SIZE)


In [18]:
trainer.fit(model, train_dataloader, validation_dataloader)

RuntimeError: Lightning can't create new processes if CUDA is already initialized. Did you manually call `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any other way? Please remove any such calls, or change the selected strategy. You will have to restart the Python kernel.

In [None]:
!cd /kaggle/working  # Assuming the folder is in the working directory


In [None]:
!ls

In [None]:
!tar -czvf checkpoints.zip -C . .