In [1]:
import os
import sys
sys.path.insert(1, "/kaggle/input/3d-unet")

In [2]:
import torch
import cv2
import numpy as np
from skimage import io
from pytorch3dunet.unet3d.model import UNet3D

In [3]:
def binarize(mask: torch.tensor) -> torch.tensor:
    return torch.where(mask != 0, 1, 0)

In [4]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from scipy.ndimage import zoom
import cv2
import torch.nn.functional as F

class EmbryoDataset(Dataset):
    
    def __init__(self, image_dir_path: str, gt_dir_path: str, augment: bool = False):
        self.image_dir_path = image_dir_path
        self.gt_dir_path = gt_dir_path
        self._filenames = os.listdir(image_dir_path)
        self._augment = augment
        
        if self._augment:
            self._filenames = self._filenames * 4
        
    def __len__(self):
        return len(self._filenames)
    
    def __preprocess_image(self, image: torch.Tensor) -> torch.Tensor:
        min_value = torch.min(image)
        max_value = torch.max(image)
        
        image = (image - min_value) / (max_value - min_value)
        
        return image
    
    def __preprocess_gt(self, gt: torch.Tensor) -> torch.Tensor:
        gt = binarize(gt)
        
        return gt
    
    def __preprocess(self, raw: np.ndarray, type: int) -> np.ndarray:
        preprocessed = raw.astype(np.float32)
        x, y = preprocessed.shape[1], preprocessed.shape[2]
        x_scale_factor = 128 / x
        y_scale_factor = 128 / y
        preprocessed = zoom(preprocessed, (2.1875, x_scale_factor, y_scale_factor), order=3)
        
        preprocessed = torch.from_numpy(preprocessed)
        preprocessed = preprocessed.to("cuda", dtype=torch.float32)
        preprocessed = preprocessed.unsqueeze(dim=0)
        
        if type == 0:
            preprocessed = self.__preprocess_image(preprocessed)
        if type == 1:
            preprocessed = self.__preprocess_gt(preprocessed)
            
        return preprocessed

    def __getitem__(self, idx):
        if self._augment:
            original_idx = idx // 4
        else:
            original_idx = idx

        curr_filename = self._filenames[original_idx]
        curr_image_path = os.path.join(self.image_dir_path, curr_filename)
        curr_gt_path = os.path.join(self.gt_dir_path, curr_filename)

        image_raw = io.imread(curr_image_path)
        gt_raw = io.imread(curr_gt_path)
        
        x = self.__preprocess(image_raw, type=0)
        y = self.__preprocess(gt_raw, type=1)

        if self._augment:
            transform_type = idx % 4
        
            # Apply the specific transformation
            if transform_type == 1:  # Flip horizontally
                x = transforms.functional.hflip(x)
                y = transforms.functional.hflip(y)
            elif transform_type == 2:  # Flip vertically
                x = transforms.functional.vflip(x)
                y = transforms.functional.vflip(y)
            elif transform_type == 3:  # Flip both horizontally and vertically
                x = transforms.functional.hflip(x)
                y = transforms.functional.hflip(y)
        
        return (x, y)

The loss gave us negative outputs, so i rewrote the loss function to follow https://arxiv.org/pdf/1606.04797.pdf since that is what the github mentions. This is the formula used:

$$D = \frac{2\sum_i^N p_i g_i}{\sum_i^N p_i^2 + \sum_i^N g_i^2}$$

where the sums run over the $N$ voxels, of the predicted binary segmentation volume $p_i \in P$ and the ground truth binary volume $g_i \in G$.

In [5]:
def dice_loss(input, target):
    # Flatten the tensors to make sure you can sum over all voxels
    input_flat = input.view(-1)
    target_flat = target.view(-1)
    
    intersection = 2.0 * (input_flat * target_flat).sum()
    denominator = input_flat.pow(2).sum() + target_flat.pow(2).sum()
    
    dice_score = intersection / denominator.clamp(min=1e-6)
    return 1 - dice_score

## 3D Unet hyperparameter tuning

