In [None]:
import os
import sys
sys.path.insert(1, "/kaggle/input/3d-unet")
# sys.path.insert(1, "/kaggle/input/qcanet-pytorch/QCANet")

In [None]:
import torch
import cv2
import numpy as np
from skimage import io
from pytorch3dunet.unet3d.model import UNet3D
# from src.lib.model import Model_L2, Model_L4

<h1> Ablation studies </h1>

<p> Below are 3 variants of the QCANet NSN model. The first one is the baseline, the second one has the ReLU activation function replaced with the Sigmoid activation function, and the third one has residual connections between the convolutional and deconvolutional layers removed. </p>

In [None]:
from torch import nn

class Model_L2(nn.Module):
    def __init__(
            self,
            ndim=3,
            n_class=2,
            init_channel=2,
            kernel_size=3,
            pool_size=2,
            ap_factor=2,
            gpu=-1,
            loss_func='nn.CrossEntropyLoss'
        ):
        super(Model_L2, self).__init__()
        self.gpu = gpu
        self.pool_size = pool_size
        self.phase = 'train'
        self.loss_func = eval(loss_func)()

        self.c0=nn.Conv3d(1, init_channel, kernel_size, 1, int(kernel_size/2))
        self.c1=nn.Conv3d(init_channel, int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))

        self.c2=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))
        self.c3=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))

        self.c4=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))
        self.c5=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 3)), kernel_size, 1, int(kernel_size/2))

        self.dc0=nn.ConvTranspose3d(int(init_channel * (ap_factor ** 3)), int(init_channel * (ap_factor ** 3)), self.pool_size, self.pool_size, 0)
        self.dc1=nn.Conv3d(int(init_channel * (ap_factor ** 2) + init_channel * (ap_factor ** 3)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))
        self.dc2=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))

        self.dc3=nn.ConvTranspose3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), self.pool_size, self.pool_size, 0)
        self.dc4=nn.Conv3d(int(init_channel * (ap_factor ** 1) + init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))
        self.dc5=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))

        self.dc6=nn.Conv3d(int(init_channel * (ap_factor ** 1)), n_class, 1, 1)

        self.bnc0=nn.BatchNorm3d(init_channel)
        self.bnc1=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))

        self.bnc2=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))
        self.bnc3=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))

        self.bnc4=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bnc5=nn.BatchNorm3d(int(init_channel * (ap_factor ** 3)))

        self.bndc1=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bndc2=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bndc4=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))
        self.bndc5=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))

        self.pool = nn.MaxPool3d(pool_size, pool_size)

    def _calc(self, x):
        e0 = F.relu(self.bnc0(self.c0(x)))
        syn0 = F.relu(self.bnc1(self.c1(e0)))
        del e0
        e1 = self.pool(syn0)
        e2 = F.relu(self.bnc2(self.c2(e1)))
        syn1 = F.relu(self.bnc3(self.c3(e2)))
        del e1, e2
        e3 = self.pool(syn1)
        e4 = F.relu(self.bnc4(self.c4(e3)))
        e5 = F.relu(self.bnc5(self.c5(e4)))
        del e3, e4
        d0 = torch.cat([self.dc0(e5), syn1], dim=1)
        del e5, syn1
        d1 = F.relu(self.bndc1(self.dc1(d0)))
        d2 = F.relu(self.bndc2(self.dc2(d1)))
        del d0, d1
        d3 = torch.cat([self.dc3(d2), syn0], dim=1)
        del d2, syn0
        d4 = F.relu(self.bndc4(self.dc4(d3)))
        d5 = F.relu(self.bndc5(self.dc5(d4)))
        del d3, d4
        d6 = self.dc6(d5)
        del d5
        return d6

    def forward(self, x, t=None, seg=True):
        h = self._calc(x)
        if seg:
            pred = F.softmax(h, dim=1)
            del h
            return pred
        else:
            loss = self.loss_func(h, t.float())
            pred = F.softmax(h, dim=1)
            del h
            return loss, pred.data
        
