In [None]:
# Install python packages
!pip install -q pandas pyathena boto3 requests pydicom IPython jupyter notebook aioboto3 pylibjpeg pylibjpeg-openjpeg pillow synapseclient nibabel pydicom nifti2dicom matplotlib split-folders torchinfo segmentation-models-pytorch-3d livelossplot torchmetrics tensorboard nilearn

# # Install playwright and dependencies to export to PDF
# !pip install -q nbconvert[webpdf]
# !playwright install-deps
# !playwright install chromium

In [None]:
from dataclasses import	dataclass
from IPython.display import	clear_output, display
from PIL import ImageShow
from scipy import stats
from torch.optim import Adam
# from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from typing	import List, Union
import glob
import json
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import os
import pandas as pd
import pydicom
import random
import time
import torch
import torch.nn	as nn
import torch.nn.functional as F

import warnings
warnings.simplefilter("ignore")

from tqdm import TqdmExperimentalWarning
from tqdm.autonotebook import tqdm
warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)

if torch.cuda.is_available():
  print("The GPU is connected! :)")
else:
  print("The GPU is not connected :(")

class GlobalConfig:
    root_dir = 'brats2023/Data/BraTS-GLI'
    train_root_dir = f'{root_dir}/train'
    test_root_dir = f'{root_dir}/validate'
    model_path = f'{root_dir}/models'
    seed = 55
    modalities = ['t1c', 't1n', 't2f', 't2w']
    train_data_csv = './train_data.csv'

def seed_everything(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

config = GlobalConfig()
seed_everything(config.seed)

In [None]:
# Let's view the nifti files and 50th slice of all modalities.
def load_nifti_file(file_path):
    """Helper function that loads a NIfTI file and return the data array."""
    niidata = nib.load(file_path)
    data = niidata.get_fdata()
    # Flip the image vertically
    flip_data = np.flip(data, axis=0).copy()

    return flip_data, data, niidata

def get_patient_sorted_files(folder_path, required_modalities):
    patient_count = 0
    patient_folder_path = []
    for patient_id in sorted(os.listdir(folder_path)):
        modalities = sorted([f.split("-")[-1].split('.nii.gz')[0] for f in os.listdir(f"{folder_path}/{patient_id}")])
        if modalities == required_modalities:
            file_dict = {}
            file_dict['patient_id'] = patient_id
            for mod in modalities:
                filename = f"{patient_id}-{mod}.nii.gz"
                filepath = os.path.join(folder_path, patient_id, filename)
                file_dict[mod] = filepath
            patient_folder_path.append(file_dict)
        else:
            print(f"5 files not found in {folder_path}{patient_id}")
        patient_count += 1
    return patient_folder_path

validate_data_df = pd.DataFrame.from_records(get_patient_sorted_files(config.test_root_dir, config.modalities))
training_data_df = pd.DataFrame.from_records(get_patient_sorted_files(config.train_root_dir, ['seg']+config.modalities))

patient_index = random.randint(0, len(training_data_df))
print(f"Patient : {patient_index}")
patient_id = training_data_df.iloc[patient_index]['patient_id']
patient_folder = f"{config.train_root_dir}/{patient_id}"
slice_number = 50
training_modalities = config.modalities + ['seg']
fig, axs = plt.subplots(1, len(training_modalities), figsize=(16, 8))
for i, ax in enumerate(axs.flat):
    file_path = f"{patient_folder}"
    flip_data, data, niidata = load_nifti_file(os.path.join(patient_folder, f"{patient_id}-{training_modalities[i]}.nii.gz"))
    ax.imshow(flip_data[:, :, slice_number], cmap='gray')
    ax.set_title(training_modalities[i])
    ax.axis('off')
    
plt.tight_layout()
plt.show()


In [None]:
# Using training_data_df

class BratsDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df
        self.modalities = config.modalities

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        id_ = self.df.loc[idx, 'patient_id']
        df_modalities = self.df.columns.to_list()[1:]
        images = []
        for modality in df_modalities:
            if modality == 'seg':
                mask_path = self.df.loc[idx, modality]
                mask = self.load_img(mask_path)
                mask = self.preprocess_mask_labels(mask)
            else:
                img_path = self.df.loc[idx, modality]
                img = self.load_img(img_path)
                img = self.normalize(img)
                images.append(img)
        img = np.stack(images)
        img = np.moveaxis(img, (0, 1, 2, 3), (0, 3, 2, 1))

        if 'seg' in df_modalities:
            return { "Id": id_, "image": img, "mask": mask}
        else:
            return { "Id": id_, "image": img}

    def load_img(self, file_path):
        data = nib.load(file_path)
        data = np.asarray(data.dataobj)
        return data

    def normalize(self, data: np.ndarray):
        data_min = np.min(data)
        return (data - data_min) / (np.max(data) - data_min)

    def resize(self, data: np.ndarray):
        data = resize(data, (78, 120, 120), preserve_range=True)
        return data

    def preprocess_mask_labels(self, mask: np.ndarray):

        mask_WT = mask.copy()
        mask_WT[mask_WT == 1] = 1
        mask_WT[mask_WT == 2] = 1
        mask_WT[mask_WT == 4] = 1

        mask_TC = mask.copy()
        mask_TC[mask_TC == 1] = 1
        mask_TC[mask_TC == 2] = 0
        mask_TC[mask_TC == 4] = 1

        mask_ET = mask.copy()
        mask_ET[mask_ET == 1] = 0
        mask_ET[mask_ET == 2] = 0
        mask_ET[mask_ET == 4] = 1

        mask = np.stack([mask_WT, mask_TC, mask_ET])
        mask = np.moveaxis(mask, (0, 1, 2, 3), (0, 3, 2, 1))

        return mask

In [None]:
def test_data_loader(batch_size, num_workers):
    dataset = BratsDataset(training_data_df)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        shuffle=True,
    )
    
    data = next(iter(dataloader))
    print(data['Id'], data['image'].shape, data['mask'].shape)
    
    img_tensor = data['image'][0].squeeze()[0].cpu().detach().numpy() 
    mask_tensor = data['mask'][0].squeeze()[0].squeeze().cpu().detach().numpy()
    print("Num uniq Image values :", len(np.unique(img_tensor, return_counts=True)[0]))
    print("Min/Max Image values:", img_tensor.min(), img_tensor.max())
    print("Num uniq Mask values:", np.unique(mask_tensor, return_counts=True))
    
    image = np.rot90(montage(img_tensor))
    mask = np.rot90(montage(mask_tensor))
    
    fig, ax = plt.subplots(1, 1, figsize = (20, 20))
    ax.imshow(image, cmap ='bone')
    ax.imshow(np.ma.masked_where(mask == False, mask), cmap='cool', alpha=0.6)

