## Chained deep learning using generalized cross entropy for multiple annotators segmentation



## Loss functions for segmentation in deep learning


Given a $k$ class multiple annotators segmentation problem with a dataset like the following'

$$\mathbf X \in \mathbb{R}^{W \times H}, \{ \mathbf Y_r \in \{0,1\}^{W \times H \times K} \}_{r=1}^R; \;\; \mathbf {\hat Y} \in [0,1]^{W\times H \times K} = f(\mathbf X)$$

The segmentation mask function will map input output as follows:

$$f: \mathbb  R ^{W\times H} \to [0,1]^{W\times H\times K}$$

$\mathbf Y$ will satisfy the following condition for being a softmax-like representation:

$$\mathbf Y_r[w,h,:] \mathbf{1} ^ \top _ k = 1; \;\; w \in W, h \in H$$

Now, let's suppose the existence of an annotators reliability map estimation $\Lambda_r; \; r \in R$;


$$\bigg\{ \Lambda_r (\mathbf X; \theta ) \in [0,1] ^{W\times H} \bigg\}_{r=1}^R $$


Then, our $TGCE_{SS}$:


$$TGCE_{SS}(\mathbf{Y}_r,f(\mathbf X;\theta) | \mathbf{\Lambda}_r (\mathbf X;\theta)) =\mathbb E_{r} \left\{ \mathbb E_{w,h}  \left\{ \Lambda_r (\mathbf X; \theta) \circ \mathbb E_k \bigg\{    \mathbf Y_r \circ \bigg( \frac{\mathbf 1 _{W\times H \times K} - f(\mathbf X;\theta) ^{\circ q }}{q} \bigg); k \in K  \bigg\}  + \\ \left(\mathbf 1 _{W \times H } - \Lambda _r (\mathbf X;\theta)\right) \circ \bigg(   \frac{\mathbf 1_{W\times H} - (\frac {1}{k} \mathbf 1_{W\times H})^{\circ q}}{q} \bigg); w \in W, h \in H \right\};r\in R\right\} $$


Where $q \in (0,1)$

Total Loss for a given batch holding $N$ samples:

$$\mathscr{L}\left(\mathbf{Y}_r[n],f(\mathbf X[n];\theta) | \mathbf{\Lambda}_r (\mathbf X[n];\theta)\right)  = \frac{1}{N} \sum_{n}^NTGCE_{SS}(\mathbf{Y}_r[n],f(\mathbf X[n];\theta) | \mathbf{\Lambda}_r (\mathbf X[n];\theta))$$

## Download and importing of libraries

In [1]:
%%capture
!pip install --upgrade keras==2.15.0
!pip install -U git+https://github.com/UN-GCPDS/python-gcpds.image_segmentation.git >> /tmp/null
!pip install git+https://github.com/qubvel/classification_models.git

In [None]:
# General Libraries
import os
import time
import shutil
import random
import warnings
import pandas as pd
import seaborn as sns
from enum import auto, Enum
from functools import partial
from datetime import datetime

# Image Processing Libraries
import cv2
import numpy as np
import matplotlib.pyplot as plt

# Deep Learning Libraries
import tensorflow as tf
from tensorflow import keras
from keras_tuner import Objective
from keras_tuner import HyperModel
import tensorflow.keras.backend as K
from keras.layers import Layer, Activation
from keras_tuner import BayesianOptimization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.utils import get_custom_objects
from keras_tuner.engine.hyperparameters import HyperParameters
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras_tuner import BayesianOptimization as OriginalBayesianOptimization
from keras_tuner.oracles import BayesianOptimizationOracle as OriginalBayesianOptimizationOracle

# Deep Learning Libraries - GCPDS
from gcpds.image_segmentation.datasets.segmentation import OxfordIiitPet

# Machine Learning Libraries - Sklearn
import sklearn
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C, DotProduct, ExpSineSquared, Matern

# Deep Learning Libraries - TensorFlow specific
from tensorflow.keras.losses import Loss
from tensorflow.keras.metrics import Metric
from tensorflow.keras import Model, layers, regularizers

# Classification models 
from classification_models.keras import Classifiers

# Other Libraries
import gc
import json
import gdown
import itertools
import SimpleITK as sitk
from PIL import ImageFont
from dataclasses import dataclass
from matplotlib.style import available
from tensorflow.python.framework.ops import EagerTensor

warnings.filterwarnings("ignore") # Disable warnings

## Download OxfordPet dataset and generation of synthetic annotators with the introduction of different signal-to-noise ratio values

In [None]:
# Download OxfordPet dataset

dataset = OxfordIiitPet()
train_dataset, val_dataset, test_dataset = dataset()

## Original Masks

In [4]:
BATCH_SIZE = 128
TARGET_SHAPE = 256, 256

def fussion_mask(mask: EagerTensor) -> EagerTensor:
    """Fuses the object and border masks into a single mask.

    This function takes a mask tensor containing separate channels for the object, 
    background, and border, and fuses the object and border channels into a single 
    mask channel.

    Parameters:
        mask (EagerTensor): A tensor representing the segmentation mask with 
            object, background, and border channels.

    Returns:
        EagerTensor: A tensor representing the fused mask containing the sum 
            of the object and border channels.

    """
    obj, bg, border = tf.unstack(mask, axis=2)
    orig_shape = mask.shape
    new_shape = list(orig_shape)
    new_shape[-1] = 1
    return tf.reshape(tf.stack([obj + border]), new_shape)


def map_dataset(dataset, target_shape, batch_size):
    """Preprocesses and batches a dataset for training or evaluation.

    This function applies a series of transformations to each sample in the dataset 
    to prepare it for training or evaluation. It resizes images and masks to the 
    specified target shape, fuses mask channels if needed, and batches the dataset.

    Args:
        dataset (tf.data.Dataset): The input dataset containing images, masks, labels, and IDs.
        target_shape (tuple): A tuple specifying the desired shape of images and masks.
        batch_size (int): The batch size to use for training or evaluation.

    Returns:
        A preprocessed and batched dataset ready for training or evaluation.

    """
    # Resize images and masks to the target shape
    dataset_ = dataset.map(lambda img, mask, label, id_img: (img, mask),
                           num_parallel_calls=tf.data.AUTOTUNE)
    dataset_ = dataset_.map(lambda img, mask: (tf.image.resize(img, target_shape), 
                                                tf.image.resize(mask, target_shape)),
                            num_parallel_calls=tf.data.AUTOTUNE)
    
    # Fuse mask channels if needed
    dataset_ = dataset_.map(lambda img, mask: (img, fussion_mask(mask)),
                            num_parallel_calls=tf.data.AUTOTUNE)

    # Batch the dataset
    dataset_ = dataset_.batch(batch_size)
    return dataset_