class Model_L2_Sigmoid(nn.Module):
    def __init__(
            self,
            ndim=3,
            n_class=2,
            init_channel=2,
            kernel_size=3,
            pool_size=2,
            ap_factor=2,
            gpu=-1,
            loss_func='nn.CrossEntropyLoss'
        ):
        super(Model_L2_Sigmoid, self).__init__()
        self.gpu = gpu
        self.pool_size = pool_size
        self.phase = 'train'
        self.loss_func = eval(loss_func)()

        self.c0=nn.Conv3d(1, init_channel, kernel_size, 1, int(kernel_size/2))
        self.c1=nn.Conv3d(init_channel, int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))

        self.c2=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))
        self.c3=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))

        self.c4=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))
        self.c5=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 3)), kernel_size, 1, int(kernel_size/2))

        self.dc0=nn.ConvTranspose3d(int(init_channel * (ap_factor ** 3)), int(init_channel * (ap_factor ** 3)), self.pool_size, self.pool_size, 0)
        self.dc1=nn.Conv3d(int(init_channel * (ap_factor ** 2) + init_channel * (ap_factor ** 3)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))
        self.dc2=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))

        self.dc3=nn.ConvTranspose3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), self.pool_size, self.pool_size, 0)
        self.dc4=nn.Conv3d(int(init_channel * (ap_factor ** 1) + init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))
        self.dc5=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))

        self.dc6=nn.Conv3d(int(init_channel * (ap_factor ** 1)), n_class, 1, 1)

        self.bnc0=nn.BatchNorm3d(init_channel)
        self.bnc1=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))

        self.bnc2=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))
        self.bnc3=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))

        self.bnc4=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bnc5=nn.BatchNorm3d(int(init_channel * (ap_factor ** 3)))

        self.bndc1=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bndc2=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bndc4=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))
        self.bndc5=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))

        self.pool = nn.MaxPool3d(pool_size, pool_size)

    def _calc(self, x):
        e0 = F.sigmoid(self.bnc0(self.c0(x)))
        syn0 = F.sigmoid(self.bnc1(self.c1(e0)))
        del e0
        e1 = self.pool(syn0)
        e2 = F.sigmoid(self.bnc2(self.c2(e1)))
        syn1 = F.sigmoid(self.bnc3(self.c3(e2)))
        del e1, e2
        e3 = self.pool(syn1)
        e4 = F.sigmoid(self.bnc4(self.c4(e3)))
        e5 = F.sigmoid(self.bnc5(self.c5(e4)))
        del e3, e4
        d0 = torch.cat([self.dc0(e5), syn1], dim=1)
        del e5, syn1
        d1 = F.sigmoid(self.bndc1(self.dc1(d0)))
        d2 = F.sigmoid(self.bndc2(self.dc2(d1)))
        del d0, d1
        d3 = torch.cat([self.dc3(d2), syn0], dim=1)
        del d2, syn0
        d4 = F.sigmoid(self.bndc4(self.dc4(d3)))
        d5 = F.sigmoid(self.bndc5(self.dc5(d4)))
        del d3, d4
        d6 = self.dc6(d5)
        del d5
        return d6

    def forward(self, x, t=None, seg=True):
        h = self._calc(x)
        if seg:
            pred = F.softmax(h, dim=1)
            del h
            return pred
        else:
            loss = self.loss_func(h, t.float())
            pred = F.softmax(h, dim=1)
            del h
            return loss, pred.data
        
