In [1]:
%%capture
!pip install segmentation-models-pytorch==0.2.1 mlflow

In [2]:
# General Libraries
import os
import sys
import csv
import cv2
import yaml
import h5py
import errno
import shutil
import random
import mlflow
import pickle
import imageio
import inspect
import zipfile
import warnings
import argparse
import functools
import matplotlib
import collections
from collections.abc import Mapping 
import numpy as np
import pandas as pd
matplotlib.use('Agg')
from PIL import Image
from typing import Dict
import albumentations as albu
from enum import auto, Enum
import matplotlib.pyplot as plt
from IPython.display import display, Image

# Sklearn metrics Libraries
from sklearn.metrics import accuracy_score, jaccard_score

# Pytorch Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
from torch.distributions import Normal, Independent, kl
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Dataset, DataLoader,TensorDataset

# Segmentation Models on Pytorch Libraries
from segmentation_models_pytorch.losses import DiceLoss, FocalLoss
from segmentation_models_pytorch.encoders import get_preprocessing_fn

## Probalistic Unet in Pytorch library functions

In [3]:
class LIDC_IDRI(Dataset):
    images = []
    labels = []
    series_uid = []

    def __init__(self, dataset_location, transform=None):
        self.transform = transform
        max_bytes = 2**31 - 1
        data = {}
        for file in os.listdir(dataset_location):
            filename = os.fsdecode(file)
            if '.pickle' in filename:
                print("Loading file", filename)
                file_path = dataset_location + filename
                bytes_in = bytearray(0)
                input_size = os.path.getsize(file_path)
                with open(file_path, 'rb') as f_in:
                    for _ in range(0, input_size, max_bytes):
                        bytes_in += f_in.read(max_bytes)
                new_data = pickle.loads(bytes_in)
                data.update(new_data)
        
        for key, value in data.items():
            self.images.append(value['image'].astype(float))
            self.labels.append(value['masks'])
            self.series_uid.append(value['series_uid'])

        assert (len(self.images) == len(self.labels) == len(self.series_uid))

        for img in self.images:
            assert np.max(img) <= 1 and np.min(img) >= 0
        for label in self.labels:
            assert np.max(label) <= 1 and np.min(label) >= 0

        del new_data
        del data

    def __getitem__(self, index):
        image = np.expand_dims(self.images[index], axis=0)

        #Randomly select one of the four labels for this image
        label = self.labels[index][random.randint(0,3)].astype(float)
        if self.transform is not None:
            image = self.transform(image)

        series_uid = self.series_uid[index]

        # Convert image and label to torch tensors
        image = torch.from_numpy(image)
        label = torch.from_numpy(label)

        #Convert uint8 to float tensors
        image = image.type(torch.FloatTensor)
        label = label.type(torch.FloatTensor)

        return image, label, series_uid

    # Override to give PyTorch size of dataset
    def __len__(self):
        return len(self.images)

In [4]:

class Encoder(nn.Module):
    """
    A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block convolutional layers,
    after each block a pooling operation is performed. And after each convolutional layer a non-linear (ReLU) activation function is applied.
    """
    def __init__(self, input_channels, num_filters, no_convs_per_block, initializers, padding=True, posterior=False):
        super(Encoder, self).__init__()
        self.contracting_path = nn.ModuleList()
        self.input_channels = input_channels
        self.num_filters = num_filters

        if posterior:
            #To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels.
            self.input_channels += 1

        layers = []
        for i in range(len(self.num_filters)):
            """
            Determine input_dim and output_dim of conv layers in this block. The first layer is input x output,
            All the subsequent layers are output x output.
            """
            input_dim = self.input_channels if i == 0 else output_dim
            output_dim = num_filters[i]
            
            if i != 0:
                layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True))
            
            layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=int(padding)))
            layers.append(nn.ReLU(inplace=True))

            for _ in range(no_convs_per_block-1):
                layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=int(padding)))
                layers.append(nn.ReLU(inplace=True))

        self.layers = nn.Sequential(*layers)

        self.layers.apply(init_weights)

    def forward(self, input):
        output = self.layers(input)
        return output

class AxisAlignedConvGaussian(nn.Module):
    """
    A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix.
    """
    def __init__(self, input_channels, num_filters, no_convs_per_block, latent_dim, initializers, posterior=False):
        super(AxisAlignedConvGaussian, self).__init__()
        self.input_channels = input_channels
        self.channel_axis = 1
        self.num_filters = num_filters
        self.no_convs_per_block = no_convs_per_block
        self.latent_dim = latent_dim
        self.posterior = posterior
        if self.posterior:
            self.name = 'Posterior'
        else:
            self.name = 'Prior'
        self.encoder = Encoder(self.input_channels, self.num_filters, self.no_convs_per_block, initializers, posterior=self.posterior)
        self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1,1), stride=1)
        self.show_img = 0
        self.show_seg = 0
        self.show_concat = 0
        self.show_enc = 0
        self.sum_input = 0

        nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu')
        nn.init.normal_(self.conv_layer.bias)

    def forward(self, input, segm=None):

        #If segmentation is not none, concatenate the mask to the channel axis of the input
        if segm is not None:
            self.show_img = input
            self.show_seg = segm
            input = torch.cat((input, segm), dim=1)
            self.show_concat = input
            self.sum_input = torch.sum(input)

        encoding = self.encoder(input)
        self.show_enc = encoding

        #We only want the mean of the resulting hxw image
        encoding = torch.mean(encoding, dim=2, keepdim=True)
        encoding = torch.mean(encoding, dim=3, keepdim=True)

        #Convert encoding to 2 x latent dim and split up for mu and log_sigma
        mu_log_sigma = self.conv_layer(encoding)

        #We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1
        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)

        mu = mu_log_sigma[:,:self.latent_dim]
        log_sigma = mu_log_sigma[:,self.latent_dim:]

        #This is a multivariate normal with diagonal covariance matrix sigma
        #https://github.com/pytorch/pytorch/pull/11178
        dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)),1)
        return dist

class Fcomb(nn.Module):
    """
    A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space,
    and output of the UNet (the feature map) by concatenating them along their channel axis.
    """
    def __init__(self, num_filters, latent_dim, num_output_channels, num_classes, no_convs_fcomb, initializers, use_tile=True):
        super(Fcomb, self).__init__()
        self.num_channels = num_output_channels #output channels
        self.num_classes = num_classes
        self.channel_axis = 1
        self.spatial_axes = [2,3]
        self.num_filters = num_filters
        self.latent_dim = latent_dim
        self.use_tile = use_tile
        self.no_convs_fcomb = no_convs_fcomb 
        self.name = 'Fcomb'

        if self.use_tile:
            layers = []

            #Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer
            layers.append(nn.Conv2d(self.num_filters[0]+self.latent_dim, self.num_filters[0], kernel_size=1))
            layers.append(nn.ReLU(inplace=True))

            for _ in range(no_convs_fcomb-2):
                layers.append(nn.Conv2d(self.num_filters[0], self.num_filters[0], kernel_size=1))
                layers.append(nn.ReLU(inplace=True))

            self.layers = nn.Sequential(*layers)

            self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1)

            if initializers['w'] == 'orthogonal':
                self.layers.apply(init_weights_orthogonal_normal)
                self.last_layer.apply(init_weights_orthogonal_normal)
            else:
                self.layers.apply(init_weights)
                self.last_layer.apply(init_weights)

    def tile(self, a, dim, n_tile):
        """
        This function is taken form PyTorch forum and mimics the behavior of tf.tile.
        Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3
        """
        init_dim = a.size(dim)
        repeat_idx = [1] * a.dim()
        repeat_idx[dim] = n_tile
        a = a.repeat(*(repeat_idx))
        order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device)
        return torch.index_select(a, dim, order_index)

    def forward(self, feature_map, z):
        """
        Z is batch_sizexlatent_dim and feature_map is batch_sizexno_channelsxHxW.
        So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified)
        """
        if self.use_tile:
            z = torch.unsqueeze(z,2)
            z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]])
            z = torch.unsqueeze(z,3)
            z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]])

            #Concatenate the feature map (output of the UNet) and the sample taken from the latent space
            feature_map = torch.cat((feature_map, z), dim=self.channel_axis)
            output = self.layers(feature_map)
            return self.last_layer(output)