In [6]:
train_set = EmbryoDataset(image_dir_path="/kaggle/input/dl-reprod/Images/train/Images", gt_dir_path="/kaggle/input/dl-reprod/GroundTruth/train/GroundTruth_QCANet", augment=True)
test_set = EmbryoDataset(image_dir_path="/kaggle/input/dl-reprod/Images/test/Images", gt_dir_path="/kaggle/input/dl-reprod/GroundTruth/test/GroundTruth_QCANet", augment=False)
train_loader = DataLoader(train_set, batch_size=1, shuffle=True)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False)

In [7]:
model = UNet3D(in_channels=1, out_channels=1).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=4.617614178775834e-05)

In [8]:
import optuna
from optuna.samplers import TPESampler
from tqdm import tqdm

def objective(trial):
    # Suggested hyperparameters
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    model.train()
    for epoch in range(10):
        total_loss = 0
        for data in tqdm(train_loader):
            x, y = data
            optimizer.zero_grad()
            y_pred = model.forward(x)
            loss = dice_loss(y_pred, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        average_loss = total_loss / len(train_loader)

    # Here you can also implement validation and use the validation loss as the return value
    return average_loss

# Optuna study
study = optuna.create_study(direction="minimize", sampler=TPESampler())
study.optimize(objective, n_trials=10)

# Best trial result
print("Best trial:")
trial = study.best_trial
print(f"Value: {trial.value}")
print("Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

[I 2024-04-14 12:17:10,437] A new study created in memory with name: no-name-7561f283-4f88-414e-b4df-c50a35654119
  0%|          | 1/484 [00:07<57:43,  7.17s/it]
[W 2024-04-14 12:17:17,616] Trial 0 failed with parameters: {'learning_rate': 6.084507237396047e-05} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/optuna/study/_optimize.py", line 200, in _run_trial
    value_or_values = func(trial)
  File "/tmp/ipykernel_103/329070226.py", line 22, in objective
    total_loss += loss.item()
KeyboardInterrupt
[W 2024-04-14 12:17:17,618] Trial 0 failed with value None.


KeyboardInterrupt: 

# NSN and NDN train and test functions

In [9]:
def separate_channels(mask: torch.tensor) -> torch.tensor:
    class_labels = [0,1]
    gt = torch.zeros((len(class_labels), mask.shape[1], mask.shape[2], mask.shape[3]))
    
    for class_label in class_labels:
        gt[class_label] = torch.where(mask == class_label, 1, 0)
        
    return gt

In [10]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from scipy.ndimage import zoom
import cv2
import torch.nn.functional as F

class EmbryoDataset(Dataset):
    
    def __init__(self, image_dir_path: str, gt_dir_path: str, augment: bool = False):
        self.image_dir_path = image_dir_path
        self.gt_dir_path = gt_dir_path
        self._filenames = os.listdir(image_dir_path)
        self._augment = augment
        
        if self._augment:
            self._filenames = self._filenames * 4
        
    def __len__(self):
        return len(self._filenames)
    
    def __preprocess_image(self, image: torch.Tensor) -> torch.Tensor:
        min_value = torch.min(image)
        max_value = torch.max(image)
        
        image = (image - min_value) / (max_value - min_value)
        
        return image
    
    def __preprocess_gt(self, gt: torch.Tensor) -> torch.Tensor:
        gt = separate_channels(gt)
        
        return gt
    
    def __preprocess(self, raw: np.ndarray, type: int) -> np.ndarray:
        preprocessed = raw.astype(np.float32)
        x, y = preprocessed.shape[1], preprocessed.shape[2]
        x_scale_factor = 128 / x
        y_scale_factor = 128 / y
        preprocessed = zoom(preprocessed, (2.1875, x_scale_factor, y_scale_factor), order=3)
        
        preprocessed = torch.from_numpy(preprocessed)
        preprocessed = preprocessed.to("cuda", dtype=torch.float32)
        preprocessed = preprocessed.unsqueeze(dim=0)
        
        if type == 0:
            preprocessed = self.__preprocess_image(preprocessed)
        if type == 1:
            preprocessed = self.__preprocess_gt(preprocessed)
            
        return preprocessed

    def __getitem__(self, idx):
        if self._augment:
            original_idx = idx // 4
        else:
            original_idx = idx

        curr_filename = self._filenames[original_idx]
        curr_image_path = os.path.join(self.image_dir_path, curr_filename)
        curr_gt_path = os.path.join(self.gt_dir_path, curr_filename)

        image_raw = io.imread(curr_image_path)
        gt_raw = io.imread(curr_gt_path)
        
        x = self.__preprocess(image_raw, type=0)
        y = self.__preprocess(gt_raw, type=1)

        if self._augment:
            transform_type = idx % 4
        
            # Apply the specific transformation
            if transform_type == 1:  # Flip horizontally
                x = transforms.functional.hflip(x)
                y = transforms.functional.hflip(y)
            elif transform_type == 2:  # Flip vertically
                x = transforms.functional.vflip(x)
                y = transforms.functional.vflip(y)
            elif transform_type == 3:  # Flip both horizontally and vertically
                x = transforms.functional.hflip(x)
                y = transforms.functional.hflip(y)
        
        return (x.cuda(), y.cuda())

In [11]:
from tqdm import tqdm

def train_epoch():
    total_loss = 0
    with tqdm(total=len(train_set)) as pbar:
        for i, data in enumerate(train_loader):
            x, y = data
            optimizer.zero_grad()
            y_pred = model.forward(x)
            # loss, y_pred = model.forward(x, t=y, seg=False)
            loss = dice_loss(y_pred, y)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pbar.update(1)
    return total_loss / len(train_loader)

def test_epoch():
    total_loss = 0
    total_iou = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            x_test, y_test = data
            y_pred = model(x_test)
            loss = dice_loss(y_pred, y_test)
            iou_value = iou(y_pred, y_test)
            total_loss += loss.item()
            total_iou += iou_value.item()
    return total_loss / len(test_loader), total_iou / len(test_loader)

## NSN hyperparameter tuning

In [12]:
from torch import nn

class Model_L2(nn.Module):
    def __init__(
            self,
            ndim=3,
            n_class=2,
            init_channel=2,
            kernel_size=3,
            pool_size=2,
            ap_factor=2,
            gpu=-1,
            loss_func='nn.CrossEntropyLoss'
        ):
        super(Model_L2, self).__init__()
        self.gpu = gpu
        self.pool_size = pool_size
        self.phase = 'train'
        self.loss_func = eval(loss_func)()

        self.c0=nn.Conv3d(1, init_channel, kernel_size, 1, int(kernel_size/2))
        self.c1=nn.Conv3d(init_channel, int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))

        self.c2=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))
        self.c3=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))

        self.c4=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))
        self.c5=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 3)), kernel_size, 1, int(kernel_size/2))

        self.dc0=nn.ConvTranspose3d(int(init_channel * (ap_factor ** 3)), int(init_channel * (ap_factor ** 3)), self.pool_size, self.pool_size, 0)
        self.dc1=nn.Conv3d(int(init_channel * (ap_factor ** 2) + init_channel * (ap_factor ** 3)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))
        self.dc2=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))

        self.dc3=nn.ConvTranspose3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), self.pool_size, self.pool_size, 0)
        self.dc4=nn.Conv3d(int(init_channel * (ap_factor ** 1) + init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))
        self.dc5=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))

        self.dc6=nn.Conv3d(int(init_channel * (ap_factor ** 1)), n_class, 1, 1)

        self.bnc0=nn.BatchNorm3d(init_channel)
        self.bnc1=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))

        self.bnc2=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))
        self.bnc3=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))

        self.bnc4=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bnc5=nn.BatchNorm3d(int(init_channel * (ap_factor ** 3)))

        self.bndc1=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bndc2=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bndc4=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))
        self.bndc5=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))

        self.pool = nn.MaxPool3d(pool_size, pool_size)

    def _calc(self, x):
        e0 = F.relu(self.bnc0(self.c0(x)))
        syn0 = F.relu(self.bnc1(self.c1(e0)))
        del e0
        e1 = self.pool(syn0)
        e2 = F.relu(self.bnc2(self.c2(e1)))
        syn1 = F.relu(self.bnc3(self.c3(e2)))
        del e1, e2
        e3 = self.pool(syn1)
        e4 = F.relu(self.bnc4(self.c4(e3)))
        e5 = F.relu(self.bnc5(self.c5(e4)))
        del e3, e4
        d0 = torch.cat([self.dc0(e5), syn1], dim=1)
        del e5, syn1
        d1 = F.relu(self.bndc1(self.dc1(d0)))
        d2 = F.relu(self.bndc2(self.dc2(d1)))
        del d0, d1
        d3 = torch.cat([self.dc3(d2), syn0], dim=1)
        del d2, syn0
        d4 = F.relu(self.bndc4(self.dc4(d3)))
        d5 = F.relu(self.bndc5(self.dc5(d4)))
        del d3, d4
        d6 = self.dc6(d5)
        del d5
        return d6

    def forward(self, x, t=None, seg=True):
        h = self._calc(x)
        if seg:
            pred = F.softmax(h, dim=1)
            del h
            return pred
        else:
            loss = self.loss_func(h, t.float())
            pred = F.softmax(h, dim=1)
            del h
            return loss, pred.data
        