class Model_L2_NoResiduals(nn.Module):
    def __init__(
            self,
            ndim=3,
            n_class=2,
            init_channel=2,
            kernel_size=3,
            pool_size=2,
            ap_factor=2,
            gpu=-1,
            loss_func='nn.CrossEntropyLoss'
        ):
        super(Model_L2_NoResiduals, self).__init__()
        self.gpu = gpu
        self.pool_size = pool_size
        self.phase = 'train'
        self.loss_func = eval(loss_func)()

        self.c0=nn.Conv3d(1, init_channel, kernel_size, 1, int(kernel_size/2))
        self.c1=nn.Conv3d(init_channel, int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))

        self.c2=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))
        self.c3=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))

        self.c4=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))
        self.c5=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 3)), kernel_size, 1, int(kernel_size/2))

        self.dc0=nn.ConvTranspose3d(int(init_channel * (ap_factor ** 3)), int(init_channel * (ap_factor ** 3)), self.pool_size, self.pool_size, 0)
        self.dc1=nn.Conv3d(int(init_channel * (ap_factor ** 2) + init_channel * (ap_factor ** 3)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))
        self.dc2=nn.Conv3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), kernel_size, 1, int(kernel_size/2))

        self.dc3=nn.ConvTranspose3d(int(init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 2)), self.pool_size, self.pool_size, 0)
        self.dc4=nn.Conv3d(int(init_channel * (ap_factor ** 1) + init_channel * (ap_factor ** 2)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))
        self.dc5=nn.Conv3d(int(init_channel * (ap_factor ** 1)), int(init_channel * (ap_factor ** 1)), kernel_size, 1, int(kernel_size/2))

        self.dc6=nn.Conv3d(int(init_channel * (ap_factor ** 1)), n_class, 1, 1)

        self.bnc0=nn.BatchNorm3d(init_channel)
        self.bnc1=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))

        self.bnc2=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))
        self.bnc3=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))

        self.bnc4=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bnc5=nn.BatchNorm3d(int(init_channel * (ap_factor ** 3)))

        self.bndc1=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bndc2=nn.BatchNorm3d(int(init_channel * (ap_factor ** 2)))
        self.bndc4=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))
        self.bndc5=nn.BatchNorm3d(int(init_channel * (ap_factor ** 1)))

        self.pool = nn.MaxPool3d(pool_size, pool_size)

    def _calc(self, x):
        e0 = F.relu(self.bnc0(self.c0(x)))
        syn0 = F.relu(self.bnc1(self.c1(e0)))
        del e0
        e1 = self.pool(syn0)
        e2 = F.relu(self.bnc2(self.c2(e1)))
        syn1 = F.relu(self.bnc3(self.c3(e2)))
        del e1, e2
        e3 = self.pool(syn1)
        e4 = F.relu(self.bnc4(self.c4(e3)))
        e5 = F.relu(self.bnc5(self.c5(e4)))
        del e3, e4
        # All-zeros vector instead of syn1 used in the concatenation to
        # remove the information from the earlier layers
        d0 = torch.cat([self.dc0(e5), torch.zeros_like(syn1)], dim=1)
        del e5, syn1
        d1 = F.relu(self.bndc1(self.dc1(d0)))
        d2 = F.relu(self.bndc2(self.dc2(d1)))
        del d0, d1
        d3 = torch.cat([self.dc3(d2), torch.zeros_like(syn0)], dim=1)
        del d2, syn0
        d4 = F.relu(self.bndc4(self.dc4(d3)))
        d5 = F.relu(self.bndc5(self.dc5(d4)))
        del d3, d4
        d6 = self.dc6(d5)
        del d5
        return d6

    def forward(self, x, t=None, seg=True):
        h = self._calc(x)
        if seg:
            pred = F.softmax(h, dim=1)
            del h
            return pred
        else:
            loss = self.loss_func(h, t.float())
            pred = F.softmax(h, dim=1)
            del h
            return loss, pred.data

In [None]:
def separate_channels(mask: torch.tensor) -> torch.tensor:
    class_labels = [0,1]
    gt = torch.zeros((len(class_labels), mask.shape[1], mask.shape[2], mask.shape[3]))
    
    for class_label in class_labels:
        gt[class_label] = torch.where(mask == class_label, 1, 0)
        
    return gt

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from scipy.ndimage import zoom
import cv2
import torch.nn.functional as F

class EmbryoDataset(Dataset):
    
    def __init__(self, image_dir_path: str, gt_dir_path: str, augment: bool = False):
        self.image_dir_path = image_dir_path
        self.gt_dir_path = gt_dir_path
        self._filenames = os.listdir(image_dir_path)
        self._augment = augment
        
        if self._augment:
            self._filenames = self._filenames * 4
        
    def __len__(self):
        return len(self._filenames)
    
    def __preprocess_image(self, image: torch.Tensor) -> torch.Tensor:
        min_value = torch.min(image)
        max_value = torch.max(image)
        
        image = (image - min_value) / (max_value - min_value)
        
        return image
    
    def __preprocess_gt(self, gt: torch.Tensor) -> torch.Tensor:
        gt = separate_channels(gt)
        
        return gt
    
    def __preprocess(self, raw: np.ndarray, type: int) -> np.ndarray:
        preprocessed = raw.astype(np.float32)
        x, y = preprocessed.shape[1], preprocessed.shape[2]
        x_scale_factor = 128 / x
        y_scale_factor = 128 / y
        preprocessed = zoom(preprocessed, (2.1875, x_scale_factor, y_scale_factor), order=3)
        
        preprocessed = torch.from_numpy(preprocessed)
        preprocessed = preprocessed.to("cuda", dtype=torch.float32)
        preprocessed = preprocessed.unsqueeze(dim=0)
        
        if type == 0:
            preprocessed = self.__preprocess_image(preprocessed)
        if type == 1:
            preprocessed = self.__preprocess_gt(preprocessed)
            
        return preprocessed

    def __getitem__(self, idx):
        if self._augment:
            original_idx = idx // 4
        else:
            original_idx = idx

        curr_filename = self._filenames[original_idx]
        curr_image_path = os.path.join(self.image_dir_path, curr_filename)
        curr_gt_path = os.path.join(self.gt_dir_path, curr_filename)

        image_raw = io.imread(curr_image_path)
        gt_raw = io.imread(curr_gt_path)
        
        x = self.__preprocess(image_raw, type=0)
        y = self.__preprocess(gt_raw, type=1)

        if self._augment:
            transform_type = idx % 4
        
            # Apply the specific transformation
            if transform_type == 1:  # Flip horizontally
                x = transforms.functional.hflip(x)
                y = transforms.functional.hflip(y)
            elif transform_type == 2:  # Flip vertically
                x = transforms.functional.vflip(x)
                y = transforms.functional.vflip(y)
            elif transform_type == 3:  # Flip both horizontally and vertically
                x = transforms.functional.hflip(x)
                y = transforms.functional.hflip(y)
        
        return (x.cuda(), y.cuda())

