## Download and importing of libraries

In [1]:
%%capture
!pip install --upgrade keras==2.15.0
!pip install -U git+https://github.com/UN-GCPDS/python-gcpds.image_segmentation.git >> /tmp/null

In [None]:
# General Libraries
import os
import time
import shutil
import random
import warnings
import pandas as pd
import seaborn as sns
from enum import auto, Enum
from functools import partial
from datetime import datetime

# Image Processing Libraries
import cv2
import numpy as np
import matplotlib.pyplot as plt

# Deep Learning Libraries
import tensorflow as tf
from tensorflow import keras
from keras_tuner import Objective
from keras_tuner import HyperModel
import tensorflow.keras.backend as K
from keras.layers import Layer, Activation
from keras_tuner import BayesianOptimization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.utils import get_custom_objects
from keras_tuner.engine.hyperparameters import HyperParameters
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Deep Learning Libraries - GCPDS
from gcpds.image_segmentation.datasets.segmentation import OxfordIiitPet

# Deep Learning Libraries - TensorFlow specific
from tensorflow.keras.losses import Loss
from tensorflow.keras.metrics import Metric
from tensorflow.keras import Model, layers, regularizers

# Other Libraries
import gc
import json
import gdown
import itertools
import SimpleITK as sitk
from PIL import ImageFont
from dataclasses import dataclass
from matplotlib.style import available
from tensorflow.python.framework.ops import EagerTensor

warnings.filterwarnings("ignore") # Disable warnings

## Download OxfordPet dataset and generation of synthetic annotators with the introduction of different signal-to-noise ratio values

In [None]:
# Download OxfordPet dataset

dataset = OxfordIiitPet()
train_dataset, val_dataset, test_dataset = dataset()

## Original Masks

In [4]:
BATCH_SIZE = 128
NUM_EPOCHS = 150
TARGET_SHAPE = 128, 128

def fussion_mask(mask: EagerTensor) -> EagerTensor:
    """Fuses the object and border masks into a single mask.

    This function takes a mask tensor containing separate channels for the object, 
    background, and border, and fuses the object and border channels into a single 
    mask channel.

    Parameters:
        mask (EagerTensor): A tensor representing the segmentation mask with 
            object, background, and border channels.

    Returns:
        EagerTensor: A tensor representing the fused mask containing the sum 
            of the object and border channels.

    """
    obj, bg, border = tf.unstack(mask, axis=2)
    orig_shape = mask.shape
    new_shape = list(orig_shape)
    new_shape[-1] = 1
    return tf.reshape(tf.stack([obj + border]), new_shape)


def map_dataset(dataset, target_shape, batch_size):
    """Preprocesses and batches a dataset for training or evaluation.

    This function applies a series of transformations to each sample in the dataset 
    to prepare it for training or evaluation. It resizes images and masks to the 
    specified target shape, fuses mask channels if needed, and batches the dataset.

    Args:
        dataset (tf.data.Dataset): The input dataset containing images, masks, labels, and IDs.
        target_shape (tuple): A tuple specifying the desired shape of images and masks.
        batch_size (int): The batch size to use for training or evaluation.

    Returns:
        A preprocessed and batched dataset ready for training or evaluation.

    """
    # Resize images and masks to the target shape
    dataset_ = dataset.map(lambda img, mask, label, id_img: (img, mask),
                           num_parallel_calls=tf.data.AUTOTUNE)
    dataset_ = dataset_.map(lambda img, mask: (tf.image.resize(img, target_shape), 
                                                tf.image.resize(mask, target_shape)),
                            num_parallel_calls=tf.data.AUTOTUNE)
    
    # Fuse mask channels if needed
    dataset_ = dataset_.map(lambda img, mask: (img, fussion_mask(mask)),
                            num_parallel_calls=tf.data.AUTOTUNE)

    # Batch the dataset
    dataset_ = dataset_.batch(batch_size)
    return dataset_

original_train = map_dataset(train_dataset, TARGET_SHAPE, BATCH_SIZE)
original_val = map_dataset(val_dataset, TARGET_SHAPE, BATCH_SIZE)
original_test = map_dataset(test_dataset, TARGET_SHAPE, BATCH_SIZE)

