In [114]:
import os
import sys
sys.path.insert(1, "/kaggle/input/3d-unet")

In [115]:
import torch
import cv2
import numpy as np
from skimage import io
from pytorch3dunet.unet3d.model import UNet3D

<h1> Dataset and auxiliary functions </h1>

In [117]:
def binarize(mask: torch.tensor) -> torch.tensor:
    """
    Converts instance segmentation labels, i.e. 1,2,3,4...
    into semantic segmentation labels, i.e. 0,1
    because instance segmentation is supposed to be done in post-processing,
    as described in the QCANet paper.
    """
    return torch.where(mask != 0, 1, 0)

In [118]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from scipy.ndimage import zoom
import cv2
import torch.nn.functional as F

class EmbryoDataset(Dataset):
    
    def __init__(self, image_dir_path: str, gt_dir_path: str, augment: bool = False):
        self.image_dir_path = image_dir_path
        self.gt_dir_path = gt_dir_path
        self._filenames = os.listdir(image_dir_path)
        self._augment = augment
        
        # If augmentation is used, the filenames are duplicated because rotated versions of
        # each image are produced, as described in the QCANet paper
        if self._augment:
            self._filenames = self._filenames * 4
        
    def __len__(self):
        return len(self._filenames)
    
    def __preprocess_image(self, image: torch.Tensor) -> torch.Tensor:
        # Preprocessing images as described in the QCANet paper
        min_value = torch.min(image)
        max_value = torch.max(image)
        
        image = (image - min_value) / (max_value - min_value)
        
        return image
    
    def __preprocess_gt(self, gt: torch.Tensor) -> torch.Tensor:
        # Ground truth preprocessing
        gt = binarize(gt)
        
        return gt
    
    def __preprocess(self, raw: np.ndarray, type: int) -> np.ndarray:
        preprocessed = raw.astype(np.float32)
        x, y = preprocessed.shape[1], preprocessed.shape[2]
        x_scale_factor = 128 / x
        y_scale_factor = 128 / y
        # Z-axis stretched by a factor of 2.1875 as described in the QCANet paper
        # Y and X axes normalized to 128 pixels to keep sizes consistent between samples
        preprocessed = zoom(preprocessed, (2.1875, x_scale_factor, y_scale_factor), order=3)
        
        preprocessed = torch.from_numpy(preprocessed)
        preprocessed = preprocessed.to("cuda", dtype=torch.float32)
        preprocessed = preprocessed.unsqueeze(dim=0)
        
        if type == 0:
            preprocessed = self.__preprocess_image(preprocessed)
        if type == 1:
            preprocessed = self.__preprocess_gt(preprocessed)
            
        return preprocessed

    def __getitem__(self, idx):
        # With augmentation enabled, we take 4 copies of one image with different
        # rotation for each one
        if self._augment:
            original_idx = idx // 4
        else:
            original_idx = idx

        curr_filename = self._filenames[original_idx]
        curr_image_path = os.path.join(self.image_dir_path, curr_filename)
        curr_gt_path = os.path.join(self.gt_dir_path, curr_filename)

        image_raw = io.imread(curr_image_path)
        gt_raw = io.imread(curr_gt_path)
        
        x = self.__preprocess(image_raw, type=0)
        y = self.__preprocess(gt_raw, type=1)

        # With augmentation enabled, we take 4 copies of one image with different
        # rotation for each one
        if self._augment:
            transform_type = idx % 4
        
            # Apply the specific transformation
            if transform_type == 1:  # Flip horizontally
                x = transforms.functional.hflip(x)
                y = transforms.functional.hflip(y)
            elif transform_type == 2:  # Flip vertically
                x = transforms.functional.vflip(x)
                y = transforms.functional.vflip(y)
            elif transform_type == 3:  # Flip both horizontally and vertically
                x = transforms.functional.hflip(x)
                y = transforms.functional.hflip(y)
        
        return (x.cuda(), y.cuda())

In [119]:
train_set = EmbryoDataset(image_dir_path="/kaggle/input/dl-reprod/Images/train/Images", gt_dir_path="/kaggle/input/dl-reprod/GroundTruth/train/GroundTruth_NSN", augment=True)
test_set = EmbryoDataset(image_dir_path="/kaggle/input/dl-reprod/Images/test/Images", gt_dir_path="/kaggle/input/dl-reprod/GroundTruth/test/GroundTruth_QCANet", augment=False)
train_loader = DataLoader(train_set, batch_size=1, shuffle=True)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False)

<h1> Metric functions </h1>

DiceLoss was used as a loss function in the QCANet paper. Implemented based on https://arxiv.org/pdf/1606.04797.pdf

$$D = \frac{2\sum_i^N p_i g_i}{\sum_i^N p_i^2 + \sum_i^N g_i^2}$$