In [None]:
train_set = EmbryoDataset(image_dir_path="/kaggle/input/dl-reprod/Images/train/Images", gt_dir_path="/kaggle/input/dl-reprod/GroundTruth/train/GroundTruth_NSN", augment=True)
test_set = EmbryoDataset(image_dir_path="/kaggle/input/dl-reprod/Images/test/Images", gt_dir_path="/kaggle/input/dl-reprod/GroundTruth/test/GroundTruth_QCANet", augment=False)
train_loader = DataLoader(train_set, batch_size=1, shuffle=True)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False)

In [None]:
def dice_loss(input, target):
    # Flatten the tensors to make sure you can sum over all voxels
    input_flat = input.view(-1)
    target_flat = target.view(-1)
    
    intersection = 2.0 * (input_flat * target_flat).sum()
    denominator = input_flat.pow(2).sum() + target_flat.pow(2).sum()
    
    dice_score = intersection / denominator.clamp(min=1e-6)
    return 1 - dice_score

In [None]:
def iou(y_pred: torch.Tensor, y_gt: torch.Tensor, smooth=1e-6):
    """
    Calculate Intersection over Union (IoU) for 3D semantic segmentation masks.

    Parameters:
    - outputs: a tensor of shape (N, C, D, H, W) where
      N is the batch size,
      C is the number of classes,
      D is the depth,
      H and W are the height and width of the masks.
      The tensor should contain binary predictions (0 or 1).
    - labels: a tensor of the same shape as outputs containing the ground truth masks.
    - smooth: a small constant added to avoid division by zero.

    Returns:
    - IoU: The Intersection over Union score for each class.
    """
    # Ensure that both outputs and labels are booleans
    y_pred = y_pred > 0.5
    y_gt = y_gt > 0.5
    
    # Intersection and Union
    intersection = (y_pred & y_gt).float().sum(dim=(2, 3, 4)) # Sum over the spatial dimensions
    union = (y_pred | y_gt).float().sum(dim=(2, 3, 4)) # Sum over the spatial dimensions
    
    # Compute the IoU and handle cases where the union is 0
    iou = (intersection + smooth) / (union + smooth)
    
    return torch.mean(iou)  # Return the average IoU over the batch

In [None]:
model = Model_L2(n_class=2, gpu=0)
# model = Model_L2_Sigmoid(n_class=2, gpu=0)
# model = Model_L2_NoResiduals(n_class=2, gpu=0)
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-05)

In [None]:
from tqdm import tqdm

def train_epoch():
    total_loss = 0
    with tqdm(total=len(train_set)) as pbar:
        for i, data in enumerate(train_loader):
            x, y = data
            optimizer.zero_grad()
            y_pred = model.forward(x)
            # loss, y_pred = model.forward(x, t=y, seg=False)
            loss = dice_loss(y_pred, y)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pbar.update(1)
    return total_loss / len(train_loader)

<p> Each model variant was trained for 10 epochs to see how the training loss behaves. </p>

In [None]:
losses = []
maxloss = 1
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    train_loss = train_epoch()
    print(f'Epoch {epoch+1}/{num_epochs}, Train loss: {train_loss}')