class ProbabilisticUnet(nn.Module):
    """
    A probabilistic UNet (https://arxiv.org/abs/1806.05034) implementation.
    input_channels: the number of channels in the image (1 for greyscale and 3 for RGB)
    num_classes: the number of classes to predict
    num_filters: is a list consisint of the amount of filters layer
    latent_dim: dimension of the latent space
    no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior
    """

    def __init__(self, input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=6, no_convs_fcomb=4, beta=10.0):
        super(ProbabilisticUnet, self).__init__()
        self.input_channels = input_channels
        self.num_classes = num_classes
        self.num_filters = num_filters
        self.latent_dim = latent_dim
        self.no_convs_per_block = 3
        self.no_convs_fcomb = no_convs_fcomb
        self.initializers = {'w':'he_normal', 'b':'normal'}
        self.beta = beta
        self.z_prior_sample = 0

        self.unet = Unet(self.input_channels, self.num_classes, self.num_filters, self.initializers, apply_last_layer=False, padding=True).to(device)
        self.prior = AxisAlignedConvGaussian(self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim,  self.initializers,).to(device)
        self.posterior = AxisAlignedConvGaussian(self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers, posterior=True).to(device)
        self.fcomb = Fcomb(self.num_filters, self.latent_dim, self.input_channels, self.num_classes, self.no_convs_fcomb, {'w':'orthogonal', 'b':'normal'}, use_tile=True).to(device)

    def forward(self, patch, segm, training=True):
        """
        Construct prior latent space for patch and run patch through UNet,
        in case training is True also construct posterior latent space
        """
        if training:
            self.posterior_latent_space = self.posterior.forward(patch, segm)
        self.prior_latent_space = self.prior.forward(patch)
        self.unet_features = self.unet.forward(patch,False)

    def sample(self, testing=False):
        """
        Sample a segmentation by reconstructing from a prior sample
        and combining this with UNet features
        """
        if testing == False:
            z_prior = self.prior_latent_space.rsample()
            self.z_prior_sample = z_prior
        else:
            #You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample.
            #z_prior = self.prior_latent_space.base_dist.loc 
            z_prior = self.prior_latent_space.sample()
            self.z_prior_sample = z_prior
        return self.fcomb.forward(self.unet_features,z_prior)


    def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None):
        """
        Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map
        use_posterior_mean: use posterior_mean instead of sampling z_q
        calculate_posterior: use a provided sample or sample from posterior latent space
        """
        if use_posterior_mean:
            z_posterior = self.posterior_latent_space.loc
        else:
            if calculate_posterior:
                z_posterior = self.posterior_latent_space.rsample()
        return self.fcomb.forward(self.unet_features, z_posterior)

    def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None):
        """
        Calculate the KL divergence between the posterior and prior KL(Q||P)
        analytic: calculate KL analytically or via sampling from the posterior
        calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
        """
        if analytic:
            #Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545
            kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space)
        else:
            if calculate_posterior:
                z_posterior = self.posterior_latent_space.rsample()
            log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior)
            log_prior_prob = self.prior_latent_space.log_prob(z_posterior)
            kl_div = log_posterior_prob - log_prior_prob
        return kl_div

    def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False):
        """
        Calculate the evidence lower bound of the log-likelihood of P(Y|X)
        """

        criterion = nn.BCEWithLogitsLoss(size_average = False, reduce=False, reduction=None)
        z_posterior = self.posterior_latent_space.rsample()
        
        self.kl = torch.mean(self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior))

        #Here we use the posterior sample sampled above
        self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean, calculate_posterior=False, z_posterior=z_posterior)
        
        reconstruction_loss = criterion(input=self.reconstruction, target=segm)
        self.reconstruction_loss = torch.sum(reconstruction_loss)
        self.mean_reconstruction_loss = torch.mean(reconstruction_loss)

        return -(self.reconstruction_loss + self.beta * self.kl)

In [5]:
class Unet(nn.Module):
    """
    A UNet (https://arxiv.org/abs/1505.04597) implementation.
    input_channels: the number of channels in the image (1 for greyscale and 3 for RGB)
    num_classes: the number of classes to predict
    num_filters: list with the amount of filters per layer
    apply_last_layer: boolean to apply last layer or not (not used in Probabilistic UNet)
    padidng: Boolean, if true we pad the images with 1 so that we keep the same dimensions
    """

    def __init__(self, input_channels, num_classes, num_filters, initializers, apply_last_layer=True, padding=True):
        super(Unet, self).__init__()
        self.input_channels = input_channels
        self.num_classes = num_classes
        self.num_filters = num_filters
        self.padding = padding
        self.activation_maps = []
        self.apply_last_layer = apply_last_layer
        self.contracting_path = nn.ModuleList()

        for i in range(len(self.num_filters)):
            input = self.input_channels if i == 0 else output
            output = self.num_filters[i]

            if i == 0:
                pool = False
            else:
                pool = True

            self.contracting_path.append(DownConvBlock(input, output, initializers, padding, pool=pool))

        self.upsampling_path = nn.ModuleList()

        n = len(self.num_filters) - 2
        for i in range(n, -1, -1):
            input = output + self.num_filters[i]
            output = self.num_filters[i]
            self.upsampling_path.append(UpConvBlock(input, output, initializers, padding))

        if self.apply_last_layer:
            self.last_layer = nn.Conv2d(output, num_classes, kernel_size=1)
            #nn.init.kaiming_normal_(self.last_layer.weight, mode='fan_in',nonlinearity='relu')
            #nn.init.normal_(self.last_layer.bias)


    def forward(self, x, val):
        blocks = []
        for i, down in enumerate(self.contracting_path):
            x = down(x)
            if i != len(self.contracting_path)-1:
                blocks.append(x)

        for i, up in enumerate(self.upsampling_path):
            x = up(x, blocks[-i-1])

        del blocks

        #Used for saving the activations and plotting
        if val:
            self.activation_maps.append(x)
        
        if self.apply_last_layer:
            x =  self.last_layer(x)

        return x

In [6]:
class DownConvBlock(nn.Module):
    """
    A block of three convolutional layers where each layer is followed by a non-linear activation function
    Between each block we add a pooling operation.
    """
    def __init__(self, input_dim, output_dim, initializers, padding, pool=True):
        super(DownConvBlock, self).__init__()
        layers = []

        if pool:
            layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True))

        layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=1, padding=int(padding)))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=int(padding)))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=int(padding)))
        layers.append(nn.ReLU(inplace=True))

        self.layers = nn.Sequential(*layers)

        self.layers.apply(init_weights)

    def forward(self, patch):
        return self.layers(patch)


class UpConvBlock(nn.Module):
    """
    A block consists of an upsampling layer followed by a convolutional layer to reduce the amount of channels and then a DownConvBlock
    If bilinear is set to false, we do a transposed convolution instead of upsampling
    """
    def __init__(self, input_dim, output_dim, initializers, padding, bilinear=True):
        super(UpConvBlock, self).__init__()
        self.bilinear = bilinear

        if not self.bilinear:
            self.upconv_layer = nn.ConvTranspose2d(input_dim, output_dim, kernel_size=2, stride=2)
            self.upconv_layer.apply(init_weights)

        self.conv_block = DownConvBlock(input_dim, output_dim, initializers, padding, pool=False)

    def forward(self, x, bridge):
        if self.bilinear:
            up = nn.functional.interpolate(x, mode='bilinear', scale_factor=2, align_corners=True)
        else:
            up = self.upconv_layer(x)
        
        assert up.shape[3] == bridge.shape[3]
        out = torch.cat([up, bridge], 1)
        out =  self.conv_block(out)

        return out

In [7]:
def truncated_normal_(tensor, mean=0, std=1):
    size = tensor.shape
    tmp = tensor.new_empty(size + (4,)).normal_()
    valid = (tmp < 2) & (tmp > -2)
    ind = valid.max(-1, keepdim=True)[1]
    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
    tensor.data.mul_(std).add_(mean)

def init_weights(m):
    if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
        #nn.init.normal_(m.weight, std=0.001)
        #nn.init.normal_(m.bias, std=0.001)
        truncated_normal_(m.bias, mean=0, std=0.001)

def init_weights_orthogonal_normal(m):
    if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
        nn.init.orthogonal_(m.weight)
        truncated_normal_(m.bias, mean=0, std=0.001)
        #nn.init.normal_(m.bias, std=0.001)

def l2_regularisation(m):
    l2_reg = None

    for W in m.parameters():
        if l2_reg is None:
            l2_reg = W.norm(2)
        else:
            l2_reg = l2_reg + W.norm(2)
    return l2_reg

def save_mask_prediction_example(mask, pred, iter):
    plt.imshow(pred[0,:,:],cmap='Greys')
    plt.savefig('images/'+str(iter)+"_prediction.png")
    plt.imshow(mask[0,:,:],cmap='Greys')
    plt.savefig('images/'+str(iter)+"_mask.png")

## Crowds for automated histopathological image segmentation functions

In [8]:
def double_conv(in_channels, out_channels, step, norm):
    # ===========================================
    # in_channels: dimension of input
    # out_channels: dimension of output
    # step: stride
    # ===========================================
    if norm == 'in':
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, 3, stride=step, padding=1, groups=1, bias=False),
            torch.nn.InstanceNorm2d(out_channels, affine=True),
            torch.nn.PReLU(),
            torch.nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, groups=1, bias=False),
            torch.nn.InstanceNorm2d(out_channels, affine=True),
            torch.nn.PReLU()
        )
    elif norm == 'bn':
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, 3, stride=step, padding=1, groups=1, bias=False),
            torch.nn.BatchNorm2d(out_channels, affine=True),
            torch.nn.PReLU(),
            torch.nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, groups=1, bias=False),
            torch.nn.BatchNorm2d(out_channels, affine=True),
            torch.nn.PReLU()
        )
    elif norm == 'ln':
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, 3, stride=step, padding=1, groups=1, bias=False),
            torch.nn.GroupNorm(out_channels, out_channels, affine=True),
            torch.nn.PReLU(),
            torch.nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, groups=1, bias=False),
            torch.nn.GroupNorm(out_channels, out_channels, affine=True),
            torch.nn.PReLU()
        )
    elif norm == 'gn':
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, 3, stride=step, padding=1, groups=1, bias=False),
            torch.nn.GroupNorm(out_channels // 8, out_channels, affine=True),
            torch.nn.PReLU(),
            torch.nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, groups=1, bias=False),
            torch.nn.GroupNorm(out_channels // 8, out_channels, affine=True),
            torch.nn.PReLU()
        )



class global_CM(torch.nn.Module):
    """ This defines the annotator network (CR Global)
    """

    def __init__(self, class_no, input_height, input_width, noisy_labels_no):
        super(global_CM, self).__init__()
        self.class_no = class_no
        self.noisy_labels_no = noisy_labels_no
        self.input_height = input_height
        self.input_width = input_width
        self.noisy_labels_no = noisy_labels_no
        self.dense_output = torch.nn.Linear(noisy_labels_no, class_no ** 2)
        self.act = torch.nn.Softplus()
        # self.relu = torch.nn.ReLU()

    def forward(self, A_id, x=None):
        output = self.act(self.dense_output(A_id))
        all_weights = output.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 512, 512)
        y = all_weights.view(-1, self.class_no**2, self.input_height, self.input_width)



        return y