In [None]:
for img,mask in original_train.take(1):
    print(mask.shape)
    fig, axes = plt.subplots(1, 2 , figsize=(10,5))
    axes[0].imshow(img[110])
    axes[0].set_title('Image')
    axes[0].axis('off')
    axes[1].imshow(mask[110][:,:,0])
    axes[1].set_title('Original mask')
    axes[1].axis('off')

### Loading of the different parts of the dataset

In [None]:
# Loading of the training part of the database in a tensor manner

original_X_train = []
original_y_train = []

for img, mask in original_train.take(8):
    original_X_train.append(img)
    original_y_train.append(mask)

original_X_train, original_y_train = tf.concat(original_X_train, axis=0), tf.concat(original_y_train, axis=0)
print(f'Tensor dimensions with training images: {original_X_train.shape} \nTensor dimensions with training masks: {original_y_train.shape}')

In [None]:
# Loading of the validation part of the database in a tensor manner

original_X_val = []
original_y_val = []

for img, mask in original_val.take(2):
    original_X_val.append(img)
    original_y_val.append(mask)

original_X_val, original_y_val = tf.concat(original_X_val, axis=0), tf.concat(original_y_val, axis=0)
print(f'Tensor dimensions with validation images: {original_X_val.shape}\nTensor dimensions with validation masks: {original_y_val.shape}')

In [None]:
# Loading of the testing part of the database in a tensor manner

original_X_test = []
original_y_test = []

for img, mask in original_test.take(2):
    original_X_test.append(img)
    original_y_test.append(mask)

original_X_test, original_y_test = tf.concat(original_X_test, axis=0), tf.concat(original_y_test, axis=0)
print(f'Tensor dimensions with test images: {original_X_test.shape}\nTensor dimensions with test masks: {original_y_test.shape}')

## Synthetics masks

In [None]:
# Download trained Unet network for OxfordPet segmentation task from Drive

model_url = "https://drive.google.com/file/d/1x39L3QNDMye1SJhKh1gf4YS-HRFLTs6G/view?usp=drive_link"
model_uri = model_url.split("/")[5]
!gdown $model_uri

model_extension = "keras"
paths = []

for file in os.listdir("."):
  if file.endswith(model_extension):
    paths.append(file)

model_path = paths[0]
print(f"Loading {model_path}...")
model_ann  = tf.keras.models.load_model(model_path, compile = False)

In [None]:
# Find last encoder convolution layer

def find_last_encoder_conv_layer(model):
    '''
    Finds the index of the last convolutional layer in the encoder part of the model.

    Parameters:
    model (keras.Model): The Keras model to search for the last encoder convolutional layer.

    Returns:
    int: Index of the last convolutional layer in the encoder part of the model.
    '''

    last_conv_encoder_layer = 0
    for i,layer in enumerate(model.layers):
        if (isinstance(layer, keras.layers.Conv2D)):
          last_conv_encoder_layer = i
        if (isinstance(layer, keras.layers.UpSampling2D)):
          break
    return last_conv_encoder_layer

last_conv_encoder_layer = find_last_encoder_conv_layer(model_ann)
last_conv_encoder_layer

In [11]:
# Compute and add noise to the target layer

def compute_snr(signal: float, noise_std: float) -> float:
    """Compute the Signal-to-Noise Ratio (SNR) in decibels.

    The Signal-to-Noise Ratio (SNR) measures the ratio of the power of a signal to the
    power of background noise. Higher SNR values indicate a stronger signal relative to
    the noise.

    Parameters:
        signal (float): The signal power.
        noise_std (float): The standard deviation of the background noise.

    Returns:
        float: The Signal-to-Noise Ratio (SNR) in decibels.

    """
    return 10 * np.log10(np.mean(signal ** 2) / noise_std ** 2)