original_train = map_dataset(train_dataset, TARGET_SHAPE, BATCH_SIZE)
original_val = map_dataset(val_dataset, TARGET_SHAPE, BATCH_SIZE)
original_test = map_dataset(test_dataset, TARGET_SHAPE, BATCH_SIZE)

In [None]:
for img,mask in original_train.take(1):
    print(mask.shape)
    fig, axes = plt.subplots(1, 2 , figsize=(10,5))
    axes[0].imshow(img[110])
    axes[0].set_title('Image')
    axes[0].axis('off')
    axes[1].imshow(mask[110][:,:,0])
    axes[1].set_title('Original mask')
    axes[1].axis('off')

### Loading of the different parts of the dataset

In [None]:
# Loading of the training part of the database in a tensor manner

original_X_train = []
original_y_train = []

for img, mask in original_train.take(8):
    original_X_train.append(img)
    original_y_train.append(mask)

original_X_train, original_y_train = tf.concat(original_X_train, axis=0), tf.concat(original_y_train, axis=0)
print(f'Tensor dimensions with training images: {original_X_train.shape} \nTensor dimensions with training masks: {original_y_train.shape}')

In [None]:
# Loading of the validation part of the database in a tensor manner

original_X_val = []
original_y_val = []

for img, mask in original_val.take(2):
    original_X_val.append(img)
    original_y_val.append(mask)

original_X_val, original_y_val = tf.concat(original_X_val, axis=0), tf.concat(original_y_val, axis=0)
print(f'Tensor dimensions with validation images: {original_X_val.shape}\nTensor dimensions with validation masks: {original_y_val.shape}')

In [None]:
# Loading of the testing part of the database in a tensor manner

original_X_test = []
original_y_test = []

for img, mask in original_test.take(2):
    original_X_test.append(img)
    original_y_test.append(mask)

original_X_test, original_y_test = tf.concat(original_X_test, axis=0), tf.concat(original_y_test, axis=0)
print(f'Tensor dimensions with test images: {original_X_test.shape}\nTensor dimensions with test masks: {original_y_test.shape}')

## Synthetics masks

In [None]:
# Download trained Unet network for OxfordPet segmentation task from Drive

model_url = "https://drive.google.com/file/d/1x39L3QNDMye1SJhKh1gf4YS-HRFLTs6G/view?usp=drive_link"
model_uri = model_url.split("/")[5]
!gdown $model_uri

model_extension = "keras"
paths = []

for file in os.listdir("."):
  if file.endswith(model_extension):
    paths.append(file)

model_path = paths[0]
print(f"Loading {model_path}...")
model_ann  = tf.keras.models.load_model(model_path, compile = False)

In [None]:
# Find last encoder convolution layer

def find_last_encoder_conv_layer(model):
    '''
    Finds the index of the last convolutional layer in the encoder part of the model.

    Parameters:
    model (keras.Model): The Keras model to search for the last encoder convolutional layer.

    Returns:
    int: Index of the last convolutional layer in the encoder part of the model.
    '''

    last_conv_encoder_layer = 0
    for i,layer in enumerate(model.layers):
        if (isinstance(layer, keras.layers.Conv2D)):
          last_conv_encoder_layer = i
        if (isinstance(layer, keras.layers.UpSampling2D)):
          break
    return last_conv_encoder_layer

last_conv_encoder_layer = find_last_encoder_conv_layer(model_ann)
last_conv_encoder_layer

In [11]:
# Compute and add noise to the target layer

def compute_snr(signal: float, noise_std: float) -> float:
    """Compute the Signal-to-Noise Ratio (SNR) in decibels.

    The Signal-to-Noise Ratio (SNR) measures the ratio of the power of a signal to the
    power of background noise. Higher SNR values indicate a stronger signal relative to
    the noise.

    Parameters:
        signal (float): The signal power.
        noise_std (float): The standard deviation of the background noise.

    Returns:
        float: The Signal-to-Noise Ratio (SNR) in decibels.

    """
    return 10 * np.log10(np.mean(signal ** 2) / noise_std ** 2)

class SnrType(Enum):
    """Enumeration representing different types of Signal-to-Noise Ratio (SNR) scales.

    This enumeration defines two types of SNR scales: 'log' and 'linear'. These types
    indicate whether the SNR values are represented in logarithmic or linear scale.

    Attributes:
        log (int): Represents the logarithmic scale for SNR values.
        linear (int): Represents the linear scale for SNR values.

    """
    log = 0
    linear = 1

def add_noise_to_layer_weights(model, layer, noise_snr, snr_type: SnrType = SnrType.log, verbose=0):
    """Adds noise to the weights of a specified layer in the model.

    This function adds noise to the weights of a specified layer in the model,
    simulating a certain signal-to-noise ratio (SNR) either in linear or logarithmic scale.

    Parameters:
        model (tf.keras.Model): The model to modify.
        layer (int): Index of the layer whose weights will be modified.
        noise_snr (float): Desired signal-to-noise ratio (SNR) for the added noise.
        snr_type (SnrType): Type of SNR scale to use, either 'log' (logarithmic) or 'linear'.
            Defaults to SnrType.log.
        verbose (int): Verbosity mode. If greater than 0, prints information about the noise
            and signal powers. Defaults to 0.

    Returns:
        float: The computed signal-to-noise ratio (SNR) after adding noise to the layer weights.

    """
    layer_weights = model.layers[layer].get_weights()

    sig_power = np.mean(layer_weights[0] ** 2)

    if snr_type == SnrType.log:
        noise_power = sig_power / (10 ** (noise_snr / 10))
    elif snr_type == SnrType.linear:
        noise_power = sig_power / noise_snr

    noise_std = noise_power ** (1 / 2)

    snr = compute_snr(layer_weights[0], noise_std)

    if verbose > 0:
        print(f"Adding noise for SNR: {noise_snr}\n\n")
        print(f"Signal power: {sig_power}")
        print(f"Noise power: {noise_power}\n\n")

    for i in range(layer_weights[0].shape[0]):
        for j in range(layer_weights[0].shape[1]):
            layer_weights[0][i][j] += np.random.randn(128, 128) * noise_std

    model.layers[last_conv_encoder_layer].set_weights(layer_weights)
    return snr

