In [None]:
import os
import gzip

def get_all_path_by_ext(root, extension, match=""):
    """
    Retrieve all file paths in the root directory and subdirectories with a specific extension
    and optionally filter by filenames containing a specific string.
    
    Parameters:
        root (str): The root directory to search.
        extension (str): The file extension to search for (e.g., '.gz').
        match (str, optional): A string that the filename must contain. Defaults to "" (no filter).

    Returns:
        list: A list of file paths that match the given extension and match criteria.
    """
    # Ensure the extension has a dot at the start
    if not extension.startswith("."):
        extension = "." + extension
    
    file_paths = []
    for dirpath, _, filenames in os.walk(root):
        for filename in filenames:
            if filename.endswith(extension) and (match in filename):
                file_path = os.path.join(dirpath, filename).replace("\\","/")
                file_paths.append(file_path)
    
    return file_paths



In [None]:
import os
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
def load_image(img_path):
    """
    Helper function to load and convert an image to RGB and return it as a numpy array.

    Parameters:
        img_path (str): Path to the image.

    Returns:
        np.ndarray: Loaded RGB image as a numpy array.
    """
    img = Image.open(img_path)
    return np.array(img)
def find_roi_and_crop(image, seg, padding_ratio=0.2):
    """
    Find lesions based on segmentation mask and crop image with optional padding.
    Parameters:
        image (np.ndarray): Original image.
        seg (np.ndarray): Segmentation mask, where the lesion region has a non-zero value.
        padding_ratio (float): Tỉ lệ padding được thêm vào vùng tổn thương.
    
    Returns:
        cropped_image (np.ndarray): Percentage of padding added to the lesion.
    """
    image = cv2.resize(image, (seg.shape[0], seg.shape[1]))
    y_indices, x_indices = np.where(seg != 0)
    if len(x_indices) == 0 or len(y_indices) == 0:
        return None
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)
    w = x_max - x_min
    h = y_max - y_min
    pad_w = int(w * padding_ratio)
    pad_h = int(h * padding_ratio)
    x_min = max(0, x_min - pad_w // 2)
    y_min = max(0, y_min - pad_h // 2)
    x_max = min(image.shape[1], x_max + pad_w // 2)
    y_max = min(image.shape[0], y_max + pad_h // 2)
    cropped_image = image[y_min:y_max, x_min:x_max]
    if cropped_image.size == 0:
        return None
    cropped_image = cv2.resize(cropped_image, (seg.shape[0], seg.shape[1]))
    return cropped_image
def plot_image(image, title="Image"):
    """
    Plot the image with a title.

    Parameters:
        image (np.ndarray): Image to plot.
        title (str): Title of the plot.
    
    Returns:
        None
    """
    fig, a = plt.subplots()
    a.imshow(image)
    plt.show()

def save_image(image, save_path):
    """
    Save the image to a specified path.

    Parameters:
        image (np.ndarray): Image to save. Should be in RGB format.
        save_path (str): Path to save the image.
    
    Returns:
        None
    """
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    img_to_save = Image.fromarray(image)
    img_to_save.save(save_path)
    print(f"Image saved at: {save_path}")



In [None]:
import os
import nibabel as nib
import torch 
from torch.utils.data import Dataset
from torchvision import transforms

class NiftiDatasetForSegmentation(Dataset):
    def __init__(self, nii_path_tuple_list, nii_seg_path_list, volume_slices, transform_img=None, transform_mask=None):
        self.nii_path_tuple_list = nii_path_tuple_list
        self.nii_seg_path_list = nii_seg_path_list
        self.volume_slices = volume_slices
        self.transform_img = transform_img
        self.transform_mask = transform_mask

    def __len__(self):
        return len(self.nii_path_tuple_list)

    def __getitem__(self, index):
        nii_path_tuple_case = self.nii_path_tuple_list[index]
        nii_data_tuple_case = [nib.load(nii_path).get_fdata() for nii_path in nii_path_tuple_case]
        nii_seg_case = nib.load(self.nii_seg_path_list[index]).get_fdata()

        X, y = [], []
        bounding_boxes = []

        # Calculate the largest bounding box across all slices in `volume_slices`
        start = nii_seg_case.shape[2] // 2 - self.volume_slices // 2
        for j in range(self.volume_slices):
            slice_index = j + start
            for nii_data in nii_data_tuple_case:
                # Find bounding box of non-zero pixel regions in each slice
                x_nonzero, y_nonzero = torch.nonzero(torch.from_numpy(nii_data[:, :, slice_index]), as_tuple=True)
                if x_nonzero.size(0) > 0 and y_nonzero.size(0) > 0:
                    x_min, x_max = x_nonzero.min().item(), x_nonzero.max().item()
                    y_min, y_max = y_nonzero.min().item(), y_nonzero.max().item()
                    bounding_boxes.append((x_min, x_max, y_min, y_max))

        # Calculate the largest enclosing bounding box across all slices
        if bounding_boxes:
            x_min = min(bbox[0] for bbox in bounding_boxes)
            x_max = max(bbox[1] for bbox in bounding_boxes)
            y_min = min(bbox[2] for bbox in bounding_boxes)
            y_max = max(bbox[3] for bbox in bounding_boxes)
        else:
            # If no non-zero regions found, take the entire image
            x_min, x_max = 0, nii_seg_case.shape[0]
            y_min, y_max = 0, nii_seg_case.shape[1]

        # Crop each image slice and mask according to the largest bounding box
        for j in range(self.volume_slices):
            slice_index = j + start

            # Crop and add to the list for MRI image slices
            slice_channels_data = []
            for nii_data in nii_data_tuple_case:
                slice_channel_data = torch.from_numpy(nii_data[:, :, slice_index][x_min:x_max+1, y_min:y_max+1]).unsqueeze(0).float()
                slice_channels_data.append(slice_channel_data)
            slice_channels_data_combined = torch.cat(slice_channels_data, dim=0)

            # Crop and add to the list for the mask
            nii_seg_slice = torch.from_numpy(nii_seg_case[:, :, slice_index][x_min:x_max+1, y_min:y_max+1]).unsqueeze(0).float()

            # Apply transforms if provided
            if self.transform_img:
                slice_channels_data_combined = self.transform_img(slice_channels_data_combined)
            if self.transform_mask:
                nii_seg_slice = self.transform_mask(nii_seg_slice)

            # Check output data types
            if not isinstance(slice_channels_data_combined, torch.Tensor):
                raise TypeError(f"Expected output transform_img to be a Tensor, but got {type(slice_channels_data_combined)}.")
            if not isinstance(nii_seg_slice, torch.Tensor):
                raise TypeError(f"Expected output transform_mask to be a Tensor, but got {type(nii_seg_slice)}.")

            X.append(slice_channels_data_combined)
            y.append(nii_seg_slice)

        # Ensure the output has consistent dimensions
        X = torch.stack(X, dim=0).permute(1, 0, 2, 3)
        y = torch.stack(y).permute(1, 0, 2, 3)

        return X, y


In [None]:
import os
import sys
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
import torch.nn.functional as F  
from torchvision import transforms
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt
import pickle
from torch.utils.checkpoint import checkpoint

In [None]:
!pip install monai
!pip install einops
from monai.networks.nets import SwinUNETR

In [None]:
MODEL_ROOT_PATH = "/kaggle/working/models"
VOLUME_SLICES = 128
IMAGE_SIZE = (128,128)
DATA_ROOT_PATH ="/kaggle/input/brats-africa-small"
NUM_EPOCHS = 5
OUTPUT_DIR = "/kaggle/working/"

In [None]:
nii_flair_path_list = get_all_path_by_ext(DATA_ROOT_PATH,".nii","t2f")
nii_t1ce_path_list = get_all_path_by_ext(DATA_ROOT_PATH,".nii","t1c")
nii_seg_path_list = get_all_path_by_ext(DATA_ROOT_PATH,".nii","seg")
nii_path_tuple_list = list(zip(nii_flair_path_list, nii_t1ce_path_list))

In [None]:
train_nii_path_tuple_list, test_nii_path_tuple_list,train_nii_seg_path_list,test_nii_seg_path_list = train_test_split(nii_path_tuple_list,nii_seg_path_list,test_size=0.2,random_state=42) 

In [None]:
class Brats2023DataSet(NiftiDatasetForSegmentation):
    def __getitem__(self, index):
        X, Y = super().__getitem__(index)
        Y = Y.squeeze(0).int()
        Y[Y == 4] = 3 
        Y = torch.nn.functional.one_hot(Y.long(), num_classes=4).permute(3, 0, 1, 2)
        return X, Y.float()

In [None]:
transform_img = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
])
transform_mask = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
])