class Model_L2_Sigmoid(nn.Module):
    def __init__(
            self,
            ndim=3,
            n_class=2,
            init_channel=2,
            kernel_size=3,
            pool_size=2,
            ap_factor=2,
            gpu=-1,
            loss_func='nn.CrossEntropyLoss'
        ):
        super(Model_L2_Sigmoid, self).__init__()
        self.gpu = gpu
        self.pool_size = pool_size
        self.phase = 'train'
        self.loss_func = eval(loss_func)()

        self.c0=nn.Conv3d(1, init_channel, kernel_size, 1, int(kernel_size/2))
        self.c1=nn.Conv3d(init_channel, int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))

        self.c2=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))
        self.c3=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))

        self.c4=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))
        self.c5=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 3)), kernel_size, 1, int(kernel_size/2))

        self.dc0=nn.ConvTranspose3d(int(init_channel * (ap_factor ** 3)), int(init_channel * (ap_factor ** 3)), self.pool_size, self.pool_size, 0)
        self.dc1=nn.Conv3d(int(init_channel * (ap_factor ** 2) + init_channel * (ap_factor ** 3)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))
        self.dc2=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))

        self.dc3=nn.ConvTranspose3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), self.pool_size, self.pool_size, 0)
        self.dc4=nn.Conv3d(int(init_channel * (ap_factor ** 1) + init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))
        self.dc5=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))

        self.dc6=nn.Conv3d(int(init_channel * (ap_factor ** 1)), n_class, 1, 1)

        self.bnc0=nn.BatchNorm3d(init_channel)
        self.bnc1=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))

        self.bnc2=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))
        self.bnc3=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))

        self.bnc4=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bnc5=nn.BatchNorm3d(int(init_channel * (ap_factor ** 3)))

        self.bndc1=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bndc2=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bndc4=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))
        self.bndc5=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))

        self.pool = nn.MaxPool3d(pool_size, pool_size)

    def _calc(self, x):
        e0 = F.sigmoid(self.bnc0(self.c0(x)))
        syn0 = F.sigmoid(self.bnc1(self.c1(e0)))
        del e0
        e1 = self.pool(syn0)
        e2 = F.sigmoid(self.bnc2(self.c2(e1)))
        syn1 = F.sigmoid(self.bnc3(self.c3(e2)))
        del e1, e2
        e3 = self.pool(syn1)
        e4 = F.sigmoid(self.bnc4(self.c4(e3)))
        e5 = F.sigmoid(self.bnc5(self.c5(e4)))
        del e3, e4
        d0 = torch.cat([self.dc0(e5), syn1], dim=1)
        del e5, syn1
        d1 = F.sigmoid(self.bndc1(self.dc1(d0)))
        d2 = F.sigmoid(self.bndc2(self.dc2(d1)))
        del d0, d1
        d3 = torch.cat([self.dc3(d2), syn0], dim=1)
        del d2, syn0
        d4 = F.sigmoid(self.bndc4(self.dc4(d3)))
        d5 = F.sigmoid(self.bndc5(self.dc5(d4)))
        del d3, d4
        d6 = self.dc6(d5)
        del d5
        return d6

    def forward(self, x, t=None, seg=True):
        h = self._calc(x)
        if seg:
            pred = F.softmax(h, dim=1)
            del h
            return pred
        else:
            loss = self.loss_func(h, t.float())
            pred = F.softmax(h, dim=1)
            del h
            return loss, pred.data
        