class conv_layers_image(torch.nn.Module):
    def __init__(self, in_channels):
        super(conv_layers_image, self).__init__()
        self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=8, kernel_size=3, stride=1, padding=1)
        self.conv2 = torch.nn.Conv2d(in_channels=8, out_channels=4, kernel_size=3, stride=1, padding=1)
        self.conv3 = torch.nn.Conv2d(in_channels=4, out_channels=4, kernel_size=3, stride=1, padding=1)
        self.relu = torch.nn.ReLU()
        self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_bn = torch.nn.BatchNorm2d(8)
        self.conv_bn2 = torch.nn.BatchNorm2d(4)
        self.fc_bn = torch.nn.BatchNorm1d(128)
        self.flatten = torch.nn.Flatten()
        self.fc1 = torch.nn.Linear(in_features=4096, out_features=128)
        self.fc2 = torch.nn.Linear(in_features=128, out_features=64)

    def forward(self, x):
        x = self.pool(self.relu(self.conv_bn(self.conv(x))))
        x = self.pool(self.relu(self.conv_bn2(self.conv2(x))))
        x = self.pool(self.relu(self.conv_bn2(self.conv3(x))))
        x = self.pool(self.relu(self.conv_bn2(self.conv3(x))))
        x = self.flatten(x)

        x = self.relu(self.fc_bn(self.fc1(x)))
        y = self.fc2(x)

        return y


class image_CM(torch.nn.Module):
    """ This defines the annotator network (CR Image)
    """

    def __init__(self, class_no, input_height, input_width, noisy_labels_no):
        super(image_CM, self).__init__()
        self.class_no = class_no
        self.noisy_labels_no = noisy_labels_no
        self.input_height = input_height
        self.input_width = input_width
        self.noisy_labels_no = noisy_labels_no
        self.conv_layers = conv_layers_image(16)
        self.dense_annotator = torch.nn.Linear(noisy_labels_no, 64)
        self.dense_output = torch.nn.Linear(128, class_no ** 2)
        self.norm = torch.nn.BatchNorm1d(class_no ** 2)
        self.act = torch.nn.Softplus()

    def forward(self, A_id, x):
        A_feat = self.dense_annotator(A_id)  # B, F_A
        x = self.conv_layers(x)
        output = self.dense_output(torch.hstack((A_feat, x)))
        output = self.norm(output)
        output = self.act(output.view(-1, self.class_no, self.class_no))
        all_weights = output.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, self.input_height, self.input_width)
        y = all_weights.view(-1, self.class_no**2, self.input_height, self.input_width)

        return y



class cm_layers(torch.nn.Module):
    """ This defines the annotator network (CR Pixel)
    """

    def __init__(self, in_channels, norm, class_no, noisy_labels_no):
        super(cm_layers, self).__init__()
        self.conv_1 = double_conv(in_channels=in_channels, out_channels=in_channels, norm=norm, step=1)
        self.conv_2 = double_conv(in_channels=in_channels, out_channels=in_channels, norm=norm, step=1)
        # self.conv_last = torch.nn.Conv2d(in_channels, class_no ** 2, 1, bias=True)
        self.class_no = class_no
        self.dense = torch.nn.Linear(80, 25)
        self.dense2 = torch.nn.Linear(25, 25)
        self.dense_annotator = torch.nn.Linear(noisy_labels_no, 64)
        # self.dense_classes = torch.nn.Linear(noisy_labels_no, 50)
        self.norm = torch.nn.BatchNorm2d(80, affine=True)
        self.relu = torch.nn.Softplus()
        self.act = torch.nn.Softmax(dim=3)

    def forward(self, A_id, x):
        print('################################################4\n')
        y = self.conv_2(self.conv_1(x))
        print(f'y shape: {y.shape}\n')
        A_id = self.relu(self.dense_annotator(A_id))  # B, F_A
        print(f'A_id shape: {A_id.shape}\n')
        A_id = A_id.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 512, 512)
        print(f'A_id shape: {A_id.shape}\n')
        
        y = torch.cat((A_id, y), dim=1)
        print(f'y shape: {y.shape}\n')
        y = self.norm(y)
        print(f'y shape: {y.shape}\n')
        y = y.permute(0, 2, 3, 1)
        print(f'y shape: {y.shape}\n')
        y = self.relu((self.dense(y)))
        print(f'y shape: {y.shape}\n')
        y = self.dense2(y)
        print(f'y shape: {y.shape}\n')
        
        # Verificar el tamaño actual del tensor
        print(f"Current tensor size before view: {y.size()}")

        # Calcular la nueva forma
        print(self.class_no)
        expected_shape = (-1, 512, 512, 5 * self.class_no)  # Adjust shape to match actual tensor size
        expected_size = torch.prod(torch.tensor(expected_shape[1:])).item()  # Calculate expected size

        if y.numel() == expected_size:
            y = self.relu(y.view(expected_shape))
        else:
            raise RuntimeError(f"Expected size {expected_size} but got {y.numel()}")

        y = y.view(-1, 512, 512, self.class_no ** 2).permute(0, 3, 1, 2)

        return y
    
class Crowd_segmentationModel(torch.nn.Module):
    """ This defines the architecture of the chosen CR method
    """
    def __init__(self, noisy_labels):
        super().__init__()
        self.seg_model = create_segmentation_backbone()
        self.activation = torch.nn.Softmax(dim=1)
        self.noisy_labels_no = len(noisy_labels)
        print("Number of annotators (model): ", self.noisy_labels_no)
        self.class_no = config['data']['class_no']
        self.crowd_type = config['model']['crowd_type']
        if self.crowd_type == 'global':
            print("Global crowdsourcing")
            self.crowd_layers = global_CM(self.class_no, 512, 512, self.noisy_labels_no)
        elif self.crowd_type == 'image': 
            print("Image dependent crowdsourcing")
            self.crowd_layers = image_CM(self.class_no, 512, 512, self.noisy_labels_no)
        elif self.crowd_type == 'pixel':
            print("Pixel dependent crowdsourcing")
            self.crowd_layers = cm_layers(in_channels=16, norm='in',
                                                  class_no=config['data']['class_no'], noisy_labels_no=self.noisy_labels_no)  
        self.activation = torch.nn.Softmax(dim=1)

    def forward(self, x, A_id=None):
        cm = None
        #print('################################################3\n')
        #print(f'1nd x shape: {x.shape} \n')
        x = self.seg_model.encoder(x)
        #print(f'2nd x len: {len(x)}\nshape 0 {x[0].shape}\nshape 1 {x[1].shape}\nshape 2 {x[2].shape}\nshape 3 {x[3].shape}\nshape 4 {x[4].shape}\nshape 5 {x[5].shape}')
        x = self.seg_model.decoder(*x)
        if A_id is not None:
            #print(f'A_id shape: {A_id.shape}, 3nd x shape: {x.shape} \n')
            cm = self.crowd_layers(A_id, x)
        x = self.seg_model.segmentation_head(x)
        y = self.activation(x)
        return y, cm

In [9]:
config = {}

def config_update(orig_dict, new_dict):
    for key, val in new_dict.items():
        if isinstance(val, Mapping):
            tmp = config_update(orig_dict.get(key, { }), val)
            orig_dict[key] = tmp
        elif isinstance(val, list):
            orig_dict[key] = val
        else:
            orig_dict[key] = new_dict[key]
    return orig_dict

def init_global_config(args):
    global config

    # load default config
    with open(args.default_config) as file:
        config = yaml.full_load(file)

    # load dataset config, overwrite parameters if double
    with open(config["data"]["dataset_config"]) as file:
        config_data_dependent = yaml.full_load(file)
    config = config_update(config, config_data_dependent)

    # load experiment config, overwrite parameters if double
    if args.experiment_folder != 'None':
        experiment_config = os.path.join(args.experiment_folder, 'exp_config.yaml')
        if os.path.exists(experiment_config):
            with open(experiment_config) as file:
                exp_config = yaml.full_load(file)
            config = config_update(config, exp_config)
        config['logging']['experiment_folder'] = args.experiment_folder
        if config['data']['crowd']:
            exp_fold = args.experiment_folder.split("/")[-3:]
        else:
            exp_fold = args.experiment_folder.split("/")[-2:]
        exp_fold = "_".join(exp_fold)
        config['logging']['run_name'] = exp_fold
    else:
        out_dir = './output/'
        os.makedirs(out_dir, exist_ok=True)
        warnings.warn("No experiment folder was given. Use ./output folder to store experiment results.")
        config['logging']['experiment_folder'] = out_dir
        config['logging']['run_name'] = 'default'

In [10]:
def start_logging():
    mlflow.set_tracking_uri(config["logging"]["mlruns_folder"])

    data_config_log = config['data'].copy()
    data_config_log.pop('visualize_images') # drop this because it is often to long to be logged

    # experiment = mlflow.set_experiment(experiment_name=config["data"]["dataset_name"])
    mlflow.set_experiment(experiment_name=config["data"]["dataset_name"])
    # with mlflow.start_run(experiment_id=experiment.experiment_id, run_name='test') as run:
    # mlflow.start_run(experiment_id=experiment.experiment_id, run_name='test')
    mlflow.start_run(run_name=config['logging']['run_name'])
    print('tracking uri:', mlflow.get_tracking_uri())
    print('artifact uri:', mlflow.get_artifact_uri())
    mlflow.log_params(config['model'])
    mlflow.log_params(data_config_log)
    mlflow.log_artifact('/kaggle/working/OxfordPet/config.yaml')