class SnrType(Enum):
    """Enumeration representing different types of Signal-to-Noise Ratio (SNR) scales.

    This enumeration defines two types of SNR scales: 'log' and 'linear'. These types
    indicate whether the SNR values are represented in logarithmic or linear scale.

    Attributes:
        log (int): Represents the logarithmic scale for SNR values.
        linear (int): Represents the linear scale for SNR values.

    """
    log = 0
    linear = 1

def add_noise_to_layer_weights(model, layer, noise_snr, snr_type: SnrType = SnrType.log, verbose=0):
    """Adds noise to the weights of a specified layer in the model.

    This function adds noise to the weights of a specified layer in the model,
    simulating a certain signal-to-noise ratio (SNR) either in linear or logarithmic scale.

    Parameters:
        model (tf.keras.Model): The model to modify.
        layer (int): Index of the layer whose weights will be modified.
        noise_snr (float): Desired signal-to-noise ratio (SNR) for the added noise.
        snr_type (SnrType): Type of SNR scale to use, either 'log' (logarithmic) or 'linear'.
            Defaults to SnrType.log.
        verbose (int): Verbosity mode. If greater than 0, prints information about the noise
            and signal powers. Defaults to 0.

    Returns:
        float: The computed signal-to-noise ratio (SNR) after adding noise to the layer weights.

    """
    layer_weights = model.layers[layer].get_weights()

    sig_power = np.mean(layer_weights[0] ** 2)

    if snr_type == SnrType.log:
        noise_power = sig_power / (10 ** (noise_snr / 10))
    elif snr_type == SnrType.linear:
        noise_power = sig_power / noise_snr

    noise_std = noise_power ** (1 / 2)

    snr = compute_snr(layer_weights[0], noise_std)

    if verbose > 0:
        print(f"Adding noise for SNR: {noise_snr}\n\n")
        print(f"Signal power: {sig_power}")
        print(f"Noise power: {noise_power}\n\n")

    for i in range(layer_weights[0].shape[0]):
        for j in range(layer_weights[0].shape[1]):
            layer_weights[0][i][j] += np.random.randn(128, 128) * noise_std

    model.layers[last_conv_encoder_layer].set_weights(layer_weights)
    return snr

In [12]:
# Define the signal-to-noise ratio values for each synthetic annotator
values_to_test = [20,0,-15]

# Creation of the different models and their perturbations starting from the base model
def produce_disturbed_models(values_to_test, base_model_path):
    """Produces a list of disturbed models by adding noise to layer weights.

    This function loads a base model from the specified path and creates disturbed
    versions of it by adding noise to the weights of a specified layer. The noise
    level is controlled by the values provided in the `values_to_test` list.

    Parameters:
        values_to_test (list): A list of values representing the noise levels to test.
        base_model_path (str): The file path to the base model to load.

    Returns:
        Tuple containing two lists:
            - List of disturbed models, each with noise added to layer weights.
            - List of Signal-to-Noise Ratio (SNR) values corresponding to each disturbed model.

    """
    snr_values = []
    models = []

    for value in values_to_test:
        model_ = tf.keras.models.load_model(base_model_path, compile=False)
        snr = add_noise_to_layer_weights(model_, last_conv_encoder_layer, value)
        snr_values.append(snr)
        models.append(model_)

    return models, snr_values


disturbance_models, snr_values = produce_disturbed_models(values_to_test, model_path)

In [13]:
# Disturbance processing with different SNR ratios values for each database partition using the modified networks

BATCH_SIZE = 128
TARGET_SHAPE = (128, 128)
ORIGINAL_MODEL_SHAPE = 256, 256
NUM_ANNOTATORS = 3

def disturb_mask(model, image, model_shape, target_shape):
    """Disturbs a segmentation mask using a neural network model.

    This function takes an input image and passes it through the given neural network model
    to generate a disturbed segmentation mask. The input image is resized to fit the model's
    input shape, and the output mask is resized to match the target shape.

    Parameters:
        model (tf.keras.Model): A neural network model used to disturb the segmentation mask.
        image (tf.Tensor): Input image tensor.
        model_shape (tuple): Shape of the input expected by the model.
        target_shape (tuple): Target shape for the disturbed segmentation mask.

    Returns:
        A disturbed segmentation mask tensor.

    """
    return tf.image.resize(model(tf.image.resize(image, model_shape)), target_shape)