where the sums run over the $N$ voxels, of the predicted binary segmentation volume $p_i \in P$ and the ground truth binary volume $g_i \in G$.

In [120]:
def dice_loss(input, target):
    # Flatten the tensors to make sure you can sum over all voxels
    input_flat = input.view(-1)
    target_flat = target.view(-1)
    
    intersection = 2.0 * (input_flat * target_flat).sum()
    denominator = input_flat.pow(2).sum() + target_flat.pow(2).sum()
    
    dice_score = intersection / denominator.clamp(min=1e-6)
    return 1 - dice_score

In [121]:
def iou(y_pred: torch.Tensor, y_gt: torch.Tensor, smooth=1e-6):
    """
    Calculate Intersection over Union (IoU) for 3D semantic segmentation masks.

    Parameters:
    - outputs: a tensor of shape (N, C, D, H, W) where
      N is the batch size,
      C is the number of classes,
      D is the depth,
      H and W are the height and width of the masks.
      The tensor should contain binary predictions (0 or 1).
    - labels: a tensor of the same shape as outputs containing the ground truth masks.
    - smooth: a small constant added to avoid division by zero.

    Returns:
    - IoU: The Intersection over Union score for each class.
    """
    # Ensure that both outputs and labels are booleans
    y_pred = y_pred > 0.5
    y_gt = y_gt > 0.5
    
    # Intersection and Union
    intersection = (y_pred & y_gt).float().sum(dim=(2, 3, 4)) # Sum over the spatial dimensions
    union = (y_pred | y_gt).float().sum(dim=(2, 3, 4)) # Sum over the spatial dimensions
    
    # Compute the IoU and handle cases where the union is 0
    iou = (intersection + smooth) / (union + smooth)
    
    return torch.mean(iou)  # Return the average IoU over the batch

<p> Bayesian hyperparameter optimization to determine hyperparameters not given in the paper </p>

In [None]:
import optuna
from optuna.samplers import TPESampler
from tqdm import tqdm

def objective(trial):
    # Suggested hyperparameters
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    model.train()
    for epoch in range(10):
        total_loss = 0
        for data in tqdm(train_loader):
            x, y = data
            optimizer.zero_grad()
            y_pred = model.forward(x)
            loss = dice_loss(y_pred, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        average_loss = total_loss / len(train_loader)

    # Here you can also implement validation and use the validation loss as the return value
    return average_loss

# Optuna study
study = optuna.create_study(direction="minimize", sampler=TPESampler())
study.optimize(objective, n_trials=10)

# Best trial result
print("Best trial:")
trial = study.best_trial
print(f"Value: {trial.value}")
print("Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")



<h1> Components of the training loop </h1>

In [122]:
model = UNet3D(in_channels=1, out_channels=1)
model.cuda()
# Learning rate determined by hyperparameter optimization below
optimizer = torch.optim.Adam(model.parameters(), lr=5e-05)

In [123]:
from tqdm import tqdm

def train_epoch():
    total_loss = 0
    with tqdm(total=len(train_set)) as pbar:
        for i, data in enumerate(train_loader):
            x, y = data
            optimizer.zero_grad()
            y_pred = model.forward(x)
            # loss, y_pred = model.forward(x, t=y, seg=False)
            loss = dice_loss(y_pred, y)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pbar.update(1)
    return total_loss / len(train_loader)

In [124]:
def test_epoch():
    total_loss = 0
    total_iou = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            x_test, y_test = data
            y_pred = model(x_test)
            loss = dice_loss(y_pred, y_test)
            iou_value = iou(y_pred, y_test)
            total_loss += loss.item()
            total_iou += iou_value.item()
    return total_loss / len(test_loader), total_iou / len(test_loader)

In [None]:
losses = []
maxloss = 1
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    train_loss = train_epoch()
    print(f'Epoch {epoch+1}/{num_epochs}, Train loss: {train_loss}')
    model.eval()
    test_loss, test_iou = test_epoch()
    print(f'Epoch {epoch+1}/{num_epochs}, Test loss: {test_loss}, Test IoU: {test_iou}')
    if (test_loss < maxloss):
        maxloss = test_loss
        torch.save(model.state_dict(), 'nsn_weights.pth')

100%|██████████| 484/484 [13:50<00:00,  1.72s/it]


Epoch 1/10, Train loss: 0.8472217211792291
Epoch 1/10, Test loss: 0.9174671877514232, Test IoU: 0.027402645331511103


100%|██████████| 484/484 [13:31<00:00,  1.68s/it]


Epoch 2/10, Train loss: 0.819889424507283
Epoch 2/10, Test loss: 0.9122002856297926, Test IoU: 0.03608741985358806


 18%|█▊        | 89/484 [02:29<10:57,  1.67s/it]

In [None]:
import matplotlib.pyplot as plt

plt.plot(losses)
plt.title('Training Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()