def log_results(results, mode, step=None):

    formatted_results = {}

    for key in results.keys():
        new_key = mode + '_' + key
        formatted_results[new_key] = results[key]

    mlflow.log_metrics(formatted_results, step=step)

In [11]:
torch.backends.cudnn.deterministic = True
eps=1e-7

def noisy_label_loss(pred, cms, labels, ignore_index, min_trace = False, alpha=0.1, loss_mode=None):
    """ Loss for the crowdsourcing methods
    """
    b, c, h, w = pred.size()

    #
    pred_norm = pred.view(b, c, h*w).permute(0, 2, 1).contiguous().view(b*h*w, c, 1)
    cm = cms.view(b, c ** 2, h * w).permute(0, 2, 1).contiguous().view(b * h * w, c * c).view(b * h * w, c, c)
    cm = cm / cm.sum(1, keepdim=True)

    pred_noisy = torch.bmm(cm, pred_norm).view(b*h*w, c)
    pred_noisy = pred_noisy.view(b, h*w, c).permute(0, 2, 1).contiguous().view(b, c, h, w)

    if loss_mode == 'ce':
        loss_ce = nn.NLLLoss(reduction='mean', ignore_index=ignore_index)(torch.log(pred_noisy+eps), labels.view(b, h, w).long())
    elif loss_mode == 'dice':
        loss_ce = DiceLoss(ignore_index=ignore_index, from_logits=False, mode='multiclass')(pred_noisy, labels.view(b, h, w).long())
    elif loss_mode == 'focal':
        loss_ce = FocalLoss(reduction='mean', ignore_index=ignore_index, mode='multiclass')(pred_noisy, labels.view(b, h, w).long())

    # regularization
    regularisation = torch.trace(torch.transpose(torch.sum(cm, dim=0), 0, 1)).sum() / (b * h * w)
    regularisation = alpha * regularisation

    if min_trace:
        loss = loss_ce + regularisation
    else:
        loss = loss_ce - regularisation

    return loss, loss_ce, regularisation

In [12]:
import segmentation_models_pytorch as smp


def create_segmentation_backbone():
    class_no = config['data']['class_no']

    if config['model']['backbone'] == 'unet':
        seg_model = smp.Unet(
            encoder_name=config['model']['encoder']['backbone'],  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
            encoder_weights=config['model']['encoder']['weights'],
            # use `imagenet` pre-trained weights for encoder initialization
            in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
            classes=class_no,  # model output channels (number of classes in your dataset)           
        )
        
    elif config['model']['backbone'] == 'linknet':
        seg_model = smp.Linknet(
            encoder_name=config['model']['encoder']['backbone'],  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
            encoder_weights=config['model']['encoder']['weights'],
            # use `imagenet` pre-trained weights for encoder initialization
            in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
            classes=class_no  # model output channels (number of classes in your dataset)
        )
    else:
        raise Exception('Choose valid model backbone!')
    return seg_model


class SegmentationModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.seg_model = create_segmentation_backbone()
        self.activation = torch.nn.Softmax(dim=1)

    def forward(self, x):
        x = self.seg_model(x)
        y = self.activation(x)
        return y

In [13]:
def preprocess_input(
    x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs
):

    if input_space == "BGR":
        x = x[..., ::-1].copy()

    if input_range is not None:
        if x.max() > 1 and input_range[1] == 1:
            x = x / 255.0

    if mean is not None:
        mean = np.array(mean)
        x = x - mean

    if std is not None:
        std = np.array(std)
        x = x / std

    return x


def get_preprocessing_params():
    formatted_settings = {}
    formatted_settings["input_range"] = (0,1)
    formatted_settings["input_space"] = "RGB"
    formatted_settings["mean"] = None
    formatted_settings["std"] = None
    return formatted_settings


def get_preprocessing_fn_without_normalization():
    params = get_preprocessing_params()
    return functools.partial(preprocess_input, **params)

In [14]:
def save_model(model):
    model_dir = 'models'
    dir = os.path.join(config['logging']['experiment_folder'], model_dir)
    os.makedirs(dir, exist_ok=True)
    out_path = os.path.join(dir, 'best_model.pth')
    torch.save(model, out_path)
    print('Best Model saved!')

def save_test_images(test_imgs:torch.Tensor, test_preds: np.array, test_labels: np.array, test_name: np.array, mode: str):
    visual_dir = 'qualitative_results/' + mode
    dir = os.path.join(config['logging']['experiment_folder'], visual_dir)
    os.makedirs(dir, exist_ok=True)

    h, w = np.shape(test_labels)

    test_preds = np.asarray(test_preds, dtype=np.uint8)
    test_labels = np.asarray(test_labels, dtype=np.uint8)

    # print("test name ", test_name)
    out_path = os.path.join(dir, 'img_' + test_name)
    save_image(test_imgs, out_path)

    test_pred_rgb = convert_classes_to_rgb(test_preds, h, w)
    out_path = os.path.join(dir, 'pred_' + test_name)
    imageio.imsave(out_path, test_pred_rgb)

    test_label_rgb = convert_classes_to_rgb(test_labels, h, w)
    out_path = os.path.join(dir, 'gt_' + test_name)
    imageio.imsave(out_path, test_label_rgb)
    mlflow.log_artifacts(dir, visual_dir)

# TODO: funcion que guarde bien el crowdsourcing
def save_crowd_images(test_imgs:torch.Tensor, gt_pred: np.array, test_preds: np.array, test_labels: np.array, test_name: np.array, annotator, cm):
    visual_dir = 'qualitative_results/' + "train_crowd"
    dir = os.path.join(config['logging']['experiment_folder'], visual_dir)
    os.makedirs(dir, exist_ok=True)

    h, w = np.shape(test_labels)

    test_preds = np.asarray(test_preds, dtype=np.uint8)
    test_labels = np.asarray(test_labels, dtype=np.uint8)

    # print("test name ", test_name)
    out_path = os.path.join(dir, 'img_' + test_name)
    save_image(test_imgs, out_path)

    test_pred_rgb = convert_classes_to_rgb(test_preds, h, w)
    out_path = os.path.join(dir, annotator + '_pred_' + test_name)
    imageio.imsave(out_path, test_pred_rgb)

    gt_pred_rgb = convert_classes_to_rgb(gt_pred, h, w)
    out_path = os.path.join(dir, 'gt_pred_' + test_name)
    imageio.imsave(out_path, gt_pred_rgb)

    test_label_rgb = convert_classes_to_rgb(test_labels, h, w)
    out_path = os.path.join(dir, annotator + '_gt_' + test_name)
    imageio.imsave(out_path, test_label_rgb)

    cm = cm.detach().cpu().numpy()
    plt.matshow(cm)
    out_path = os.path.join(dir, annotator + '_matrix_' + test_name)
    plt.savefig(out_path)

    mlflow.log_artifacts(dir, visual_dir)


def save_image_color_legend():
    visual_dir = 'qualitative_results/'
    dir = os.path.join(config['logging']['experiment_folder'], visual_dir)
    os.makedirs(dir, exist_ok=True)
    class_no = config['data']['class_no']
    class_names = config['data']['class_names']

    fig = plt.figure()

    size = 100

    for class_id in range(class_no):
        # out_img[size*class_id:size*(class_id+1),:,:] = convert_classes_to_rgb(np.ones(size,size,3)*class_id, size,size)
        out_img = convert_classes_to_rgb(np.ones(shape=[size,size])*class_id, size,size)
        ax = fig.add_subplot(1, class_no, class_id+1)
        ax.imshow(out_img)
        ax.set_title(class_names[class_id])
        ax.axis('off')
    plt.savefig(dir + 'legend.png')
    mlflow.log_artifact(dir + 'legend.png', 'qualitative_results')


def convert_classes_to_rgb(seg_classes, h, w):

    seg_rgb = np.zeros((h, w, 3), dtype=np.uint8)
    class_no = config['data']['class_no']

    colors = [[0,179,255], [153,0,0], [255,102,204], [0,153,51], [153,0,204]]

    for class_id in range(class_no):
        seg_rgb[:, :, 0][seg_classes == class_id] = colors[class_id][0]
        seg_rgb[:, :, 1][seg_classes == class_id] = colors[class_id][1]
        seg_rgb[:, :, 2][seg_classes == class_id] = colors[class_id][2]

    return seg_rgb


def save_results(results):
    results_dir = 'quantitative_results'
    dir = os.path.join(config['logging']['experiment_folder'], results_dir)
    os.makedirs(dir, exist_ok=True)
    out_path = os.path.join(dir, 'results.csv')

    with open(out_path, 'w') as csv_file:
        writer = csv.writer(csv_file)
        for key, value in results.items():
            writer.writerow([key, value])