def mix_channels(mask, num_annotators):
    """Mixes the channels of a segmentation mask.

    This function creates a new tensor by mixing the channels of the input segmentation mask.
    It is commonly used in scenarios where binary segmentation masks are represented with
    multiple channels, each indicating the annotation of a different annotator.

    Parameters:
        mask (tensor): Input segmentation mask tensor with shape (batch_size, height, width, channels).
        num_annotators (int): Number of annotators whose annotations are included in the mask.

    Returns:
        A tensor representing the mixed channels segmentation mask with shape
        (batch_size, height, width, num_annotators).

    """
    return tf.stack([mask, 1 - mask], axis=-2)


def add_noisy_annotators(img: EagerTensor, models, model_shape, target_shape) -> EagerTensor:
    """Adds noise from multiple annotators to an input image.

    This function applies noise to an input image from multiple annotator models,
    creating a set of noisy annotations. It iterates through each model in the
    provided list of models, applying noise to the input image based on the
    characteristics of each model.

    Parameters:
        img (EagerTensor): The input image to which noise will be added.
        models (list): A list of annotator models used to generate noise.
        model_shape: The shape of the model's output.
        target_shape: The target shape of the output annotations.

    Returns:
        EagerTensor: A tensor representing the noisy annotations generated by
        applying noise from multiple annotators to the input image.

    """
    return tf.transpose([disturb_mask(model, img, model_shape=model_shape, target_shape=target_shape) for model in models], [2, 3, 1, 4, 0])


def map_dataset_MA(dataset, target_shape, model_shape, batch_size, num_annotators):
    """Preprocesses a dataset for multi-annotator segmentation tasks.

    This function performs a series of mapping operations on the input dataset
    to prepare it for training or evaluation in a multi-annotator segmentation
    scenario. It resizes images and masks, adds noisy annotations, reshapes masks,
    mixes channels, and batches the data.

    Parameters:
        dataset (tf.data.Dataset): Input dataset containing images, masks, labels, and image IDs.
        target_shape (tuple): Desired shape for the images and masks after resizing.
        model_shape (tuple): Shape required by the segmentation model.
        batch_size (int): Size of the batches to create.
        num_annotators (int): Number of annotators providing annotations for each image.

    Returns:
        A preprocessed dataset ready for training or evaluation.

    """
    dataset_ = dataset.map(lambda img, mask, label, id_img: (img, mask),
                           num_parallel_calls=tf.data.AUTOTUNE)

    dataset_ = dataset_.map(lambda img, mask: (tf.image.resize(img, target_shape),
                                                tf.image.resize(mask, target_shape)),
                             num_parallel_calls=tf.data.AUTOTUNE)

    dataset_ = dataset_.map(lambda img, mask: (img, add_noisy_annotators(tf.expand_dims(img, 0),
                                                                         disturbance_models,
                                                                         model_shape=model_shape,
                                                                         target_shape=target_shape)),
                             num_parallel_calls=tf.data.AUTOTUNE)

    dataset_ = dataset_.map(lambda img, mask: (img, tf.reshape(mask, (mask.shape[0], mask.shape[1], 1, mask.shape[-1]))),
                             num_parallel_calls=tf.data.AUTOTUNE)

    dataset_ = dataset_.map(lambda img, mask: (img, mix_channels(mask, num_annotators)),
                             num_parallel_calls=tf.data.AUTOTUNE)

    dataset_ = dataset_.map(lambda img, mask: (img, tf.squeeze(mask, axis=2)),
                             num_parallel_calls=tf.data.AUTOTUNE)

    dataset_ = dataset_.batch(batch_size)
    return dataset_



synthetic_train = map_dataset_MA(
    train_dataset,
    target_shape=TARGET_SHAPE,
    model_shape=ORIGINAL_MODEL_SHAPE,
    batch_size=BATCH_SIZE,
    num_annotators=NUM_ANNOTATORS)