class Model_L2_NoResiduals(nn.Module):
    def __init__(
            self,
            ndim=3,
            n_class=2,
            init_channel=2,
            kernel_size=3,
            pool_size=2,
            ap_factor=2,
            gpu=-1,
            loss_func='nn.CrossEntropyLoss'
        ):
        super(Model_L2_NoResiduals, self).__init__()
        self.gpu = gpu
        self.pool_size = pool_size
        self.phase = 'train'
        self.loss_func = eval(loss_func)()

        self.c0=nn.Conv3d(1, init_channel, kernel_size, 1, int(kernel_size/2))
        self.c1=nn.Conv3d(init_channel, int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))

        self.c2=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))
        self.c3=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))

        self.c4=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))
        self.c5=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 3)), kernel_size, 1, int(kernel_size/2))

        self.dc0=nn.ConvTranspose3d(int(init_channel * (ap_factor ** 3)), int(init_channel * (ap_factor ** 3)), self.pool_size, self.pool_size, 0)
        self.dc1=nn.Conv3d(int(init_channel * (ap_factor ** 2) + init_channel * (ap_factor ** 3)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))
        self.dc2=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))

        self.dc3=nn.ConvTranspose3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), self.pool_size, self.pool_size, 0)
        self.dc4=nn.Conv3d(int(init_channel * (ap_factor ** 1) + init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))
        self.dc5=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))

        self.dc6=nn.Conv3d(int(init_channel * (ap_factor ** 1)), n_class, 1, 1)

        self.bnc0=nn.BatchNorm3d(init_channel)
        self.bnc1=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))

        self.bnc2=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))
        self.bnc3=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))

        self.bnc4=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bnc5=nn.BatchNorm3d(int(init_channel * (ap_factor ** 3)))

        self.bndc1=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bndc2=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bndc4=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))
        self.bndc5=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))

        self.pool = nn.MaxPool3d(pool_size, pool_size)

    def _calc(self, x):
        e0 = F.relu(self.bnc0(self.c0(x)))
        syn0 = F.relu(self.bnc1(self.c1(e0)))
        del e0
        e1 = self.pool(syn0)
        e2 = F.relu(self.bnc2(self.c2(e1)))
        syn1 = F.relu(self.bnc3(self.c3(e2)))
        del e1, e2
        e3 = self.pool(syn1)
        e4 = F.relu(self.bnc4(self.c4(e3)))
        e5 = F.relu(self.bnc5(self.c5(e4)))
        del e3, e4
        d0 = torch.cat([self.dc0(e5), torch.zeros_like(syn1)], dim=1)
        del e5, syn1
        d1 = F.relu(self.bndc1(self.dc1(d0)))
        d2 = F.relu(self.bndc2(self.dc2(d1)))
        del d0, d1
        d3 = torch.cat([self.dc3(d2), torch.zeros_like(syn0)], dim=1)
        del d2, syn0
        d4 = F.relu(self.bndc4(self.dc4(d3)))
        d5 = F.relu(self.bndc5(self.dc5(d4)))
        del d3, d4
        d6 = self.dc6(d5)
        del d5
        return d6

    def forward(self, x, t=None, seg=True):
        h = self._calc(x)
        if seg:
            pred = F.softmax(h, dim=1)
            del h
            return pred
        else:
            loss = self.loss_func(h, t.float())
            pred = F.softmax(h, dim=1)
            del h
            return loss, pred.data