In [15]:
def dice_coef_binary(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    smooth = 0.0001
    return (2. * intersection + smooth) / (np.sum(y_true) + np.sum(y_pred) + smooth)

def dice_coef_multilabel(y_true, y_pred):
    class_no = config['data']['class_no']
    dice_per_class = []
    for index in range(class_no):
        dice_per_class.append(dice_coef_binary(y_true == index, y_pred == index))

    return np.array(dice_per_class)

def segmentation_scores(label_trues, label_preds, metric_names):
    '''
    :param label_trues:
    :param label_preds:
    :param n_class:
    :return:
    '''
    results = {}
    class_no = config['data']['class_no']
    class_names = config['data']['class_names']
    ignore_last_class = config['data']['ignore_last_class']

    assert len(label_trues) == len(label_preds)

    label_preds = np.array(label_preds, dtype='int8')
    label_trues = np.array(label_trues, dtype='int8')

    if ignore_last_class:
        label_preds = label_preds[label_trues!=class_no]
        label_trues = label_trues[label_trues!=class_no]

    dice_per_class = dice_coef_multilabel(label_trues, label_preds)

    results['macro_dice'] = dice_per_class.mean()

    intersection = (label_preds == label_trues).sum(axis=None)
    sum_ = 2 * np.prod(label_preds.shape)
    results['micro_dice'] = ((2 * intersection + 1e-6) / (sum_ + 1e-6))

    for class_id in range(class_no):
        results['dice_class_' + str(class_id) + '_' + class_names[class_id]] = dice_per_class[class_id]

    results['accuracy'] = accuracy_score(label_trues, label_preds)
    results['miou'] = jaccard_score(label_trues, label_preds, average="macro") # same as IoU!

    for metric in metric_names:
        assert metric in results.keys()

    return results

In [16]:
def get_training_augmentation():
    aug_config = config['data']['augmentation']
    if aug_config['use_augmentation']:
        train_transform = [
            albu.HorizontalFlip(p=0.5),
            albu.VerticalFlip(p=0.5),
            albu.RandomRotate90(p=0.5),

            albu.Blur(blur_limit=aug_config['gaussian_blur_kernel'], p=0.5),
            albu.RandomBrightnessContrast(brightness_limit=aug_config['brightness_limit'],
                                          contrast_limit=aug_config['contrast_limit'],
                                          p=0.5),
            albu.HueSaturationValue(hue_shift_limit=aug_config['hue_shift_limit'],
                                    sat_shift_limit=aug_config['sat_shift_limit'],
                                    p=0.5)
        ]
        composed_transform = albu.Compose(train_transform)
    else:
        composed_transform = None
    return composed_transform


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.PadIfNeeded(384, 480)
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    Args:
        preprocessing_fn (callbale): data normalization function
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    """

    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

# =============================================

class CustomDataset(torch.utils.data.Dataset):
    """Custom Dataset. Read images, apply augmentation and preprocessing transformations.
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing
            (e.g. normalization, shape manipulation, etc.)
    """
    def __init__(
            self,
            images_dir,
            masks_dir,
            augmentation=None,
            preprocessing=None
    ):
        #if config['data']['sr_experiment']:
        #    names = pd.read_csv(images_dir + config['data']['sr_path'] + 'test.csv').values.tolist()
        #    self.ids = [x[0] for x in names]
        #else:
            #self.ids = os.listdir(images_dir)
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        self.class_no = config['data']['class_no']
        self.class_values = self.set_class_values(self.class_no)
        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def __getitem__(self, i):

        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)
        
        #print(f'Image directory: {self.images_fps[i]}, mask directory: {self.masks_fps[i]}\n')
        
        # check dimensions
        if image.shape[:2] != mask.shape[:2]:
            raise ValueError(f"Inconsistent dimensions: image {image.shape[:2]} vs mask {mask.shape[:2]}\n")
        
        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        return image, mask, self.ids[i], 0

    def __len__(self):
        return len(self.ids)

    def set_class_values(self, class_no):
        if config['data']['ignore_last_class']:
            class_values = list(range(class_no + 1))
        else:
            class_values = list(range(class_no))
        return class_values


class Crowdsourced_Dataset(torch.utils.data.Dataset):
    """Crowdsourced_Dataset Dataset. Read images, apply augmentation and preprocessing transformations.
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing
            (e.g. noralization, shape manipulation, etc.)
    """
    def __init__(
            self,
            images_dir,
            masks_dir,
            augmentation=None,
            preprocessing=None,
            _set = None
    ):
        #if config['data']['sr_experiment']:
        #    names = pd.read_csv(images_dir + config['data']['sr_path'] + 'train.csv').values.tolist()
        #    self.ids = [x[0] for x in names]
        #else:
        #    self.ids = os.listdir(images_dir)
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        
        annotators = os.listdir(masks_dir)
        self.annotators = [e for e in annotators if e not in ('expert', 'MV', 'STAPLE')]
        self.annotators_fps = [os.path.join(masks_dir, annotator) for annotator in self.annotators]
        self.masks_dir = masks_dir
        self.annotators_no = len(self.annotators)
        print("Images: ", self.ids)
        print("Annotators: ")
        print(*self.annotators, sep = "\n")
        print("Number of annotators: ", self.annotators_no)
        print("Paths of annotators ", *self.annotators_fps)
        self.class_no = config['data']['class_no']
        self.class_values = self.set_class_values(self.class_no)
        self.augmentation = augmentation
        self.preprocessing = preprocessing

        if config['data']['ignore_last_class']:
            self.ignore_index = int(self.class_no) # deleted class is always set to the last index
        else:
            self.ignore_index = -100 # this means no index ignored



    def __getitem__(self, i):

        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        size_image, _, _ = image.shape
        indexes = np.random.permutation(self.annotators_no)
        for ann_index in indexes:
            ann_path = self.annotators_fps[ann_index]
            #print('ANN_PATH',ann_path,end='\n')
            mask_path = os.path.join(ann_path, self.ids[i])
            #print('MASK_PATH',mask_path,end='\n')
            if os.path.exists(mask_path):
                mask = cv2.imread(mask_path, 0)
                # extract certain classes from mask (e.g. cars)
                mask = [(mask == v) for v in self.class_values]
                mask = np.stack(mask, axis=-1).astype('float')
                annotator_id = torch.zeros(len(self.annotators_fps))
                annotator_id[self.annotators_fps.index(ann_path)] = 1
                break

                #print("Exist esta wea", mask_path)
            else:
                #print("Not exist ", mask_path)
                continue

        # apply augmentations
        if self.augmentation:
            # print("Augmentation!")
            sample = self.augmentation(image=image, mask=mask)
            image = sample['image']
            mask = sample['mask']

        # apply preprocessing
        if self.preprocessing:
            # print("Preprocessing!")
            sample = self.preprocessing(image=image, mask=mask)
            image = sample['image']
            mask = sample['mask']
        # print("Return ", len(masks), "masks")
        # print(masks.shape)
        return image, mask, self.ids[i], annotator_id

    def __len__(self):
        return len(self.ids)

    def set_class_values(self, class_no):
        if config['data']['ignore_last_class']:
            class_values = list(range(class_no + 1))
        else:
            class_values = list(range(class_no))
        return class_values


def get_data_supervised():
    batch_size = config['model']['batch_size']
    normalization = config['data']['normalization']
    crowd = config['data']['crowd']

    train_image_folder = os.path.join(config['data']['path'], config['data']['train']['images'])
    train_label_folder = os.path.join(config['data']['path'], config['data']['train']['masks'])
    val_image_folder = os.path.join(config['data']['path'], config['data']['val']['images'])
    val_label_folder = os.path.join(config['data']['path'], config['data']['val']['masks'])
    print(f'val_image_folder: {val_image_folder}, val_label_folder: {val_label_folder}')
    test_image_folder = os.path.join(config['data']['path'], config['data']['test']['images'])
    test_label_folder = os.path.join(config['data']['path'], config['data']['test']['masks'])

    if normalization:
        encoder_name = config['model']['encoder']['backbone']
        encoder_weights = config['model']['encoder']['weights']
        preprocessing_fn = get_preprocessing_fn(encoder_name, pretrained=encoder_weights)
    else:
        preprocessing_fn = get_preprocessing_fn_without_normalization()

    preprocessing = get_preprocessing(preprocessing_fn)

    annotators = []

    if crowd:
        train_dataset = Crowdsourced_Dataset(train_image_folder, train_label_folder, augmentation=get_training_augmentation(),
                                      preprocessing = preprocessing)
        annotators = train_dataset.annotators

    else:
        train_dataset = CustomDataset(train_image_folder, train_label_folder, augmentation=get_training_augmentation(),
                                      preprocessing = preprocessing)
    validate_dataset = CustomDataset(val_image_folder, val_label_folder, preprocessing = preprocessing)
    
    test_dataset = CustomDataset(test_image_folder, test_label_folder, preprocessing = preprocessing)

    
    trainloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True)
    
    validateloader = data.DataLoader(validate_dataset, batch_size=batch_size, shuffle=False, num_workers=batch_size,
                                     drop_last=False)
    testloader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=batch_size,
                                 drop_last=False)

    return trainloader, validateloader, testloader, annotators

In [17]:
eps=1e-7


class ModelHandler():
    def __init__(self, annotators):

        # architecture
        if config['model']['crowd_type'] == 'prob-unet':
            self.model = ProbabilisticUnet(3, config['data']['class_no'])
        elif config['data']['crowd']:
            self.model = Crowd_segmentationModel(annotators)
            self.alpha = 1
            self.annotators = annotators
        else:
            self.model = SegmentationModel()

        # loss
        self.loss_mode = config['model']['loss']

        #GPU
        self.model.cuda()
        if torch.cuda.is_available():
            print('Running on GPU')
            self.device = torch.device('cuda')
        else:
            warnings.warn("Running on CPU because no GPU was found!")
            self.device = torch.device('cpu')

    def train(self, trainloader, validateloader):
        model = self.model
        device = self.device
        max_score = 0
        c_weights = config['data']['class_weights']
        class_weights = torch.FloatTensor(c_weights).cuda()

        class_no = config['data']['class_no']
        epochs = config['model']['epochs']
        learning_rate = config['model']['learning_rate']
        batch_s = config['model']['batch_size']
        vis_train_images = config['data']['visualize_images']['train']
        save_image_color_legend()

        # Optimizer
        if config['data']['crowd'] and config['model']['crowd_type']!='prob-unet':
            optimizer = torch.optim.Adam([
                {'params': model.seg_model.parameters()},
                {'params': model.crowd_layers.parameters(), 'lr': 1e-3}
            ], lr=learning_rate)
        elif config['model']['optimizer'] == 'adam':
            optimizer = torch.optim.Adam([
                dict(params=model.parameters(), lr=learning_rate),
            ])
        elif config['model']['optimizer'] == 'sgd_mom':
            optimizer = torch.optim.SGD([
                dict(params=model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True),
            ])
        else:
            raise Exception('Choose valid optimizer!')

        min_trace = config['model']['min_trace']

        # Training loop
        for i in range(0, epochs):

            print('\nEpoch: {}'.format(i))
            model.train()

            # Stop of the warm-up period
            if i == 5: #10 for cr_image_dice // 5 rest of the methods
                print("Minimize trace activated!")
                min_trace = True
                self.alpha = config['model']['alpha']
                print("Alpha updated", self.alpha)

                if config['data']['crowd'] and config['model']['crowd_type']!='prob-unet':
                    optimizer = torch.optim.Adam([
                        {'params': model.seg_model.parameters()},
                        {'params': model.crowd_layers.parameters(), 'lr': 1e-4}
                    ], lr=learning_rate)

            # Training in batches
            for j, (images, labels, imagename, ann_ids) in enumerate(trainloader):
                # Loading data to GPU
                images = images.cuda().float()
                labels = labels.cuda().long()
                ann_ids = ann_ids.cuda().float()
                
                #print(f'Images shape: {images.shape}, Labels shape: {labels.shape}, ANN_ids shape: {ann_ids.shape}')

                # zero the parameter gradients
                optimizer.zero_grad()

                if config['data']['ignore_last_class']:
                    ignore_index = int(config['data']['class_no'])  # deleted class is always set to the last index
                else:
                    ignore_index = -100  # this means no index ignored
                self.ignore_index = ignore_index

                # Foward+loss (crowd or not)
                if config['model']['crowd_type'] == 'prob-unet':
                    _, labels = torch.max(labels, dim=1)
                    labels = labels[:,None,:,:]
                    model.forward(images, labels, training=True)
                    elbo = model.elbo(labels)
                    reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(
                        model.fcomb.layers)
                    loss = -elbo + 1e-5 * reg_loss
                elif config['data']['crowd']:
                    _, labels = torch.max(labels, dim=1)
                    #print('################################################1\n')
                    y_pred, cms = model(images, ann_ids)
                    #print(f'y_pred shape: {y_pred.shape}')
                    loss, loss_ce, loss_trace = noisy_label_loss(y_pred, cms, labels, ignore_index,
                                                                 min_trace, self.alpha, self.loss_mode)
                else:
                    _, labels = torch.max(labels, dim=1)
                    y_pred = model(images)

                    if self.loss_mode == 'ce':
                        loss = torch.nn.NLLLoss(reduction='mean', ignore_index=ignore_index, weight=class_weights)(
                            torch.log(y_pred+eps), labels)
                    elif self.loss_mode == 'dice':
                        loss = DiceLoss(ignore_index=ignore_index, from_logits=False, mode='multiclass')(
                            y_pred, labels)
                    elif self.loss_mode == 'focal':
                        loss = FocalLoss(reduction='mean', ignore_index=ignore_index, mode='multiclass')(
                            y_pred, labels)

                # Final prediction
                if not config['data']['crowd']:
                    _, y_pred_max = torch.max(y_pred[:, 0:class_no], dim=1)

                # Backprop
                if not torch.isnan(loss):
                    loss.backward()
                    optimizer.step()

                # Save results in training (only save for not crowd methods)
                if j % int(config['logging']['interval']) == 0:
                    print("Iter {}/{} - batch loss : {:.4f}".format(j, len(trainloader), loss))
                    if not config['data']['crowd']:
                        train_results = self.get_results(y_pred_max, labels)
                        log_results(train_results, mode='train', step=(i * len(trainloader) * batch_s + j))
                        for k in range(len(imagename)):
                            if imagename[k] in vis_train_images:
                                labels_save = labels[k].cpu().detach().numpy()
                                y_pred_max_save = y_pred_max[k].cpu().detach().numpy()
                                images_save = images[k]  # .cpu().detach().numpy()
                                save_test_images(images_save, y_pred_max_save, labels_save, imagename[k], 'train')

            # Save validation results
            val_results = self.evaluate(validateloader, mode='val')  # TODO: validate crowd
            log_results(val_results, mode='val', step=int((i + 1) * len(trainloader) * batch_s))
            mlflow.log_metric('finished_epochs', i + 1, int((i + 1) * len(trainloader) * batch_s))

            # Save model
            metric_for_saving = val_results['macro_dice']
            if max_score < metric_for_saving and i > 9:
                save_model(model)
                max_score = metric_for_saving

            # LR decay
            if i > config['model']['lr_decay_after_epoch']:
                for g in optimizer.param_groups:
                    g['lr'] = g['lr'] / (1 + config['model']['lr_decay_param'])

            # Show annotator matrix
            if config['data']['crowd'] and config['model']['crowd_type']!='prob-unet':
                _,  ann_id = torch.max(ann_ids, dim=1)
                for ann_ix, cm in enumerate(cms):
                    print(f'ann_id: {ann_id}, ann_ix:{ann_ix}, cm shape: {cm.shape}')
                    if cm.shape[0] == 4:
                    # adjust the shape of the `cm` tensor for subsequent operations.
                        cm = cm.view(2, 2, 512, 512)
                    else:
                        cm = cm.view(5,5,512,512)
                    cm_ = cm[:,:,100,100]
                    cm_ = cm_/cm_.sum(0)
                    print("Annotators", ann_id)
                    print("CM ", ann_id[ann_ix].cpu().detach().numpy()+1, ": ", cm_.cpu().detach().numpy())


        # Final evaluation of crowd
        if config['data']['crowd'] and config['model']['crowd_type']!='prob-unet':
            self.evaluate_crowd(trainloader, mode='train')

    def test(self, testloader):
        save_image_color_legend()
        results = self.evaluate(testloader)
        log_results(results, mode='test', step=None)
        save_results(results)

    def evaluate(self, evaluatedata, mode='test'):
        class_no = config['data']['class_no']
        vis_images = config['data']['visualize_images'][mode]

        if mode=='test':
            print("Testing the best model")
            model_dir = 'models'
            dir = os.path.join(config['logging']['experiment_folder'], model_dir)
            model_path = os.path.join(dir, 'best_model.pth')
            model = torch.load(model_path)
        else:
            model = self.model

        device = self.device
        model.eval()

        labels = []
        preds = []

        with torch.no_grad():
            for j, (test_img, test_label, test_name, _) in enumerate(evaluatedata):
                test_img = test_img.to(device=device, dtype=torch.float32)
                if config['model']['crowd_type'] == 'prob-unet':
                    model.forward(test_img, None, training=False)
                    test_pred = model.sample(testing=True)
                elif config['data']['crowd']:
                    test_pred, _ = model(test_img)
                else:
                    test_pred = model(test_img)
                _, test_pred = torch.max(test_pred[:, 0:class_no], dim=1)
                test_pred_np = test_pred.cpu().detach().numpy()
                test_label = test_label.cpu().detach().numpy()
                test_label = np.argmax(test_label, axis=1)

                preds.append(test_pred_np.astype(np.int8).copy().flatten())
                labels.append(test_label.astype(np.int8).copy().flatten())

                for k in range(len(test_name)):
                    if test_name[k] in vis_images or vis_images == 'all':
                        img = test_img[k]
                        save_test_images(img, test_pred_np[k], test_label[k], test_name[k], mode)

            preds = np.concatenate(preds, axis=0, dtype=np.int8).flatten()
            labels = np.concatenate(labels, axis=0, dtype=np.int8).flatten()

            results = self.get_results(preds, labels)

            print('RESULTS for ' + mode)
            print(results)
            return results

    def evaluate_crowd(self, evaluatedata, mode='train'):
        class_no = config['data']['class_no']
        vis_images = config['data']['visualize_images'][mode]
        print("Testing the best model for crowds")
        model_dir = 'models'
        dir = os.path.join(config['logging']['experiment_folder'], model_dir)
        model_path = os.path.join(dir, 'best_model.pth')
        model = torch.load(model_path)

        device = self.device
        model.eval()

        with torch.no_grad():
            for j, (test_img, test_label, test_name, ann_id) in enumerate(evaluatedata):
                test_img = test_img.to(device=device, dtype=torch.float32)
                ann_id = ann_id.to(device=device)
                pred_noisy_list = []
                test_pred, cm = model(test_img, ann_id)

                test_pred_np = test_pred.cpu().detach().numpy()
                test_pred_np = np.argmax(test_pred_np, axis=1)

                _, test_label = torch.max(test_label, dim=1)
                test_label = test_label.cpu().detach().numpy()

                b, c, h, w = test_pred.size()

                pred_noisy = test_pred.view(b, c, h * w).permute(0, 2, 1).contiguous().view(b * h * w, c, 1)

                cm = cm.view(b, c ** 2, h * w).permute(0, 2, 1).contiguous().view(b * h * w, c * c).view(
                    b * h * w, c, c)
                cm = cm / cm.sum(1, keepdim=True) # normalize cm

                pred_noisy = torch.bmm(cm, pred_noisy).view(b * h * w, c) # prediction annotator
                pred_noisy = pred_noisy.view(b, h * w, c).permute(0, 2, 1).contiguous().view(b, c, h, w)

                _, pred_noisy = torch.max(pred_noisy[:, 0:class_no], dim=1)
                pred_noisy_np = pred_noisy.cpu().detach().numpy()

                pred_noisy_list.append(pred_noisy.cpu().detach().numpy().astype(np.int8).copy().flatten())

                cm = cm.view(b, h*w, c, c)

                if config['model']['crowd_type'] == 'pixel':
                    cm = cm.mean(1)
                    cm = cm/cm.sum(1, keepdim=True)

                else:
                    cm = cm[:,0,:,:]
                _, ann = torch.max(ann_id, dim=1)
                ann = ann.cpu().detach().numpy()
                for k in range(len(test_name)):
                    if test_name[k] in vis_images or vis_images == 'all':
                        img = test_img[k]
                        save_crowd_images(img, test_pred_np[k], pred_noisy_np[k], test_label[k],
                                          test_name[k], self.annotators[ann[k]], cm[k])

    def get_results(self, pred, label):
        class_no = config['data']['class_no']
        class_names = config['data']['class_names']

        metrics_names = ['macro_dice', 'micro_dice', 'miou', 'accuracy']
        for class_id in range(class_no):
            metrics_names.append('dice_class_' + str(class_id) + '_' + class_names[class_id])

        if torch.is_tensor(pred):
            pred = pred.cpu().detach().numpy().copy().flatten()
        if torch.is_tensor(label):
            label = label.cpu().detach().numpy().copy().flatten()

        results = segmentation_scores(label, pred, metrics_names)

        return results

In [19]:
if not os.path.exists('/kaggle/working/OxfordPet'):
    os.makedirs('/kaggle/working/OxfordPet')

!cp -r /kaggle/input/oxfordpets /kaggle/working

if os.path.exists('/kaggle/working/oxfordpets'):
    os.rename('/kaggle/working/oxfordpets', '/kaggle/working/OxfordPet')
    
if not os.path.exists('/kaggle/working/OxfordPet/experiments'):
    os.makedirs('/kaggle/working/OxfordPet/experiments')
    
!cp -r /kaggle/input/crowd-seg-model-training-results/models /kaggle/working/OxfordPet
!cp -r /kaggle/input/crowd-seg-model-training-results/qualitative_results /kaggle/working/OxfordPet
!cp -r /kaggle/input/crowd-seg-model-training-results/quantitative_results /kaggle/working/OxfordPet

In [None]:
def main():
    print(os.curdir)
    #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
    #os.environ["CUDA_VISIBLE_DEVICES"] = str(3)
    #os.environ['CUDA_LAUNCH_BLOCKING'] = str(1)

    start_logging()

    # load data
    trainloader, validateloader, testloader, annotators = get_data_supervised()

    # load and train the model
    model_handler = ModelHandler(annotators)
    model_handler.train(trainloader, validateloader)
    model_handler.test(testloader)

# Simulate argument parsing for notebook execution
args = argparse.Namespace(
    default_config="/kaggle/working/OxfordPet/config.yaml",
    experiment_folder= "/kaggle/working/OxfordPet/"
)
init_global_config(args)
torch.manual_seed(config['model']['seed'])

In [None]:
# Call the main function to execute the workflow
main()

In [None]:
model = torch.load('/kaggle/working/OxfordPet/models/best_model.pth')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

In [21]:
def load_chunk_from_hdf5(filename, dataset_name, start_index, end_index):
    """
    Load a chunk of data from an HDF5 file.

    This function loads a chunk of data from an HDF5 file, given the filename,
    the name of the dataset in the file, and the start and end indices of the chunk
    in the dataset. The function uses the h5py library to open the file in read mode
    and read the chunk data from the specified dataset.

    Args:
        filename (str): The name of the HDF5 file to load the chunk from.
        dataset_name (str): The name of the dataset in the HDF5 file to load the chunk from.
        start_index (int): The start index of the chunk in the dataset.
        end_index (int): The end index of the chunk in the dataset.

    Returns:
        chunk_data (numpy.ndarray): The chunk of data loaded from the HDF5 file.
    """
    with h5py.File(filename, 'r') as f:
        dataset = f[dataset_name]
        chunk_data = dataset[start_index:end_index, :, :, :]
    return chunk_data

## Sample of images and mask in Oxford Pets dataset in h5 file

In [None]:
chunk_data_images = load_chunk_from_hdf5('/kaggle/working/OxfordPet/images_train.h5', 'dataset', 105, 110)

chunk_data_masks = load_chunk_from_hdf5('/kaggle/working/OxfordPet/masks_train.h5', 'dataset', 105, 110)

print(f'Images shape: {chunk_data_images.shape}, Masks shape: {chunk_data_masks.shape}')
plt.imsave('./image_example_train.png', chunk_data_images[0], format='png')
plt.imsave('./mask_example_train.png', chunk_data_masks[0,:,:,0], format='png')
display(Image(filename='/kaggle/working/image_example_train.png'))
display(Image(filename='/kaggle/working/mask_example_train.png'))

## Definition of performance metrics

### DICE metric

$$\text{Dice} = {2 \cdot |\text{Intersection}| + \text{smooth} \over |\text{Union}| + \text{smooth}}$$

Where:

$|\text{Intersection}| = \sum_{i=1}^{N} y\_{true\_i} \cdot y\_{pred\_i}$, $|\text{Union}| = \sum_{i=1}^{N} y\_{true\_i} + \sum_{i=1}^{N} y\_{pred\_i}$


- $N$ is the total number of elements in the segmentation masks.
- $y\_{true\_i}$ and $y\_{pred\_i}$ represent the value of the i-th element in the ground truth and predicted segmentation masks, respectively.
- $\text{smooth}$ is a smoothing parameter to avoid division by zero.

In [23]:
class DiceMetric:
    """
    Class for calculating the Dice coefficient metric for image segmentation.

    This class provides a method for calculating the Dicecoefficient between 
    a ground truth image and a predicted image, measures their similarity,
    as well as a method for calculating the total Dice coefficient for a set of images.

    Attributes:
        None

    Methods:
        calculate: Calculates the Dice coefficient between a ground truth image
                   and a predicted image.
        total_dice: Calculates the total Dice coefficient for a set of images.
    """
    
    def __init__(self):
        """
        Initializes the dice_metric class.
        """
        pass
    
    def calculate(self,y_true, y_pred, axis=(1, 2), smooth=1e-5):
        """
        Calculates the Dice coefficient between a ground truth image and a
        predicted image.

        Parameters:
            y_true: Ground truth image.
            y_pred: Predicted image.
            axis: Axes over which to sum the intersection and union of the images.
            smooth: Small value added to the numerator and denominator to avoid
                    division by zero.

        Returns:
            None
        """
        global dice_total
        y_true = np.squeeze(y_true, axis=-1).astype(np.float32)
        y_pred = y_pred.astype(np.float32)
        intersection = np.sum(y_true * y_pred, axis=axis)
        union = np.sum(y_true, axis=axis) + np.sum(y_pred, axis=axis)
        dice = (2. * intersection + smooth) / (union + smooth)
        dice_total.append(dice)
    
    def total_dice(self):
        """
        Calculates the total Dice coefficient for a set of images.

        Parameters:
            None

        Returns:
            value: The total Dice coefficient for the set of images.
        """
        value = np.concatenate(dice_total, axis=0)
        return np.mean(value)

### Jaccard metric

$$\text{Jaccard} = {|\text{Intersection}| + \text{smooth} \over |\text{Union}| + \text{smooth}}$$

Where:

$|\text{Intersection}| = \sum_{i=1}^{N} y\_{true\_i} \cdot y\_{pred\_i}$, $|\text{Union}| = \sum_{i=1}^{N} y\_{true\_i} + \sum_{i=1}^{N} y\_{pred\_i} - |\text{Intersection}|$
- $N$ is the total number of elements in the segmentation masks.
- $y\_{true\_i}$ and $y\_{pred\_i}$ represent the value of the i-th element in the ground truth and predicted segmentation masks, respectively.
- $\text{smooth}$ is a small smoothing parameter to avoid division by zero.

In [24]:
class JaccardMetric:
    """
    Class for calculating the Jaccard similarity coefficient metric for image
    segmentation.

    The Jaccard similarity coefficient, also known as the Intersection over Union (IoU),
    measures the similarity between two sets by comparing their intersection to their union.
    In the context of semantic segmentation, it quantifies the overlap between the ground
    truth segmentation masks and the predicted segmentation masks.

    Attributes:
        None

    Methods:
        calculate: Calculates the Jaccard similarity coefficient between a ground
                   truth image and a predicted image.
        total_jaccard: Calculates the total Jaccard similarity coefficient for a
                       set of images.
    """
    def __init__(self):
        """
        Initializes the JaccardMetric class.
        """
        pass
    
    def calculate(self, y_true, y_pred, axis=(1, 2), smooth=1e-5):
        """
        Calculates the Jaccard similarity coefficient between a ground truth image
        and a predicted image.

        Parameters:
            y_true: Ground truth image.
            y_pred: Predicted image.
            axis: Axes over which to sum the intersection and union of the images.
            smooth: Small value added to the numerator and denominator to avoid
                    division by zero.

        Returns:
            None
        """
        global jaccard_total
        y_true = np.squeeze(y_true, axis=-1).astype(np.float32)
        y_pred = y_pred.astype(np.float32)
        intersection = np.sum(y_true * y_pred, axis=axis)
        union = np.sum(y_true, axis=axis) + np.sum(y_pred, axis=axis) - intersection
        jaccard = (intersection + smooth) / (union + smooth)
        jaccard_total.append(jaccard)
        
    def total_jaccard(self):
        """
        Calculates the total Jaccard similarity coefficient for a set of images.

        Parameters:
            None

        Returns:
            value: The total Jaccard similarity coefficient for the set of images.
        """
        value = np.concatenate(jaccard_total, axis=0)
        return np.mean(value)

### Sensitivity metric

$$\text{Sensitivity} = {\text{True Positives} \over \text{Actual Positives} + \text{smooth}}$$

Where:

$\text{True Positives} = \sum_{i=1}^{N} y\_{true\_i} \cdot y\_{pred\_i}$, $\text{Actual Positives} = \sum_{i=1}^{N} y\_{true\_i}$


- $N$ is the total number of elements in the labels.
- $y\_{true\_i}$ and $y\_{pred\_i}$ represent the value of the i-th element in the ground truth and predicted labels, respectively.
- $\text{smooth}$ is a small value added to the denominator to avoid division by zero.

In [25]:
class SensitivityMetric:
    """
    Class for calculating the sensitivity metric for image segmentation.

    Sensitivity, also known as true positive rate or recall, measures the proportion
    of actual positives that are correctly identified by the model. It is computed
    as the ratio of true positives to the sum of true positives and false negatives.

    Attributes:
        None

    Methods:
        calculate: Calculates the sensitivity between a ground truth image and a
                   predicted image.
        total_sensitivity: Calculates the total sensitivity for a set of images.
    """
    def __init__(self):
        """
        Initializes the SensitivityMetric class.
        """
        pass
    
    def calculate(self, y_true, y_pred, axis=(1, 2), smooth=1e-5):
        """
        Calculates the sensitivity between a ground truth image and a predicted
        image.

        Parameters:
            y_true: Ground truth image.
            y_pred: Predicted image.
            axis: Axes over which to sum the true positives and actual positives.
            smooth: Small value added to the denominator to avoid division by zero.

        Returns:
            None
        """
        global sensitivity_total
        y_true = np.squeeze(y_true, axis=-1).astype(np.float32)
        y_pred = y_pred.astype(np.float32)
        true_positives = np.sum(y_true * y_pred, axis=axis)
        actual_positives = np.sum(y_true, axis=axis)
        sensitivity = true_positives / (actual_positives + smooth)
        sensitivity_total.append(sensitivity)
        
    def total_sensitivity(self):
        """
        Calculates the total sensitivity for a set of images.

        Parameters:
            None

        Returns:
            value: The total sensitivity for the set of images.
        """
        value = np.concatenate(sensitivity_total, axis=0)
        return np.mean(value)

### Specificity metric

$$\text{Specificity} = {\text{True Negatives} \over \text{Actual Negatives} + \text{smooth}}$$

Where:

$\text{True Negatives} = \sum_{i=1}^{N} (1 - y\_{true\_i}) \cdot (1 - y\_{pred\_i})$, $\text{Actual Negatives} = \sum_{i=1}^{N} (1 - y\_{true\_i})$

- $N$ is the total number of samples.
- $y\_{true\_i}$ and $y\_{pred\_i}$ represent the ground truth label and predicted probability (or binary prediction) for the i-th sample, respectively.
- $\text{smooth}$ is a smoothing term to avoid division by zero.

In [26]:
class SpecificityMetric:
    """
    Class for calculating the specificity metric for image segmentation.

    Specificity measures the proportion of actual negative cases that were correctly
    identified as such. It is complementary to sensitivity (recall).

    Attributes:
        None

    Methods:
        calculate: Calculates the specificity between a ground truth image and a
                   predicted image.
        total_specificity: Calculates the total specificity for a set of images.
    """
    def __init__(self):
        """
        Initializes the SpecificityMetric class.
        """
        pass
    
    def calculate(self, y_true, y_pred, axis=(1, 2), smooth=1e-5):
        """
        Calculates the specificity between a ground truth image and a predicted
        image.

        Parameters:
            y_true: Ground truth image.
            y_pred: Predicted image.
            axis: Axes over which to sum the true negatives and actual negatives.
            smooth: Small value added to the denominator to avoid division by zero.

        Returns:
            None
        """
        global specificity_total
        y_true = np.squeeze(y_true, axis=-1).astype(np.float32)
        y_pred = y_pred.astype(np.float32)
        true_negatives = np.sum((1 - y_true) * (1 - y_pred), axis=axis)
        actual_negatives = np.sum(1 - y_true, axis=axis)
        specificity = true_negatives / (actual_negatives + smooth)
        specificity_total.append(specificity)
        
    def total_specificity(self):
        """
        Calculates the total specificity for a set of images.

        Parameters:
            None

        Returns:
            value: The total specificity for the set of images.
        """
        value = np.concatenate(specificity_total, axis=0)
        return np.mean(value)

### Model performance for the training part of the dataset

In [None]:
dice_total = []
jaccard_total = []
sensitivity_total = []
specificity_total = []

train_dice = DiceMetric()
train_jaccard = JaccardMetric()
train_sensitivity = SensitivityMetric()
train_specificity = SpecificityMetric()

for batch in range(0, 206, 5):
    
    original_images = load_chunk_from_hdf5('/kaggle/working/OxfordPet/images_train.h5', 'dataset', batch, batch+5)
    original_masks = load_chunk_from_hdf5('/kaggle/working/OxfordPet/masks_train.h5', 'dataset', batch, batch+5)
    input_tensor = ((torch.from_numpy(original_images)).permute(0, 3, 1, 2)).to(device)
    A_id_tensor = torch.zeros(5, 3).to(device)
    with torch.no_grad():
        output, cm = model(input_tensor, A_id_tensor)
    output = output.detach().cpu()
    cm = cm.detach().cpu()
    predicted_masks = output[:,1,:,:].numpy() > 0.5
    
    train_dice.calculate(original_masks, predicted_masks)
    train_jaccard.calculate(original_masks, predicted_masks)
    train_sensitivity.calculate(original_masks, predicted_masks)
    train_specificity.calculate(original_masks, predicted_masks)
    
    del input_tensor, A_id_tensor, output, cm
    torch.cuda.empty_cache()
    
print(f'Dice: {train_dice.total_dice()}, Jaccard: {train_jaccard.total_jaccard()}, Sensitivity: {train_sensitivity.total_sensitivity()}, Specificity: {train_specificity.total_specificity()}')

In [None]:
plt.imsave('./image_example_train_1.png', predicted_masks[4], format='png')
display(Image(filename='/kaggle/working/image_example_train_1.png'))

In [None]:
dice_total = []
jaccard_total = []
sensitivity_total = []
specificity_total = []

val_dice = DiceMetric()
val_jaccard = JaccardMetric()
val_sensitivity = SensitivityMetric()
val_specificity = SpecificityMetric()

for batch in range(0, 206, 5):
    
    original_images = load_chunk_from_hdf5('/kaggle/working/OxfordPet/images_val.h5', 'dataset', batch, batch+5)
    original_masks = load_chunk_from_hdf5('/kaggle/working/OxfordPet/masks_val.h5', 'dataset', batch, batch+5)
    input_tensor = ((torch.from_numpy(original_images)).permute(0, 3, 1, 2)).to(device)
    A_id_tensor = torch.zeros(5, 3).to(device)
    with torch.no_grad():
        output, cm = model(input_tensor, A_id_tensor)
    output = output.detach().cpu()
    cm = cm.detach().cpu()
    predicted_masks = output[:,1,:,:].numpy() > 0.5
    
    train_dice.calculate(original_masks, predicted_masks)
    train_jaccard.calculate(original_masks, predicted_masks)
    train_sensitivity.calculate(original_masks, predicted_masks)
    train_specificity.calculate(original_masks, predicted_masks)
    
    del input_tensor, A_id_tensor, output, cm
    torch.cuda.empty_cache()
    
print(f'Dice: {val_dice.total_dice()}, Jaccard: {val_jaccard.total_jaccard()}, Sensitivity: {val_sensitivity.total_sensitivity()}, Specificity: {val_specificity.total_specificity()}')

In [None]:
dice_total = []
jaccard_total = []
sensitivity_total = []
specificity_total = []

test_dice = DiceMetric()
test_jaccard = JaccardMetric()
test_sensitivity = SensitivityMetric()
test_specificity = SpecificityMetric()

for batch in range(0, 206, 5):
    
    original_images = load_chunk_from_hdf5('/kaggle/working/OxfordPet/images_test.h5', 'dataset', batch, batch+5)
    original_masks = load_chunk_from_hdf5('/kaggle/working/OxfordPet/masks_test.h5', 'dataset', batch, batch+5)
    input_tensor = ((torch.from_numpy(original_images)).permute(0, 3, 1, 2)).to(device)
    A_id_tensor = torch.zeros(5, 3).to(device)
    with torch.no_grad():
        output, cm = model(input_tensor, A_id_tensor)
    output = output.detach().cpu()
    cm = cm.detach().cpu()
    predicted_masks = output[:,1,:,:].numpy() > 0.5
    
    train_dice.calculate(original_masks, predicted_masks)
    train_jaccard.calculate(original_masks, predicted_masks)
    train_sensitivity.calculate(original_masks, predicted_masks)
    train_specificity.calculate(original_masks, predicted_masks)
    
    del input_tensor, A_id_tensor, output, cm
    torch.cuda.empty_cache()
    
print(f'Dice: {test_dice.total_dice()}, Jaccard: {test_jaccard.total_jaccard()}, Sensitivity: {test_sensitivity.total_sensitivity()}, Specificity: {test_specificity.total_specificity()}')