test_data_loader(batch_size=4, num_workers=1)

In [None]:
# Loss functions

def dice_coef_metric(probabilities: torch.Tensor,
                     truth: torch.Tensor,
                     treshold: float = 0.5,
                     eps: float = 1e-9) -> np.ndarray:
    """
    Calculate Dice score for data batch.
    Params:
        probobilities: model outputs after activation function.
        truth: truth values.
        threshold: threshold for probabilities.
        eps: additive to refine the estimate.
        Returns: dice score aka f1.
    """
    scores = []
    num = probabilities.shape[0]
    predictions = (probabilities >= treshold).float()
    assert(predictions.shape == truth.shape)
    for i in range(num):
        prediction = predictions[i]
        truth_ = truth[i]
        intersection = 2.0 * (truth_ * prediction).sum()
        union = truth_.sum() + prediction.sum()
        if truth_.sum() == 0 and prediction.sum() == 0:
            scores.append(1.0)
        else:
            score = (intersection + eps) / (union + eps)  # Added eps to denominator
            scores.append(min(max(score.item(), 0.0), 1.0))  # Clamp between 0 and 1
    return np.mean(scores)


def jaccard_coef_metric(probabilities: torch.Tensor,
               truth: torch.Tensor,
               treshold: float = 0.5,
               eps: float = 1e-9) -> np.ndarray:
    """
    Calculate Jaccard index for data batch.
    Params:
        probobilities: model outputs after activation function.
        truth: truth values.
        threshold: threshold for probabilities.
        eps: additive to refine the estimate.
        Returns: jaccard score aka iou."
    """
    scores = []
    num = probabilities.shape[0]
    predictions = (probabilities >= treshold).float()
    assert(predictions.shape == truth.shape)
    for i in range(num):
        prediction = predictions[i]
        truth_ = truth[i]
        intersection = (prediction * truth_).sum()
        union = (prediction.sum() + truth_.sum()) - intersection
        if truth_.sum() == 0 and prediction.sum() == 0:
            scores.append(1.0)
        else:
            score = (intersection + eps) / (union + eps)  # Added eps to denominator
            scores.append(min(max(score.item(), 0.0), 1.0))  # Clamp between 0 and 1
    return np.mean(scores)


class DiceLoss(nn.Module):
    """Calculate dice loss."""
    def __init__(self, eps: float = 1e-9):
        super(DiceLoss, self).__init__()
        self.eps = eps

    def forward(self,
                logits: torch.Tensor,
                targets: torch.Tensor) -> torch.Tensor:
        num = targets.size(0)
        probability = torch.sigmoid(logits)
        probability = probability.view(num, -1)
        targets = targets.view(num, -1)
        assert(probability.shape == targets.shape)
        intersection = 2.0 * (probability * targets).sum()
        union = probability.sum() + targets.sum()
        dice_score = (intersection + self.eps) / (union + self.eps)  # Added eps to denominator
        dice_score = torch.clamp(dice_score, 0.0, 1.0)  # Clamp between 0 and 1
        return 1.0 - dice_score