In [13]:
def iou(y_pred: torch.Tensor, y_gt: torch.Tensor, smooth=1e-6):
    """
    Calculate Intersection over Union (IoU) for 3D semantic segmentation masks.

    Parameters:
    - outputs: a tensor of shape (N, C, D, H, W) where
      N is the batch size,
      C is the number of classes,
      D is the depth,
      H and W are the height and width of the masks.
      The tensor should contain binary predictions (0 or 1).
    - labels: a tensor of the same shape as outputs containing the ground truth masks.
    - smooth: a small constant added to avoid division by zero.

    Returns:
    - IoU: The Intersection over Union score for each class.
    """
    # Ensure that both outputs and labels are booleans
    y_pred = y_pred > 0.5
    y_gt = y_gt > 0.5
    
    # Intersection and Union
    intersection = (y_pred & y_gt).float().sum(dim=(2, 3, 4)) # Sum over the spatial dimensions
    union = (y_pred | y_gt).float().sum(dim=(2, 3, 4)) # Sum over the spatial dimensions
    
    # Compute the IoU and handle cases where the union is 0
    iou = (intersection + smooth) / (union + smooth)
    
    return torch.mean(iou)  # Return the average IoU over the batch

In [14]:
def objective(trial):
    # Define the hyperparameters to be tuned
    lr = trial.suggest_loguniform('lr', 1e-5, 1e-3)