In [12]:
# Define the signal-to-noise ratio values for each synthetic annotator
values_to_test = [20,0,-15]

# Creation of the different models and their perturbations starting from the base model
def produce_disturbed_models(values_to_test, base_model_path):
    """Produces a list of disturbed models by adding noise to layer weights.

    This function loads a base model from the specified path and creates disturbed
    versions of it by adding noise to the weights of a specified layer. The noise
    level is controlled by the values provided in the `values_to_test` list.

    Parameters:
        values_to_test (list): A list of values representing the noise levels to test.
        base_model_path (str): The file path to the base model to load.

    Returns:
        Tuple containing two lists:
            - List of disturbed models, each with noise added to layer weights.
            - List of Signal-to-Noise Ratio (SNR) values corresponding to each disturbed model.

    """
    snr_values = []
    models = []

    for value in values_to_test:
        model_ = tf.keras.models.load_model(base_model_path, compile=False)
        snr = add_noise_to_layer_weights(model_, last_conv_encoder_layer, value)
        snr_values.append(snr)
        models.append(model_)

    return models, snr_values


disturbance_models, snr_values = produce_disturbed_models(values_to_test, model_path)

In [13]:
# Disturbance processing with different SNR ratios values for each database partition using the modified networks

BATCH_SIZE = 128
TARGET_SHAPE = (256, 256)
ORIGINAL_MODEL_SHAPE = 256, 256
NUM_ANNOTATORS = 3

def disturb_mask(model, image, model_shape, target_shape):
    """Disturbs a segmentation mask using a neural network model.

    This function takes an input image and passes it through the given neural network model
    to generate a disturbed segmentation mask. The input image is resized to fit the model's
    input shape, and the output mask is resized to match the target shape.

    Parameters:
        model (tf.keras.Model): A neural network model used to disturb the segmentation mask.
        image (tf.Tensor): Input image tensor.
        model_shape (tuple): Shape of the input expected by the model.
        target_shape (tuple): Target shape for the disturbed segmentation mask.

    Returns:
        A disturbed segmentation mask tensor.

    """
    return tf.image.resize(model(tf.image.resize(image, model_shape)), target_shape)


def mix_channels(mask, num_annotators):
    """Mixes the channels of a segmentation mask.

    This function creates a new tensor by mixing the channels of the input segmentation mask.
    It is commonly used in scenarios where binary segmentation masks are represented with
    multiple channels, each indicating the annotation of a different annotator.

    Parameters:
        mask (tensor): Input segmentation mask tensor with shape (batch_size, height, width, channels).
        num_annotators (int): Number of annotators whose annotations are included in the mask.

    Returns:
        A tensor representing the mixed channels segmentation mask with shape
        (batch_size, height, width, num_annotators).

    """
    return tf.stack([mask, 1 - mask], axis=-2)


def add_noisy_annotators(img: EagerTensor, models, model_shape, target_shape) -> EagerTensor:
    """Adds noise from multiple annotators to an input image.

    This function applies noise to an input image from multiple annotator models,
    creating a set of noisy annotations. It iterates through each model in the
    provided list of models, applying noise to the input image based on the
    characteristics of each model.

    Parameters:
        img (EagerTensor): The input image to which noise will be added.
        models (list): A list of annotator models used to generate noise.
        model_shape: The shape of the model's output.
        target_shape: The target shape of the output annotations.

    Returns:
        EagerTensor: A tensor representing the noisy annotations generated by
        applying noise from multiple annotators to the input image.

    """
    return tf.transpose([disturb_mask(model, img, model_shape=model_shape, target_shape=target_shape) for model in models], [2, 3, 1, 4, 0])


def map_dataset_MA(dataset, target_shape, model_shape, batch_size, num_annotators):
    """Preprocesses a dataset for multi-annotator segmentation tasks.

    This function performs a series of mapping operations on the input dataset
    to prepare it for training or evaluation in a multi-annotator segmentation
    scenario. It resizes images and masks, adds noisy annotations, reshapes masks,
    mixes channels, and batches the data.

    Parameters:
        dataset (tf.data.Dataset): Input dataset containing images, masks, labels, and image IDs.
        target_shape (tuple): Desired shape for the images and masks after resizing.
        model_shape (tuple): Shape required by the segmentation model.
        batch_size (int): Size of the batches to create.
        num_annotators (int): Number of annotators providing annotations for each image.

    Returns:
        A preprocessed dataset ready for training or evaluation.

    """
    dataset_ = dataset.map(lambda img, mask, label, id_img: (img, mask),
                           num_parallel_calls=tf.data.AUTOTUNE)

    dataset_ = dataset_.map(lambda img, mask: (tf.image.resize(img, target_shape),
                                                tf.image.resize(mask, target_shape)),
                             num_parallel_calls=tf.data.AUTOTUNE)

    dataset_ = dataset_.map(lambda img, mask: (img, add_noisy_annotators(tf.expand_dims(img, 0),
                                                                         disturbance_models,
                                                                         model_shape=model_shape,
                                                                         target_shape=target_shape)),
                             num_parallel_calls=tf.data.AUTOTUNE)

    dataset_ = dataset_.map(lambda img, mask: (img, tf.reshape(mask, (mask.shape[0], mask.shape[1], 1, mask.shape[-1]))),
                             num_parallel_calls=tf.data.AUTOTUNE)

    dataset_ = dataset_.map(lambda img, mask: (img, mix_channels(mask, num_annotators)),
                             num_parallel_calls=tf.data.AUTOTUNE)

    dataset_ = dataset_.map(lambda img, mask: (img, tf.squeeze(mask, axis=2)),
                             num_parallel_calls=tf.data.AUTOTUNE)

    dataset_ = dataset_.batch(batch_size)
    return dataset_



synthetic_train = map_dataset_MA(
    train_dataset,
    target_shape=TARGET_SHAPE,
    model_shape=ORIGINAL_MODEL_SHAPE,
    batch_size=BATCH_SIZE,
    num_annotators=NUM_ANNOTATORS)