In [None]:
train_dataset = Brats2023DataSet(train_nii_path_tuple_list,train_nii_seg_path_list,VOLUME_SLICES,transform_img=transform_img,transform_mask=transform_mask)
val_dataset = Brats2023DataSet(test_nii_path_tuple_list,test_nii_seg_path_list,VOLUME_SLICES,transform_img=transform_img,transform_mask=transform_mask)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True)

In [None]:
import matplotlib.pyplot as plt

def plot_loss(train_loss, valid_loss):
    # Create a new figure and axis for training loss
    figure_1, train_ax = plt.subplots()

    # Create a new figure and axis for validation loss
    figure_2, valid_ax = plt.subplots()

    # Plot the training loss values
    train_ax.plot(train_loss)
    train_ax.set_xlabel('Epoch')               # Label the X-axis as 'Epoch'
    train_ax.set_ylabel('Training Loss')       # Label the Y-axis as 'Training Loss'
    train_ax.legend()                          # Show the legend (label box)

    # Plot the validation loss values
    valid_ax.plot(valid_loss)
    valid_ax.set_xlabel('Epoch')               # Label the X-axis as 'Epoch'
    valid_ax.set_ylabel('Validation Loss')     # Label the Y-axis as 'Validation Loss'
    valid_ax.legend()                          # Show the legend (label box)

    # Save the images to the output directory
    figure_1.savefig(f"{OUTPUT_DIR}/train_loss.png")
    figure_2.savefig(f"{OUTPUT_DIR}/valid_loss.png")