synthetic_val = map_dataset_MA(
    val_dataset,
    target_shape=TARGET_SHAPE,
    model_shape=ORIGINAL_MODEL_SHAPE,
    batch_size=BATCH_SIZE,
    num_annotators=NUM_ANNOTATORS)

synthetic_test = map_dataset_MA(
    test_dataset,
    target_shape=TARGET_SHAPE,
    model_shape=ORIGINAL_MODEL_SHAPE,
    batch_size=BATCH_SIZE,
    num_annotators=NUM_ANNOTATORS)

In [None]:
# Plotting the different perturbations to a sample and the resulting dimensions

for img,mask in synthetic_train.take(1):
  print(f"Mask shape: {mask.shape} (batch_size * h * w * k * r) Img shape {img.shape}")
  fig, axes = plt.subplots(2,NUM_ANNOTATORS)
  fig.set_size_inches(16,7)
  for i in range(NUM_ANNOTATORS):
    axes[0][i].imshow((mask)[0,:,:,0,i])
    axes[0][i].set_title(f"Mask for annotator {i}")
    axes[0][i].axis('off')
    axes[1][i].imshow((mask)[0,:,:,-1,i])
    axes[1][i].axis('off')

### Loading of the different parts of the dataset

In [None]:
# Loading of the training part of the database in a tensor manner

synthetic_X_train = []
synthetic_y_train = []

for img, mask in synthetic_train.take(8):
    synthetic_X_train.append(img)
    synthetic_y_train.append(tf.where(mask[:,:,:,0,:]> 0.5, tf.ones_like(mask[:,:,:,0,:]), tf.zeros_like(mask[:,:,:,0,:])))

synthetic_X_train, synthetic_y_train = tf.concat(synthetic_X_train, axis=0), tf.concat(synthetic_y_train, axis=0)
print(f'Tensor dimensions with training images: {synthetic_X_train.shape} \nTensor dimensions with training masks: {synthetic_y_train.shape}')

In [None]:
# Loading of the validation part of the database in a tensor manner

synthetic_X_val = []
synthetic_y_val = []

for img, mask in synthetic_val.take(2):
    synthetic_X_val.append(img)
    synthetic_y_val.append(tf.where(mask[:,:,:,0,:]> 0.5, tf.ones_like(mask[:,:,:,0,:]), tf.zeros_like(mask[:,:,:,0,:])))

synthetic_X_val, synthetic_y_val = tf.concat(synthetic_X_val, axis=0), tf.concat(synthetic_y_val, axis=0)
print(f'Tensor dimensions with validation images: {synthetic_X_val.shape}\nTensor dimensions with validation masks: {synthetic_y_val.shape}')

In [None]:
# Loading of the testing part of the database in a tensor manner

synthetic_X_test = []
synthetic_y_test = []

for img, mask in synthetic_test.take(2):
    synthetic_X_test.append(img)
    synthetic_y_test.append(tf.where(mask[:,:,:,0,:]> 0.5, tf.ones_like(mask[:,:,:,0,:]), tf.zeros_like(mask[:,:,:,0,:])))

synthetic_X_test, synthetic_y_test = tf.concat(synthetic_X_test, axis=0), tf.concat(synthetic_y_test, axis=0)
print(f'Tensor dimensions with test images: {synthetic_X_test.shape}\nTensor dimensions with test masks: {synthetic_y_test.shape}')

## Simultaneous Truth And Performance Level Estimation (STAPLE)

The central mathematical equation of the STAPLE algorithm relates to the conditional probability of correct and incorrect classification. Below is the equation describing the probability of correct classification (sensitivity) and the probability of incorrect classification (specificity) for the k-th expert or algorithm:

$\text{sensitivity} = \lambda_k(0,0) + \lambda_k(1,1)$ and $\text{specificity} = \lambda_k(1,0) + \lambda_k(0,1)$

Where:
- $\lambda_k(r,s)$ represents the conditional probability that the true classification of pixel $m$ is $T(m) = r$.
- $B_k(m) = s$ represents the k-th expert or algorithm predicts.

The STAPLE algorithm uses the EM to iteratively estimate the quality of individual segmentations and generate a final segmentation by weighting the decisions of reliable algorithms more than those of less reliable algorithms.