synthetic_val = map_dataset_MA(
    val_dataset,
    target_shape=TARGET_SHAPE,
    model_shape=ORIGINAL_MODEL_SHAPE,
    batch_size=BATCH_SIZE,
    num_annotators=NUM_ANNOTATORS)

synthetic_test = map_dataset_MA(
    test_dataset,
    target_shape=TARGET_SHAPE,
    model_shape=ORIGINAL_MODEL_SHAPE,
    batch_size=BATCH_SIZE,
    num_annotators=NUM_ANNOTATORS)

In [None]:
# Plotting the different perturbations to a sample and the resulting dimensions

for img,mask in synthetic_train.take(1):
  print(f"Mask shape: {mask.shape} (batch_size * h * w * k * r) Img shape {img.shape}")
  fig, axes = plt.subplots(2,NUM_ANNOTATORS)
  fig.set_size_inches(16,7)
  for i in range(NUM_ANNOTATORS):
    axes[0][i].imshow((mask)[0,:,:,0,i])
    axes[0][i].set_title(f"Mask for annotator {i}")
    axes[0][i].axis('off')
    axes[1][i].imshow((mask)[0,:,:,-1,i])
    axes[1][i].axis('off')

### Loading of the different parts of the dataset

In [None]:
# Loading of the training part of the database in a tensor manner

synthetic_X_train = []
synthetic_y_train = []

for img, mask in synthetic_train.take(8):
    synthetic_X_train.append(img)
    synthetic_y_train.append(tf.reshape(mask,[BATCH_SIZE, TARGET_SHAPE[0], TARGET_SHAPE[1], NUM_ANNOTATORS*2]))

synthetic_X_train, synthetic_y_train = tf.concat(synthetic_X_train, axis=0), tf.concat(synthetic_y_train, axis=0)
print(f'Tensor dimensions with training images: {synthetic_X_train.shape} \nTensor dimensions with training masks: {synthetic_y_train.shape}')

In [None]:
# Loading of the validation part of the database in a tensor manner

synthetic_X_val = []
synthetic_y_val = []

for img, mask in synthetic_val.take(2):
    synthetic_X_val.append(img)
    synthetic_y_val.append(tf.reshape(mask,[BATCH_SIZE, TARGET_SHAPE[0], TARGET_SHAPE[1], NUM_ANNOTATORS*2]))

synthetic_X_val, synthetic_y_val = tf.concat(synthetic_X_val, axis=0), tf.concat(synthetic_y_val, axis=0)
print(f'Tensor dimensions with validation images: {synthetic_X_val.shape}\nTensor dimensions with validation masks: {synthetic_y_val.shape}')

In [None]:
# Loading of the testing part of the database in a tensor manner

synthetic_X_test = []
synthetic_y_test = []

for img, mask in synthetic_test.take(2):
    synthetic_X_test.append(img)
    synthetic_y_test.append(tf.reshape(mask,[BATCH_SIZE, TARGET_SHAPE[0], TARGET_SHAPE[1], NUM_ANNOTATORS*2]))

synthetic_X_test, synthetic_y_test = tf.concat(synthetic_X_test, axis=0), tf.concat(synthetic_y_test, axis=0)
print(f'Tensor dimensions with test images: {synthetic_X_test.shape}\nTensor dimensions with test masks: {synthetic_y_test.shape}')

## Loss for Tuned Gradual Class-wise Ensemble for Semi-supervised Learning

$$TGCE_{SS}(\mathbf{Y}_r,f(\mathbf X;\theta) | \mathbf{\Lambda}_r (\mathbf X;\theta)) =\mathbb E_{r} \left\{ \mathbb E_{w,h} \left\{  \Lambda_r (\mathbf X; \theta) \circ \mathbb E_k \bigg\{    \mathbf Y_r \circ \bigg( \frac{\mathbf 1 _{W\times H \times K} - f(\mathbf X;\theta) ^{\circ q }}{q} \bigg); k \in K  \bigg\}  + \\ \left(\mathbf 1 _{W \times H } - \Lambda _r (\mathbf X;\theta) \right) \circ \bigg(   \frac{\mathbf 1_{W\times H} - (\frac {1}{k} \mathbf 1_{W\times H})^{\circ q}}{q} \bigg); w \in W, h \in H \right\};r\in R\right\} $$

In [18]:
# Custom loss function: TGCE SS

class TGCE_SS(Loss):
    """Tuned Gradual Class-wise Ensemble for Semi-supervised Learning Loss.

    This loss function implements the Tuned Gradual Class-wise Ensemble (TGCE) loss for
    semi-supervised learning tasks. It is designed to improve the robustness of models
    against noisy annotations by considering the reliability of annotators. The loss function
    penalizes the disagreement between the model predictions and the ground truth labels,
    taking into account the reliability of annotators.

    Methods
    ----------
    call(y_true, y_pred)
    get_config()

    """

    def __init__(self, q=0.1, name='TGCE_SS', R=3, K_=2, smooth=1e-5, fill_channels=1, **kwargs):
        """Initializes the TGCE_SS loss object.

        Parameters:
            q (float): The tuning parameter for controlling the smoothness of the ensemble.
                Defaults to 0.1.
            name (str): Name of the loss function. Defaults to 'TGCE_SS'.
            R (int): Number of annotators. Defaults to NUM_ANNOTATORS.
            K_ (int): Number of classes. Defaults to 2.
            smooth (float): Smoothing parameter to avoid division by zero. Defaults to 1e-5.
            fill_channels (int): Number of filler channels. Defaults to 1.
            **kwargs: Additional arguments passed to the parent class.

        """
        self.q = q
        self.R = R
        self.K_ = K_
        self.smooth = smooth
        self.fill_channels = fill_channels
        super().__init__(name=name, **kwargs)

    def call(self, y_true, y_pred):
        """Computes the TGCE_SS loss.

        Parameters:
            y_true (tensor): Ground truth labels.
            y_pred (tensor): Predicted probabilities.

        Returns:
            Containing the loss value.
        """
        y_pred = y_pred[..., :-self.fill_channels] # Disregarding backfill channels
        y_true = tf.reshape(y_true, tuple(y_true.shape[:-1])+(self.K_,self.R))
        Lambda_r = y_pred[..., self.K_:]  # Annotators reliability -> extra cnn upsampling layer
        y_pred_ = y_pred[..., :self.K_]  # Segmented images from unet
        N, W, H, _ = y_pred_.shape
        y_pred_ = y_pred_[..., tf.newaxis]
        y_pred_ = tf.repeat(y_pred_, repeats=[self.R], axis=-1)  # Repeat f(x)
        
        epsilon = 1e-8  # Small constant to avoid divisions by zero
        y_pred_ = tf.clip_by_value(y_pred_, epsilon, 1.0 - epsilon)  # Limiting values between epsilon and 1 - epsilon
        
        term_r = tf.math.reduce_mean(tf.math.multiply(y_true, (tf.ones([N, W, H, self.K_, self.R]) - tf.pow(y_pred_, self.q)) / (self.q + epsilon + self.smooth)), axis=-2)
        term_c = tf.math.multiply(tf.ones([N, W, H, self.R]) - Lambda_r, (tf.ones([N, W, H, self.R]) - tf.pow((1 / self.K_ + self.smooth) * tf.ones([N, W, H, self.R]), self.q)) / (self.q + epsilon + self.smooth))
        
        # Avoid NaN in final loss function
        TGCE_SS = tf.math.reduce_mean(tf.math.multiply(Lambda_r, term_r) + term_c)
        if tf.math.is_nan(TGCE_SS):
            TGCE_SS = tf.where(tf.math.is_nan(TGCE_SS), tf.constant(1e-8), TGCE_SS)  # Replace NaN with 1e-8
            print("\nInitializing TGCE_SS \n")
        
        return TGCE_SS / tf.constant(0.5857603) # Divided by the highest possible loss value        

    def get_config(self):
        """Gets the configuration of the loss function.

        Returns:
            A dictionary containing the configuration parameters of the loss function.

        """
        base_config = super().get_config()
        return {**base_config, "q": self.q}