In [None]:
import torch
import torch.nn as nn


class EDiceLoss(nn.Module):
    """Dice loss tailored to Brats need.
    """

    def __init__(self, do_sigmoid=True):
        super(EDiceLoss, self).__init__()
        self.do_sigmoid = do_sigmoid
        self.device = "cpu"

    def binary_dice(self, inputs, targets, label_index, metric_mode=False):
        smooth = 1.
        if self.do_sigmoid:
            inputs = torch.sigmoid(inputs)

        if metric_mode:
            inputs = inputs > 0.5
            if targets.sum() == 0:
                if inputs.sum() == 0:
                    return torch.tensor(1., device="cuda")
                else:
                    return torch.tensor(0., device="cuda")
        intersection = EDiceLoss.compute_intersection(inputs, targets)
        if metric_mode:
            dice = (2 * intersection) / ((inputs.sum() + targets.sum()) * 1.0)
        else:
            dice = (2 * intersection + smooth) / (inputs.pow(2).sum() + targets.pow(2).sum() + smooth)
        if metric_mode:
            return dice
        return 1 - dice

    @staticmethod
    def compute_intersection(inputs, targets):
        intersection = torch.sum(inputs * targets)
        return intersection

    def forward(self, inputs, target):
        dice = 0
        for i in range(target.size(1)):
            dice = dice + self.binary_dice(inputs[:, i, ...], target[:, i, ...], i)

        final_dice = dice / target.size(1)
        return final_dice

    def metric(self, inputs, target):
        dices = []
        for j in range(target.size(0)):
            dice = []
            for i in range(target.size(1)):
                dice.append(self.binary_dice(inputs[j, i], target[j, i], i, True))
            dices.append(dice)
        return dices

In [None]:
model = SwinUNETR(
                  in_channels=2,
                  out_channels=4,
                  feature_size=48,
                  use_checkpoint=True,
                  )

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device) 
if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)
criterion = EDiceLoss().cuda()
metric = criterion.metric
optimizer = AdamW(model.parameters(), lr=1e-4)  

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    train_loss_list = []
    model.train()
    tqdm_bar = tqdm(dataloader, total=len(dataloader))
    for inputs, targets in tqdm_bar:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss_list.append(loss.detach().item())
        tqdm_bar.set_description(desc=f"Training Loss: {loss.detach().item():.5f}")
        del inputs,targets, outputs, loss
        torch.cuda.empty_cache()
    return train_loss_list

def evaluate_one_epoch(model, dataloader, criterion, metric, device):
    eval_loss_list = []
    eval_dices_list= []
    model.eval()

    with torch.no_grad():
        tqdm_bar = tqdm(dataloader, total=len(dataloader))
        for inputs, targets in tqdm_bar:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            dices = metric(outputs, targets)
            mean_dice = torch.tensor(dices).mean().item()
            eval_loss_list.append(loss.detach().item())
            eval_dices_list.extend(dices)
            tqdm_bar.set_description(desc=f"Valid Loss: {loss.detach().item():.5f}/Dice score:{mean_dice:.5f}")
            del inputs,targets, outputs, loss,dices
            torch.cuda.empty_cache()
    return eval_loss_list

In [None]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
checkpoint_path = None
if checkpoint_path:
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'],strict=False)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    loss_dict = checkpoint['loss_dict']
else:
    start_epoch = 0
    loss_dict = {'train_loss': [], 'valid_loss': []}

'''
Train the model over all epochs
'''
for epoch in range(start_epoch, NUM_EPOCHS):
    print("----------Epoch {}----------".format(epoch + 1))

    # Train the model for one epoch
    train_loss_list = train_one_epoch(model, train_dataloader, criterion, optimizer, device)
    loss_dict['train_loss'].extend(train_loss_list)

    # Run evaluation
    eval_loss_list = evaluate_one_epoch(model, val_dataloader, criterion, metric, device)
    loss_dict['valid_loss'].extend(eval_loss_list)

    # Save model checkpoint
    ckpt_file_name = f"{OUTPUT_DIR}/epoch_{epoch+1}_model.pth"
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss_dict': loss_dict,
    }, ckpt_file_name)

    plot_loss(loss_dict['train_loss'], loss_dict['valid_loss'])


with open(f"{OUTPUT_DIR}/loss_dict.pkl", "wb") as file:
    pickle.dump(loss_dict, file)

print("Training Finished!")

In [None]:
!nvidia-smi