class BCEDiceLoss(nn.Module):
    """Compute objective loss: BCE loss + DICE loss."""
    def __init__(self):
        super(BCEDiceLoss, self).__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()

    def forward(self, 
                logits: torch.Tensor,
                targets: torch.Tensor) -> torch.Tensor:
        assert(logits.shape == targets.shape)
        dice_loss = self.dice(logits, targets)
        bce_loss = self.bce(logits, targets)

        return bce_loss + dice_loss


class Meter:
    '''factory for storing and updating iou and dice scores.'''
    def __init__(self, treshold: float = 0.5):
        self.threshold: float = treshold
        self.dice_scores: list = []
        self.iou_scores: list = []
    
    def update(self, logits: torch.Tensor, targets: torch.Tensor):
        """
        Takes: logits from output model and targets,
        calculates dice and iou scores, and stores them in lists.
        """
        probs = torch.sigmoid(logits)
        dice = dice_coef_metric(probs, targets, self.threshold)
        iou = jaccard_coef_metric(probs, targets, self.threshold)
        
        self.dice_scores.append(dice)
        self.iou_scores.append(iou)
    
    def get_metrics(self) -> np.ndarray:
        """
        Returns: the average of the accumulated dice and iou scores.
        """
        dice = np.mean(self.dice_scores)
        iou = np.mean(self.iou_scores)
        return dice, iou



In [None]:
class DoubleConv(nn.Module):
    """(Conv3D -> BN -> ReLU) * 2"""
    def __init__(self, in_channels, out_channels, num_groups=8):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm3d(out_channels),
            nn.GroupNorm(num_groups=num_groups, num_channels=out_channels),
            nn.ReLU(inplace=True),

            nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm3d(out_channels),
            nn.GroupNorm(num_groups=num_groups, num_channels=out_channels),
            nn.ReLU(inplace=True)
          )

    def forward(self,x):
        return self.double_conv(x)