## Layers definition

In [19]:
# Definition of layers for the neural network structure

DefaultConv2D = partial(layers.Conv2D,
                        kernel_size=3, activation='relu', padding="same")

DilatedConv = partial(layers.Conv2D,
                        kernel_size=3, activation='relu', padding="same", dilation_rate=10, name="DilatedConv")

DefaultPooling = partial(layers.MaxPool2D,
                        pool_size=2)

upsample = partial(layers.UpSampling2D, (2,2))

def residual_block(x, filters, kernel_initializer, block_name):
    shortcut = x
    x = DefaultConv2D(filters, kernel_initializer=kernel_initializer, name=f'Conv_{block_name}_1')(x)
    x = layers.BatchNormalization(name=f'Batch_{block_name}_1')(x)
    x = DefaultConv2D(filters, kernel_initializer=kernel_initializer, name=f'Conv_{block_name}_2')(x)
    x = layers.BatchNormalization(name=f'Batch_{block_name}_2')(x)
    x = layers.Add(name=f'ResAdd_{block_name}')([shortcut, x])
    x = layers.Activation('relu', name=f'ResAct_{block_name}')(x)
    return x

## Kernel initializers

In [20]:
def kernel_initializer(seed):
    """Returns a Glorot uniform initializer for kernel weights.

    Glorot uniform initializer, also known as Xavier uniform initializer,
    is commonly used to initialize the weights of kernels in neural network layers.
    It draws samples from a uniform distribution within a certain range,
    calculated to keep the variance of the weights constant across layers.
    This initializer is useful for training deep neural networks.

    Parameters:
        seed (int): Random seed for reproducibility.

    Returns:
        A Glorot uniform initializer for kernel weights.

    """
    return tf.keras.initializers.GlorotUniform(seed=seed)

## Activation function: Sparse Softmax

$$\text{SparseSoftmax}(x) = \frac{\exp(x - \text{max}(x))}{\text{sum}(\exp(x - \text{max}(x)))}$$

Where:

- $x$ is the input tensor.
- $\text{max}(x)$ is the maximum value in the tensor $x$.
- $\text{sum}$ is the sum of the exponential values of $x - \text{max}(x)$.

In [21]:
class SparseSoftmax(Layer):
    """Custom layer implementing the sparse softmax activation function.

    This layer computes the softmax activation function for a given input tensor,
    handling sparse input efficiently.

    Methods
    ----------
    build(input_shape)
    call(x)
    compute_output_shape(input_shape)

    """

    def __init__(self, name='SparseSoftmax', **kwargs):
        """Initializes the SparseSoftmax layer.

        Parameters:
            **kwargs: Additional arguments to be passed to the parent class.

        """
        super(SparseSoftmax, self).__init__(name=name,**kwargs)

    def build(self, input_shape):
        """Builds the layer.

        Parameters:
            input_shape (tuple): Shape of the input tensor.

        """
        super(SparseSoftmax, self).build(input_shape)

    def call(self, x):
        """Computes the output of the layer.

        Parameters:
            x (tensor): Input tensor.

        Returns:
            A tensor representing the output of the softmax activation function.

        """
        e_x = K.exp(x - K.max(x, axis=-1, keepdims=True))
        sum_e_x = K.sum(e_x, axis=-1, keepdims=True)
        output = e_x / (sum_e_x + K.epsilon())
        return output

    def compute_output_shape(self, input_shape):
        """Computes the output shape of the layer.

        Parameters:
            input_shape (tuple): Shape of the input tensor.

        Returns:
            The same shape as the input tensor.

        """
        return input_shape

In [22]:
tf.keras.utils.get_custom_objects()['sparse_softmax'] = SparseSoftmax()

## Unet Model