#     init_channel = trial.suggest_categorical('init_channel', [2, 4, 8])
#     kernel_size = trial.suggest_categorical('kernel_size', [3, 5])
#     pool_size = trial.suggest_categorical('pool_size', [2, 3])
#     ap_factor = trial.suggest_categorical('ap_factor', [1, 2])
    num_epochs = trial.suggest_int('num_epochs', 5, 150)

    # Model, dataset, and DataLoader setup
    model = Model_L2_NoResiduals(n_class=2, gpu=0)
    model.cuda()

    train_set = EmbryoDataset(image_dir_path="/kaggle/input/dl-reprod/Images/train/Images", gt_dir_path="/kaggle/input/dl-reprod/GroundTruth/train/GroundTruth_NSN", augment=True)
    test_set = EmbryoDataset(image_dir_path="/kaggle/input/dl-reprod/Images/test/Images", gt_dir_path="/kaggle/input/dl-reprod/GroundTruth/test/GroundTruth_QCANet", augment=False)
    train_loader = DataLoader(train_set, batch_size=1, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=1, shuffle=False)


    for epoch in range(num_epochs):
        model.train()
        train_loss = train_epoch()
        model.eval()
        test_loss, test_iou = test_epoch()


    return test_loss

study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=10)

print('Number of finished trials:', len(study.trials))
print('Best trial:', study.best_trial.params)

[I 2024-04-14 12:17:32,929] A new study created in memory with name: no-name-13ed868b-da00-4417-b2a3-1a49c843febc
  lr = trial.suggest_loguniform('lr', 1e-5, 1e-3)
  0%|          | 0/484 [00:03<?, ?it/s]
[W 2024-04-14 12:17:36,310] Trial 0 failed with parameters: {'lr': 0.00027384916370804755, 'num_epochs': 67} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/optuna/study/_optimize.py", line 200, in _run_trial
    value_or_values = func(trial)
  File "/tmp/ipykernel_103/2176007645.py", line 23, in objective
    train_loss = train_epoch()
  File "/tmp/ipykernel_103/1540288744.py", line 15, in train_epoch
    total_loss += loss.item()
KeyboardInterrupt
[W 2024-04-14 12:17:36,311] Trial 0 failed with value None.


KeyboardInterrupt: 

## NDN Hyperparameter Tuning