class Down(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.MaxPool3d(2, 2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, x):
        return self.encoder(x)


class Up(nn.Module):

    def __init__(self, in_channels, out_channels, trilinear=True):
        super().__init__()
        if trilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2, diffZ // 2, diffZ - diffZ // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class Out(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 1)

    def forward(self, x):
        return self.conv(x)


class UNet3d(nn.Module):
    def __init__(self, in_channels, n_classes, n_channels):
        super().__init__()
        self.in_channels = in_channels
        self.n_classes = n_classes
        self.n_channels = n_channels

        self.conv = DoubleConv(in_channels, n_channels)
        self.enc1 = Down(n_channels, 2 * n_channels)
        self.enc2 = Down(2 * n_channels, 4 * n_channels)
        self.enc3 = Down(4 * n_channels, 8 * n_channels)
        self.enc4 = Down(8 * n_channels, 8 * n_channels)

        self.dec1 = Up(16 * n_channels, 4 * n_channels)
        self.dec2 = Up(8 * n_channels, 2 * n_channels)
        self.dec3 = Up(4 * n_channels, n_channels)
        self.dec4 = Up(2 * n_channels, n_channels)
        self.out = Out(n_channels, n_classes)

    def forward(self, x):
        x1 = self.conv(x)
        x2 = self.enc1(x1)
        x3 = self.enc2(x2)
        x4 = self.enc3(x3)
        x5 = self.enc4(x4)

        mask = self.dec1(x5, x4)
        mask = self.dec2(mask, x3)
        mask = self.dec3(mask, x2)
        mask = self.dec4(mask, x1)
        mask = self.out(mask)
        return mask

In [None]:
class Trainer:
    def __init__(self,
                 net: nn.Module,
                 dataset: torch.utils.data.Dataset,
                 criterion: nn.Module,
                 lr: float,
                 accumulation_steps: int,
                 batch_size: int,
                 num_workers: int,
                 fold: int,
                 num_epochs: int,
                 # train_data_csv: config.train_data_csv,
                 display_plot: bool = True,
                ):

        """Initialization."""
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print("device:", self.device)
        self.display_plot = display_plot
        self.net = net
        self.net = self.net.to(self.device)
        self.criterion = criterion
        self.optimizer = Adam(self.net.parameters(), lr=lr)
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode="min",
                                           patience=2, verbose=True)
        self.accumulation_steps = accumulation_steps // batch_size
        self.phases = ["train", "val"]
        self.num_epochs = num_epochs

        self.dataloaders = DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            shuffle=True,
        )
        self.best_loss = float("inf")
        self.losses = {phase: [] for phase in self.phases}
        self.dice_scores = {phase: [] for phase in self.phases}
        self.jaccard_scores = {phase: [] for phase in self.phases}

    def _compute_loss_and_outputs(self,
                                  images: torch.Tensor,
                                  targets: torch.Tensor):
        images = images.to(self.device).float()
        targets = targets.to(self.device).float()
        logits = self.net(images)
        loss = self.criterion(logits, targets)
        return loss, logits

    def _do_epoch(self, epoch: int, phase: str):
        print(f"{phase} epoch: {epoch} | time: {time.strftime('%H:%M:%S')}")

        self.net.train() if phase == "train" else self.net.eval()
        meter = Meter()
        dataloader = self.dataloaders
        total_batches = len(dataloader)
        running_loss = 0.0
        self.optimizer.zero_grad()
        pbar = tqdm(enumerate(dataloader), total=total_batches, desc=f"{phase} {epoch}", leave=False)
        for itr, data_batch in pbar:
            images, targets = data_batch['image'], data_batch['mask']
            loss, logits = self._compute_loss_and_outputs(images, targets)
            loss = loss / self.accumulation_steps
            if phase == "train":
                loss.backward()
                if (itr + 1) % self.accumulation_steps == 0:
                    self.optimizer.step()
                    self.optimizer.zero_grad()
            running_loss += loss.item()
            meter.update(logits.detach().cpu(), targets.detach().cpu())

        epoch_loss = (running_loss * self.accumulation_steps) / total_batches
        epoch_dice, epoch_iou = meter.get_metrics()

        self.losses[phase].append(epoch_loss)
        self.dice_scores[phase].append(epoch_dice)
        self.jaccard_scores[phase].append(epoch_iou)

        return epoch_loss

    def run(self):
        for epoch in range(self.num_epochs):
            self._do_epoch(epoch, "train")
            with torch.no_grad():
                val_loss = self._do_epoch(epoch, "val")
                self.scheduler.step(val_loss)
            if self.display_plot:
                self._plot_train_history()

            if val_loss < self.best_loss:
                print(f"\n{'#'*20}\nSaved new checkpoint\n{'#'*20}\n")
                self.best_loss = val_loss
                torch.save(self.net.state_dict(), "best_model.pth")
            print()
        self._save_train_history()

    def _plot_train_history(self):
        data = [self.losses, self.dice_scores, self.jaccard_scores]
        colors = ['deepskyblue', "crimson"]
        labels = [
            f"""
            train loss {self.losses['train'][-1]}
            val loss {self.losses['val'][-1]}
            """,

            f"""
            train dice score {self.dice_scores['train'][-1]}
            val dice score {self.dice_scores['val'][-1]}
            """,

            f"""
            train jaccard score {self.jaccard_scores['train'][-1]}
            val jaccard score {self.jaccard_scores['val'][-1]}
            """,
        ]

        clear_output(True)
        with plt.style.context("seaborn-v0_8-notebook"):
            fig, axes = plt.subplots(3, 1, figsize=(8, 10))
            for i, ax in enumerate(axes):
                ax.plot(data[i]['val'], c=colors[0], label="val")
                ax.plot(data[i]['train'], c=colors[-1], label="train")
                ax.set_title(labels[i])
                ax.legend(loc="upper right")

            plt.tight_layout()
            plt.show()

    def load_predtrain_model(self,
                             state_path: str):
        self.net.load_state_dict(torch.load(state_path))
        print("Predtrain model loaded")

    def _save_train_history(self):
        """writing model weights and training logs to files."""
        torch.save(self.net.state_dict(),
                   f"last_epoch_model.pth")

        logs_ = [self.losses, self.dice_scores, self.jaccard_scores]
        log_names_ = ["_loss", "_dice", "_jaccard"]
        logs = [logs_[i][key] for i in list(range(len(logs_)))
                         for key in logs_[i]]
        log_names = [key+log_names_[i]
                     for i in list(range(len(logs_)))
                     for key in logs_[i]
                    ]
        pd.DataFrame(
            dict(zip(log_names, logs))
        ).to_csv("train_log.csv", index=False)

In [None]:
nodel = UNet3d(in_channels=4, n_classes=3, n_channels=24).to('cuda')

In [None]:
%%time

trainer = Trainer(
            net=nodel,
            dataset=BratsDataset(training_data_df),
            criterion=BCEDiceLoss(),
            lr=5e-4,
            accumulation_steps=4,
            batch_size=1,
            num_workers=8,
            fold=0,
            num_epochs=50,)

trainer.run()