![Unet-TGCELoss](https://lh3.googleusercontent.com/fife/ALs6j_FQpKDYXh1PSuK4B7zca2t-aLNDOGiHv3Cx4wnqrj5tDZp7LFygbkh3Pw0t3sIu7Fly0fC7BrL5I8HME3fXUAFPgKIkzPuTL8W1m6kBgaqQf8yMOt89R7PUO4PNgy0dG5DbixazOdTRquMcEzJ1WiYWf33c1gIK9_8eJhDmmhAYutY-HkMQBerxetLXEdSc2euRjPf7yKb44U9MwhfhX7UWt4Scndk6Pn6FQIczftrSXrAkpm1wQ3KqEuvtfRyffpoYAc48uZ2SYu9G22X9ZIyWBiGhQbuOe1MpW27lJ8SAtXcG5BIt6jy55NTFSWeh0AhuVM-3gwKhTOYqwy4C0lcbXQJ8XobTaWAQWxSyCPHkGfsccxwkILRi9-13j7Bd22ZReapWyu7w0Yrn4I2qxpvJ7kiVjElZWw8foKitEHsELCqLOKcJXrGQGr0qo9maTPzU_aLVYObrDrvndmdOSfkThd2hYswinp75TaoDSMOa3XJIK-02Rv9uS3S8w6VaWX3kggtxDeGlUyfLs8zBOj9SI0sZ8X8l0fd3nFUgOb1lTf29Eod57vzVICMMQvWO0sgabPavO7QxiMt5EvcJHrdZz0U1OYtC91SAszpd58JQSKPChrRaYjIiPNZPWdkFkztzFhClTMlgw8U6SUTSWHYJEDXFeb15Dnv4kNENTV0zpBT6PCdrdpqsNMnhGPGpaCeVPivt85Em8mz1I5YApYKKstjP-L5GG3cCSM3KSe2lu13cVZoBHDXAVUfmqlGgj37MGYsAi5V8U-S4G7Xtun40AYqxRC4OhptA2AdShQ6JOdPfKAlmhLy4rH9Ae1Euu3KrCcYuFY_-_x0POr45C7ouP5uyOAgtWDvM75yyAYkSBVqMYKV-VFnLW_piSQNsZiQXwqjiCGYTi4eYJVsn8TatuJW7o6L1MlO5qdtdgrWNN4lNTl1_CFB8b55wrzZgCxjuuUUGRKU1hA8cEtECtktWbUhfrwvgdCU7-fvuUMIZF7OBizlzGXQ-tYYbQjGpUzfSq19efU_9Wqj8Fep5fNRKuLRTLVP4d-GLy7ecoHbYR-bDx-Ln3oTiiEJ7n1_H6vrRskybpNN_eGhokwxNy3_9h2jWeEAxwWJn6tMf19lkpwdsYanf-VTp5SthSoWb7BA7_qGGXScorMXbbFlKGjAfYKLuXLapcO2wy1yvQe87fggoTkyECfNg_3ie_TdEZlmhdvjAX0oNNOFH6KOsck3WbRdB3cOyICd872hkS0QNB8TxdPVRAR_2ifRxBKOygqBYA0WKEVl2VPCnDDfmwJp401ORkk6AFEZssaKzWinP9Kt3Xo436vo_983b8oNLowlhfeJpHR9iWl9iaHxwU_LShYvCrG8L_eLeaPFAVVXgpPF31KGgeKOSFoaEDvO5OpMrPJ0HrDZPZSXw6KMNtf0B0t_1R6OJoD9OpJSAvUVrVd9xxnsVJxZGSGFr28MNrSJ1gl9PRPyLalFiakDhxYO2PwshZJrzf-GaUANkRx-lJpnzPkqlqcMxl1HNDYSnY73A3sb3kLJe7m_eq0EfCpJkTK4oNPF6-fS1YcQguS8BKnlMXdWfTTywv6Cjhb7ROe1qB83tCtU7syhVEuTzMohgcm-5jT-AOtGJX-WGQny3mboNKZU0j_y9sHAJD7UmTihwutj4Hr8eF1yFUbOTdw=w1632-h928)

In [23]:
ResNet34, preprocess_input = Classifiers.get('resnet34')

In [24]:
def unet_model(input_shape=(256,256,3), name='UNET', out_channels=2, out_filler_channels=1, n_scorers=3):
    """
    Builds a U-Net model for image segmentation.

    Parameters:
    - input_shape (tuple): Shape of the input images, default is (128, 128, 3).
    - name (str): Name of the model, default is 'UNET'.
    - out_channels (int): Number of output channels, default is 2.
    - out_filler_channels (int): Number of filler output channels, default is 1.
    - n_scorers (int): Number of scorer channels, default is 3.

    Returns:
    - model (tensorflow.keras.Model): The constructed U-Net model.

    The model consists of an encoder and decoder with multiple convolutional, batch normalization, and dropout layers. 
    It also includes upsampling and concatenation steps to maintain spatial dimensions during decoding.
    """
    # Backbone
    backbone = ResNet34(input_shape=input_shape, weights='imagenet', include_top=False)
    
    for layer in backbone.layers: # Freeze all layers except BatchNormalization layers
        if not isinstance(layer, layers.BatchNormalization): 
            layer.trainable = False

    input = backbone.input
    level_1 = backbone.layers[5].output #128x128x64
    level_2 = backbone.layers[37].output #64x64x64t
    level_3 = backbone.layers[74].output #32x32x128
    level_4 = backbone.layers[129].output #16x16x256
    x = backbone.layers[157].output #8x8x512

    # Decoder
    x = upsample(name='Up60')(x)  # 8x8 -> 16x16
    x = layers.Concatenate(name='Concat60')([level_4, x])
    x = DefaultConv2D(256, kernel_initializer=kernel_initializer(91), name='Conv60')(x)
    x = residual_block(x, 256, kernel_initializer(47), '60')

    x = upsample(name='Up70')(x)  # 16x16 -> 32x32
    x = layers.Concatenate(name='Concat70')([level_3, x])
    x = DefaultConv2D(128, kernel_initializer=kernel_initializer(21), name='Conv70')(x)
    x = residual_block(x, 128, kernel_initializer(96), '70')

    x = upsample(name='Up80')(x)  # 32x32 -> 64x64
    x = layers.Concatenate(name='Concat80')([level_2, x])
    x = DefaultConv2D(64, kernel_initializer=kernel_initializer(96), name='Conv80')(x)
    x = residual_block(x, 64, kernel_initializer(98), '80')

    x = upsample(name='Up90')(x)  # 64x64 -> 128x128
    x = layers.Concatenate(name='Concat90')([level_1, x])
    x = DefaultConv2D(32, kernel_initializer=kernel_initializer(35), name='Conv90')(x)
    x = residual_block(x, 32, kernel_initializer(7), '90')
    
    x = upsample(name='Up100')(x) # 128x128 -> 256x256
    x = layers.Concatenate(name='Concat100')([input, x])
    x = DefaultConv2D(16, kernel_initializer=kernel_initializer(45), name='Conv100')(x)
    x = residual_block(x, 16, kernel_initializer(7), '100')

    xy = DefaultConv2D(out_channels, kernel_size=(1, 1), activation='sparse_softmax',
                      kernel_initializer=kernel_initializer(42), name='Conv200')(x)
    x_lambda = DilatedConv(n_scorers, kernel_size=(1, 1), activation='sparse_softmax',
                        kernel_initializer=kernel_initializer(42), name='DilatedConv200-Lambda')(x)
    xyy = DefaultConv2D(out_filler_channels, kernel_size=(1, 1),
                      kernel_initializer=kernel_initializer(42),
                      name='Conv201')(x)

    y = layers.Concatenate(name='Concat200')([xy, x_lambda, xyy])
    model = Model(input, y, name=name)

    return model

## Train model

In [53]:
class ChangeLearningRateCallback(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if epoch == 57:
            new_lr = 0.000493612  # New learning rate
            keras.backend.set_value(self.model.optimizer.lr, new_lr)
            print(f"\nSet learning rate a {new_lr}")
change_lr_callback = ChangeLearningRateCallback()

In [None]:
model = unet_model()

model.compile(optimizer=Adam(learning_rate=0.00093612),
              loss=TGCE_SS(q=0.48029, R=NUM_ANNOTATORS, K_=2, fill_channels=1))

history = model.fit(synthetic_X_train,synthetic_y_train, epochs=100, validation_data=(synthetic_X_val,synthetic_y_val), callbacks=[change_lr_callback])

In [None]:
# Extract loss values
train_loss = history.history['loss']
val_loss = history.history['val_loss']

# Create a list with the number of epochs
epochs = range(1, len(train_loss) + 1)

# Create the chart
plt.figure(figsize=(12, 6))
plt.plot(epochs, train_loss, 'b', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# Show the graph
plt.show()

## Visualize performance

In [None]:
model_y_train = model.predict(synthetic_X_train)
model_y_val = model.predict(synthetic_X_val)
model_y_test = model.predict(synthetic_X_test)

In [None]:
rows, columns = 4,3
sample = 333
fig, axes = plt.subplots(rows,columns,figsize=(15,8))

axes[0,0].imshow(original_X_train[sample])
axes[0,0].set_title('Image in original train')
axes[0,2].imshow(synthetic_X_train[sample])
axes[0,2].set_title('Image in synthetic train')

axes[1,1].imshow(original_y_train[sample])
axes[1,1].set_title('Original mask for image')

axes[2,0].imshow(synthetic_y_train[sample,:,:,0])
axes[2,0].set_title('Annotator 1 mask for image')
axes[2,1].imshow(synthetic_y_train[sample,:,:,1])
axes[2,1].set_title('Annotator 2 mask for image')
axes[2,2].imshow(synthetic_y_train[sample,:,:,2])
axes[2,2].set_title('Annotator 3 mask for image')

axes[3,1].imshow(model_y_train[sample,:,:,0])
axes[3,1].set_title('Model mask for image')

[axes[i, k].axis('off') for i in range(rows) for k in range(columns)]

## Definition of performance metrics

### DICE metric

$$\text{Dice} = {2 \cdot |\text{Intersection}| + \text{smooth} \over |\text{Union}| + \text{smooth}}$$

Where:

$|\text{Intersection}| = \sum_{i=1}^{N} y\_{true\_i} \cdot y\_{pred\_i}$, $|\text{Union}| = \sum_{i=1}^{N} y\_{true\_i} + \sum_{i=1}^{N} y\_{pred\_i}$


- $N$ is the total number of elements in the segmentation masks.
- $y\_{true\_i}$ and $y\_{pred\_i}$ represent the value of the i-th element in the ground truth and predicted segmentation masks, respectively.
- $\text{smooth}$ is a smoothing parameter to avoid division by zero.

In [39]:
# Definition of the DiceCoefficientMetric

def dice_metric(y_true, y_pred, axis=(1, 2), smooth=1e-5):
    """Computes the Dice coefficient metric for evaluating semantic segmentation.

    This function calculates the Dice coefficient metric, which measures the similarity
    between ground truth and predicted segmentation masks.

    Parameters:
        y_true (tensor): Ground truth segmentation masks.
        y_pred (tensor): Predicted segmentation masks.
        axis (tuple of int): Axis along which to compute sums. Defaults to (1, 2).
        smooth (float): A smoothing parameter to avoid division by zero. Defaults to 1e-5.
        num_annotators (int): Number of annotators. Defaults to 3.

    Returns:
        A scalar value representing the average Dice coefficient metric.
    """
    y_true = tf.cast(tf.squeeze(y_true, axis=-1), tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_pred = tf.where(y_pred> 0.5, tf.ones_like(y_pred), tf.zeros_like(y_pred))
    intersection = tf.reduce_sum(y_true * y_pred, axis=axis)
    union = tf.reduce_sum(y_true, axis=axis) + tf.reduce_sum(y_pred, axis=axis)
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

### Jaccard metric

$$\text{Jaccard} = {|\text{Intersection}| + \text{smooth} \over |\text{Union}| + \text{smooth}}$$

Where:

$|\text{Intersection}| = \sum_{i=1}^{N} y\_{true\_i} \cdot y\_{pred\_i}$, $|\text{Union}| = \sum_{i=1}^{N} y\_{true\_i} + \sum_{i=1}^{N} y\_{pred\_i} - |\text{Intersection}|$
- $N$ is the total number of elements in the segmentation masks.
- $y\_{true\_i}$ and $y\_{pred\_i}$ represent the value of the i-th element in the ground truth and predicted segmentation masks, respectively.
- $\text{smooth}$ is a small smoothing parameter to avoid division by zero.

In [29]:
# Definition of the JaccardMetric

def jaccard_metric(y_true, y_pred, axis=(1, 2), smooth=1e-5, num_annotators=3):
    """Computes the Jaccard similarity coefficient as a metric for semantic segmentation.

    The Jaccard similarity coefficient, also known as the Intersection over Union (IoU),
    measures the similarity between two sets by comparing their intersection to their union.
    In the context of semantic segmentation, it quantifies the overlap between the ground
    truth segmentation masks and the predicted segmentation masks.

    Parameters:
        y_true (tensor): Ground truth segmentation masks.
        y_pred (tensor): Predicted segmentation masks.
        axis (tuple of int): Axes along which to compute sums. Defaults to (1, 2).
        smooth (float): A small smoothing parameter to avoid division by zero. Defaults to 1e-5.
        num_annotators (int): Number of annotators. Defaults to 3.

    Returns:
        A tensor representing the mean Jaccard similarity coefficient.

    """ 
    y_true = tf.cast(tf.squeeze(y_true, axis=-1), tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_pred = tf.where(y_pred> 0.5, tf.ones_like(y_pred), tf.zeros_like(y_pred))
    intersection = tf.reduce_sum(y_true * y_pred, axis=axis)
    union = tf.reduce_sum(y_true, axis=axis) + tf.reduce_sum(y_pred, axis=axis) - intersection
    jaccard = (intersection + smooth) / (union + smooth)
    return tf.reduce_mean(jaccard)

### Sensitivity metric

$$\text{Sensitivity} = {\text{True Positives} \over \text{Actual Positives} + \text{smooth}}$$

Where:

$\text{True Positives} = \sum_{i=1}^{N} y\_{true\_i} \cdot y\_{pred\_i}$, $\text{Actual Positives} = \sum_{i=1}^{N} y\_{true\_i}$


- $N$ is the total number of elements in the labels.
- $y\_{true\_i}$ and $y\_{pred\_i}$ represent the value of the i-th element in the ground truth and predicted labels, respectively.
- $\text{smooth}$ is a small value added to the denominator to avoid division by zero.

In [30]:
# Definition of the SensitivityMetric

def sensitivity_metric(y_true, y_pred, axis=(1, 2), smooth=1e-5, num_annotators=3):
    """Computes the sensitivity as a metric for semantic segmentation.

    Sensitivity, also known as true positive rate or recall, measures the proportion
    of actual positives that are correctly identified by the model. It is computed
    as the ratio of true positives to the sum of true positives and false negatives.

    Parameters:
        y_true (tensor): Ground truth labels.
        y_pred (tensor): Predicted probabilities or labels.
        axis (tuple): Axes over which to perform the reduction. Defaults to (1, 2).
        smooth (float): A small value added to the denominator to avoid division by zero. Defaults to 1e-5.
        num_annotators (int): Number of annotators. Defaults to 3.

    Returns:
        The sensitivity metric averaged over the specified axes.

    """
    y_true = tf.cast(tf.squeeze(y_true, axis=-1), tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_pred = tf.where(y_pred> 0.5, tf.ones_like(y_pred), tf.zeros_like(y_pred))
    true_positives = tf.reduce_sum(y_true * y_pred, axis=axis)
    actual_positives = tf.reduce_sum(y_true, axis=axis)
    sensitivity = true_positives / (actual_positives + smooth)
    return tf.reduce_mean(sensitivity)

### Specificity metric

$$\text{Specificity} = {\text{True Negatives} \over \text{Actual Negatives} + \text{smooth}}$$

Where:

$\text{True Negatives} = \sum_{i=1}^{N} (1 - y\_{true\_i}) \cdot (1 - y\_{pred\_i})$, $\text{Actual Negatives} = \sum_{i=1}^{N} (1 - y\_{true\_i})$

- $N$ is the total number of samples.
- $y\_{true\_i}$ and $y\_{pred\_i}$ represent the ground truth label and predicted probability (or binary prediction) for the i-th sample, respectively.
- $\text{smooth}$ is a smoothing term to avoid division by zero.

In [31]:
# Definition of the SpecificityMetric

def specificity_metric(y_true, y_pred, axis=(1, 2), smooth=1e-5, num_annotators=3):
    """Computes the specificity as a metric for semantic segmentation.

    Specificity measures the proportion of actual negative cases that were correctly
    identified as such. It is complementary to sensitivity (recall).

    Parameters:
        y_true (tensor): Ground truth binary labels.
        y_pred (tensor): Predicted probabilities or binary predictions.
        axis (tuple): Axes over which to perform reduction. Defaults to (1, 2).
        smooth (float): Smoothing term to avoid division by zero. Defaults to 1e-5.
        num_annotators (int): Number of annotators. Defaults to 3.

    Returns:
        A tensor representing the specificity metric.

    """
    y_true = tf.cast(tf.squeeze(y_true, axis=-1), tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_pred = tf.where(y_pred> 0.5, tf.ones_like(y_pred), tf.zeros_like(y_pred))
    true_negatives = tf.reduce_sum((1 - y_true) * (1 - y_pred), axis=axis)
    actual_negatives = tf.reduce_sum(1 - y_true, axis=axis)
    specificity = true_negatives / (actual_negatives + smooth)
    return tf.reduce_mean(specificity)

## Measuring performance with metrics for model's predictions

In [None]:
dice_model_y_train = dice_metric(original_y_train,model_y_train[:,:,:,0])
dice_model_y_val = dice_metric(original_y_val,model_y_val[:,:,:,0])
dice_model_y_test = dice_metric(original_y_test,model_y_test[:,:,:,0])
print(dice_model_y_train,dice_model_y_val,dice_model_y_test)

In [None]:
jaccard_model_y_train = jaccard_metric(original_y_train,model_y_train[:,:,:,0])
jaccard_model_y_val = jaccard_metric(original_y_val,model_y_val[:,:,:,0])
jaccard_model_y_test = jaccard_metric(original_y_test,model_y_test[:,:,:,0])
print(jaccard_model_y_train,jaccard_model_y_val,jaccard_model_y_test)

In [None]:
sensitivity_model_y_train = sensitivity_metric(original_y_train,model_y_train[:,:,:,0])
sensitivity_model_y_val = sensitivity_metric(original_y_val,model_y_val[:,:,:,0])
sensitivity_model_y_test = sensitivity_metric(original_y_test,model_y_test[:,:,:,0])
print(sensitivity_model_y_train,sensitivity_model_y_val,sensitivity_model_y_test)

In [None]:
specificity_model_y_train = specificity_metric(original_y_train,model_y_train[:,:,:,0])
specificity_model_y_val = specificity_metric(original_y_val,model_y_val[:,:,:,0])
specificity_model_y_test = specificity_metric(original_y_test,model_y_test[:,:,:,0])
print(specificity_model_y_train,specificity_model_y_val,specificity_model_y_test)