In [15]:
class Model_L4(nn.Module):
    def __init__(
            self,
            ndim=3,
            n_class=2,
            init_channel=2,
            kernel_size=3,
            pool_size=2,
            ap_factor=2,
            gpu=-1,
        ):
        super(Model_L4, self).__init__()
        self.gpu = gpu
        self.pool_size = pool_size
        self.phase = 'train'

        self.c0=nn.Conv3d(1, init_channel, kernel_size, 1, int(kernel_size/2))
        self.c1=nn.Conv3d(init_channel, int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))
        self.c2=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))
        self.c3=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))

        self.c4=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))
        self.c5=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 3)), kernel_size, 1, int(kernel_size/2))

        self.c6=nn.Conv3d(int(init_channel * (ap_factor ** 3)), int(init_channel * (ap_factor ** 3)), kernel_size, 1, int(kernel_size/2))
        self.c7=nn.Conv3d(int(init_channel * (ap_factor ** 3)), int(init_channel * (ap_factor ** 4)), kernel_size, 1, int(kernel_size/2))

        self.c8=nn.Conv3d(int(init_channel * (ap_factor ** 4)), int(init_channel * (ap_factor ** 4)), kernel_size, 1, int(kernel_size/2))
        self.c9=nn.Conv3d(int(init_channel * (ap_factor ** 4)), int(init_channel * (ap_factor ** 5)), kernel_size, 1, int(kernel_size/2))

        self.dc0=nn.ConvTranspose3d(int(init_channel * (ap_factor ** 5)), int(init_channel * (ap_factor ** 5)), self.pool_size, self.pool_size, 0)
        self.dc1=nn.Conv3d(int(init_channel * (ap_factor ** 4) + init_channel * (ap_factor ** 5)), int(init_channel * (ap_factor ** 4)), kernel_size, 1, int(kernel_size/2))
        self.dc2=nn.Conv3d(int(init_channel * (ap_factor ** 4)), int(init_channel * (ap_factor ** 4)), kernel_size, 1, int(kernel_size/2))

        self.dc3=nn.ConvTranspose3d(int(init_channel * (ap_factor ** 4)), int(init_channel * (ap_factor ** 4)), self.pool_size, self.pool_size, 0)
        self.dc4=nn.Conv3d(int(init_channel * (ap_factor ** 3) + init_channel * (ap_factor ** 4)), int(init_channel * (ap_factor ** 3)), kernel_size, 1, int(kernel_size/2))
        self.dc5=nn.Conv3d(int(init_channel * (ap_factor ** 3)), int(init_channel * (ap_factor ** 3)), kernel_size, 1, int(kernel_size/2))

        self.dc6=nn.ConvTranspose3d(int(init_channel * (ap_factor ** 3)), int(init_channel * (ap_factor ** 3)), self.pool_size, self.pool_size, 0)
        self.dc7=nn.Conv3d(int(init_channel * (ap_factor ** 2) + init_channel * (ap_factor ** 3)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))
        self.dc8=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))

        self.dc9=nn.ConvTranspose3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), self.pool_size, self.pool_size, 0)
        self.dc10=nn.Conv3d(int(init_channel * (ap_factor ** 1) + init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))
        self.dc11=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))

        self.dc12=nn.Conv3d(int(init_channel * (ap_factor ** 1)), n_class, 1, 1)

        self.bnc0=nn.BatchNorm3d(init_channel)
        self.bnc1=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))

        self.bnc2=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))
        self.bnc3=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))

        self.bnc4=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bnc5=nn.BatchNorm3d(int(init_channel * (ap_factor ** 3)))

        self.bnc6=nn.BatchNorm3d(int(init_channel * (ap_factor ** 3)))
        self.bnc7=nn.BatchNorm3d(int(init_channel * (ap_factor ** 4)))

        self.bnc8=nn.BatchNorm3d(int(init_channel * (ap_factor ** 4)))
        self.bnc9=nn.BatchNorm3d(int(init_channel * (ap_factor ** 5)))
        self.bndc1=nn.BatchNorm3d(int(init_channel * (ap_factor ** 4)))
        self.bndc2=nn.BatchNorm3d(int(init_channel * (ap_factor ** 4)))
        self.bndc4=nn.BatchNorm3d(int(init_channel * (ap_factor ** 3)))
        self.bndc5=nn.BatchNorm3d(int(init_channel * (ap_factor ** 3)))
        self.bndc7=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bndc8=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bndc10=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))
        self.bndc11=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))

        self.pool = nn.MaxPool3d(pool_size, pool_size)

    def _calc(self, x):
        e0 = F.relu(self.bnc0(self.c0(x)))
        syn0 = F.relu(self.bnc1(self.c1(e0)))
        del e0
        e1 = self.pool(syn0)
        e2 = F.relu(self.bnc2(self.c2(e1)))
        syn1 = F.relu(self.bnc3(self.c3(e2)))
        del e1, e2
        e3 = self.pool(syn1)
        e4 = F.relu(self.bnc4(self.c4(e3)))
        syn2 = F.relu(self.bnc5(self.c5(e4)))
        del e3, e4
        e5 = self.pool(syn2)
        e6 = F.relu(self.bnc6(self.c6(e5)))
        syn3 = F.relu(self.bnc7(self.c7(e6)))
        del e5, e6
        e7 = self.pool(syn3)
        e8 = F.relu(self.bnc8(self.c8(e7)))
        e9 = F.relu(self.bnc9(self.c9(e8)))
        del e7, e8
        d0 = torch.cat([self.dc0(e9), syn3], dim=1)
        del e9, syn3
        d1 = F.relu(self.bndc1(self.dc1(d0)))
        d2 = F.relu(self.bndc2(self.dc2(d1)))
        del d0, d1
        d3 = torch.cat([self.dc3(d2), syn2], dim=1)
        del d2, syn2
        d4 = F.relu(self.bndc4(self.dc4(d3)))
        d5 = F.relu(self.bndc5(self.dc5(d4)))
        del d3, d4
        d6 = torch.cat([self.dc6(d5), syn1], dim=1)
        del d5, syn1
        d7 = F.relu(self.bndc7(self.dc7(d6)))
        d8 = F.relu(self.bndc8(self.dc8(d7)))
        del d6, d7
        d9 = torch.cat([self.dc9(d8), syn0], dim=1)
        del d8, syn0
        d10 = F.relu(self.bndc10(self.dc10(d9)))
        d11 = F.relu(self.bndc11(self.dc11(d10)))
        del d9, d10

        d12 = self.dc12(d11)
        del d11
        return d12

    def forward(self, x, t=None, seg=True):
        h = self._calc(x)
        pred = F.softmax(h, dim=1)
        del h
        return pred