[Source](https://link.springer.com/content/pdf/10.1007/978-3-642-11216-4_21.pdf)

In [90]:
def staple(masks, n_annotators):
    """
    Apply the STAPLE (Simultaneous Truth and Performance Level Estimation) algorithm to a set of binary masks.

    This function takes a tensor of masks and the number of annotators, and applies the STAPLE algorithm to generate
    a consensus segmentation mask. The masks are expected to be in the shape (samples, height, width, annotators).

    Parameters:
        masks (tf.Tensor): A tensor of shape (samples, height, width, annotators) containing the binary masks.
        n_annotators (int): The number of annotators who provided the masks.

    Returns:
        tf.Tensor: A tensor containing the consensus segmentation masks for all samples, with the same shape as the input
        masks, but with the annotators dimension reduced to a binary mask.
    """
    
    for sample in range(masks.shape[0]):
        
        annotations = tf.transpose(masks[sample], perm=[2, 0, 1]).numpy().astype(np.int16)
        
        seg_stack = []
        
        for annotator in range(n_annotators):
            seg_stack.append(sitk.GetImageFromArray(annotations[annotator]))

        # Run STAPLE algorithm
        STAPLE_seg_sitk = sitk.STAPLE(seg_stack, 1.0 ) # 1.0 specifies the foreground value

        # Convert back to numpy array and then to a tensor
        STAPLE_seg = tf.convert_to_tensor(sitk.GetArrayFromImage(STAPLE_seg_sitk)) 
        
        match sample:
            case 0:
                complete_staple_masks = STAPLE_seg
            case 1:
                complete_staple_masks = tf.stack([complete_staple_masks, STAPLE_seg])
            case _:
                complete_staple_masks = tf.concat([complete_staple_masks, tf.reshape(STAPLE_seg,(1,masks.shape[1],masks.shape[2]))], axis=0)

    # Thresholded to return a binary mask
    complete_staple_masks = tf.where(complete_staple_masks> 0.5, tf.ones_like(complete_staple_masks), tf.zeros_like(complete_staple_masks)) #
    
    return complete_staple_masks

In [91]:
staple_y_train = staple(synthetic_y_train,3)
staple_y_val = staple(synthetic_y_val,3)
staple_y_test = staple(synthetic_y_test,3)

In [None]:
rows, columns = 4,3
sample = 333
fig, axes = plt.subplots(rows,columns,figsize=(15,8))

axes[0,0].imshow(original_X_train[sample])
axes[0,0].set_title('Image in original train')
axes[0,2].imshow(synthetic_X_train[sample])
axes[0,2].set_title('Image in synthetic train')

axes[1,1].imshow(original_y_train[sample])
axes[1,1].set_title('Original mask for image')

axes[2,0].imshow(synthetic_y_train[sample,:,:,0])
axes[2,0].set_title('Annotator 1 mask for image')
axes[2,1].imshow(synthetic_y_train[sample,:,:,1])
axes[2,1].set_title('Annotator 2 mask for image')
axes[2,2].imshow(synthetic_y_train[sample,:,:,2])
axes[2,2].set_title('Annotator 3 mask for image')

axes[3,1].imshow(staple_y_train[sample])
axes[3,1].set_title('STAPLE mask for image')

[axes[i, k].axis('off') for i in range(rows) for k in range(columns)]

## Definition of performance metrics

### DICE metric

$$\text{Dice} = {2 \cdot |\text{Intersection}| + \text{smooth} \over |\text{Union}| + \text{smooth}}$$

Where:

$|\text{Intersection}| = \sum_{i=1}^{N} y\_{true\_i} \cdot y\_{pred\_i}$, $|\text{Union}| = \sum_{i=1}^{N} y\_{true\_i} + \sum_{i=1}^{N} y\_{pred\_i}$


- $N$ is the total number of elements in the segmentation masks.
- $y\_{true\_i}$ and $y\_{pred\_i}$ represent the value of the i-th element in the ground truth and predicted segmentation masks, respectively.
- $\text{smooth}$ is a smoothing parameter to avoid division by zero.

In [93]:
# Definition of the DiceCoefficientMetric

def dice_metric(y_true, y_pred, axis=(1, 2), smooth=1e-5):
    """Computes the Dice coefficient metric for evaluating semantic segmentation.

    This function calculates the Dice coefficient metric, which measures the similarity
    between ground truth and predicted segmentation masks.

    Parameters:
        y_true (tensor): Ground truth segmentation masks.
        y_pred (tensor): Predicted segmentation masks.
        axis (tuple of int): Axis along which to compute sums. Defaults to (1, 2).
        smooth (float): A smoothing parameter to avoid division by zero. Defaults to 1e-5.
        num_annotators (int): Number of annotators. Defaults to 3.

    Returns:
        A scalar value representing the average Dice coefficient metric.
    """
    y_true = tf.cast(tf.squeeze(y_true, axis=-1), tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=axis)
    union = tf.reduce_sum(y_true, axis=axis) + tf.reduce_sum(y_pred, axis=axis)
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

### Jaccard metric

$$\text{Jaccard} = {|\text{Intersection}| + \text{smooth} \over |\text{Union}| + \text{smooth}}$$

Where:

$|\text{Intersection}| = \sum_{i=1}^{N} y\_{true\_i} \cdot y\_{pred\_i}$, $|\text{Union}| = \sum_{i=1}^{N} y\_{true\_i} + \sum_{i=1}^{N} y\_{pred\_i} - |\text{Intersection}|$
- $N$ is the total number of elements in the segmentation masks.
- $y\_{true\_i}$ and $y\_{pred\_i}$ represent the value of the i-th element in the ground truth and predicted segmentation masks, respectively.
- $\text{smooth}$ is a small smoothing parameter to avoid division by zero.

In [94]:
# Definition of the JaccardMetric

def jaccard_metric(y_true, y_pred, axis=(1, 2), smooth=1e-5, num_annotators=3):
    """Computes the Jaccard similarity coefficient as a metric for semantic segmentation.

    The Jaccard similarity coefficient, also known as the Intersection over Union (IoU),
    measures the similarity between two sets by comparing their intersection to their union.
    In the context of semantic segmentation, it quantifies the overlap between the ground
    truth segmentation masks and the predicted segmentation masks.

    Parameters:
        y_true (tensor): Ground truth segmentation masks.
        y_pred (tensor): Predicted segmentation masks.
        axis (tuple of int): Axes along which to compute sums. Defaults to (1, 2).
        smooth (float): A small smoothing parameter to avoid division by zero. Defaults to 1e-5.
        num_annotators (int): Number of annotators. Defaults to 3.

    Returns:
        A tensor representing the mean Jaccard similarity coefficient.

    """ 
    y_true = tf.cast(tf.squeeze(y_true, axis=-1), tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=axis)
    union = tf.reduce_sum(y_true, axis=axis) + tf.reduce_sum(y_pred, axis=axis) - intersection
    jaccard = (intersection + smooth) / (union + smooth)
    return tf.reduce_mean(jaccard)

### Sensitivity metric

$$\text{Sensitivity} = {\text{True Positives} \over \text{Actual Positives} + \text{smooth}}$$

Where:

$\text{True Positives} = \sum_{i=1}^{N} y\_{true\_i} \cdot y\_{pred\_i}$, $\text{Actual Positives} = \sum_{i=1}^{N} y\_{true\_i}$


- $N$ is the total number of elements in the labels.
- $y\_{true\_i}$ and $y\_{pred\_i}$ represent the value of the i-th element in the ground truth and predicted labels, respectively.
- $\text{smooth}$ is a small value added to the denominator to avoid division by zero.

In [95]:
# Definition of the SensitivityMetric

def sensitivity_metric(y_true, y_pred, axis=(1, 2), smooth=1e-5, num_annotators=3):
    """Computes the sensitivity as a metric for semantic segmentation.

    Sensitivity, also known as true positive rate or recall, measures the proportion
    of actual positives that are correctly identified by the model. It is computed
    as the ratio of true positives to the sum of true positives and false negatives.

    Parameters:
        y_true (tensor): Ground truth labels.
        y_pred (tensor): Predicted probabilities or labels.
        axis (tuple): Axes over which to perform the reduction. Defaults to (1, 2).
        smooth (float): A small value added to the denominator to avoid division by zero. Defaults to 1e-5.
        num_annotators (int): Number of annotators. Defaults to 3.

    Returns:
        The sensitivity metric averaged over the specified axes.

    """
    y_true = tf.cast(tf.squeeze(y_true, axis=-1), tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    true_positives = tf.reduce_sum(y_true * y_pred, axis=axis)
    actual_positives = tf.reduce_sum(y_true, axis=axis)
    sensitivity = true_positives / (actual_positives + smooth)
    return tf.reduce_mean(sensitivity)

### Specificity metric

$$\text{Specificity} = {\text{True Negatives} \over \text{Actual Negatives} + \text{smooth}}$$

Where:

$\text{True Negatives} = \sum_{i=1}^{N} (1 - y\_{true\_i}) \cdot (1 - y\_{pred\_i})$, $\text{Actual Negatives} = \sum_{i=1}^{N} (1 - y\_{true\_i})$

- $N$ is the total number of samples.
- $y\_{true\_i}$ and $y\_{pred\_i}$ represent the ground truth label and predicted probability (or binary prediction) for the i-th sample, respectively.
- $\text{smooth}$ is a smoothing term to avoid division by zero.

In [96]:
# Definition of the SpecificityMetric

def specificity_metric(y_true, y_pred, axis=(1, 2), smooth=1e-5, num_annotators=3):
    """Computes the specificity as a metric for semantic segmentation.

    Specificity measures the proportion of actual negative cases that were correctly
    identified as such. It is complementary to sensitivity (recall).

    Parameters:
        y_true (tensor): Ground truth binary labels.
        y_pred (tensor): Predicted probabilities or binary predictions.
        axis (tuple): Axes over which to perform reduction. Defaults to (1, 2).
        smooth (float): Smoothing term to avoid division by zero. Defaults to 1e-5.
        num_annotators (int): Number of annotators. Defaults to 3.

    Returns:
        A tensor representing the specificity metric.

    """
    y_true = tf.cast(tf.squeeze(y_true, axis=-1), tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    true_negatives = tf.reduce_sum((1 - y_true) * (1 - y_pred), axis=axis)
    actual_negatives = tf.reduce_sum(1 - y_true, axis=axis)
    specificity = true_negatives / (actual_negatives + smooth)
    return tf.reduce_mean(specificity)

## Measuring performance with metrics for STAPLE

In [None]:
dice_staple_y_train = dice_metric(original_y_train,staple_y_train)
dice_staple_y_val = dice_metric(original_y_val,staple_y_val)
dice_staple_y_test = dice_metric(original_y_test,staple_y_test)
print(dice_staple_y_train,dice_staple_y_val,dice_staple_y_test)

In [None]:
jaccard_staple_y_train = jaccard_metric(original_y_train,staple_y_train)
jaccard_staple_y_val = jaccard_metric(original_y_val,staple_y_val)
jaccard_staple_y_test = jaccard_metric(original_y_test,staple_y_test)
print(jaccard_staple_y_train,jaccard_staple_y_val,jaccard_staple_y_test)

In [None]:
sensitivity_staple_y_train = sensitivity_metric(original_y_train,staple_y_train)
sensitivity_staple_y_val = sensitivity_metric(original_y_val,staple_y_val)
sensitivity_staple_y_test = sensitivity_metric(original_y_test,staple_y_test)
print(sensitivity_staple_y_train,sensitivity_staple_y_val,sensitivity_staple_y_test)

In [None]:
specificity_staple_y_train = specificity_metric(original_y_train,staple_y_train)
specificity_staple_y_val = specificity_metric(original_y_val,staple_y_val)
specificity_staple_y_test = specificity_metric(original_y_test,staple_y_test)
print(specificity_staple_y_train,specificity_staple_y_val,specificity_staple_y_test)