In [16]:
def objective(trial):
    # Define the hyperparameters to be tuned
    lr = trial.suggest_loguniform('lr', 1e-5, 1e-3)
#     init_channel = trial.suggest_categorical('init_channel', [2, 4, 8])
#     kernel_size = trial.suggest_categorical('kernel_size', [3, 5])
#     pool_size = trial.suggest_categorical('pool_size', [2, 3])
#     ap_factor = trial.suggest_categorical('ap_factor', [1, 2])
    num_epochs = trial.suggest_int('num_epochs', 5, 150)

    # Model, dataset, and DataLoader setup
    model = Model_L4(n_class=1, gpu=1)
    model.cuda()

    train_set = EmbryoDataset(image_dir_path="/kaggle/input/dl-reprod/Images/train/Images", gt_dir_path="/kaggle/input/dl-reprod/GroundTruth/train/GroundTruth_NDN", augment=True)
    test_set = EmbryoDataset(image_dir_path="/kaggle/input/dl-reprod/Images/test/Images", gt_dir_path="/kaggle/input/dl-reprod/GroundTruth/test/GroundTruth_QCANet", augment=False)
    train_loader = DataLoader(train_set, batch_size=1, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=1, shuffle=False)

    for epoch in range(num_epochs):
        model.train()
        train_loss = train_epoch()
        model.eval()
        test_loss, test_iou = test_epoch()


    return test_loss

study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=10)

print('Number of finished trials:', len(study.trials))
print('Best trial:', study.best_trial.params)

[I 2024-04-14 12:17:40,418] A new study created in memory with name: no-name-ffd54f88-e1e3-46b8-a886-5f76cbcbd478
  lr = trial.suggest_loguniform('lr', 1e-5, 1e-3)
  0%|          | 0/484 [00:01<?, ?it/s]
[W 2024-04-14 12:17:41,950] Trial 0 failed with parameters: {'lr': 0.000846484066828721, 'num_epochs': 77} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/optuna/study/_optimize.py", line 200, in _run_trial
    value_or_values = func(trial)
  File "/tmp/ipykernel_103/294003187.py", line 22, in objective
    train_loss = train_epoch()
  File "/tmp/ipykernel_103/1540288744.py", line 6, in train_epoch
    for i, data in enumerate(train_loader):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 674, in _next_data
    data = self._dataset_

KeyboardInterrupt: 