# Loss functions

> A set of custom loss functions

In [None]:
#| default_exp losses

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from bioMONAI.core import store_attr

In [None]:
#| export
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import sigmoid

from monai.losses import SSIMLoss

from scipy.optimize import curve_fit
from fastai.vision.all import mse, mae, CrossEntropyLossFlat, Any

from bioMONAI.metrics import FRCMetric, get_fourier_ring_correlations
from bioMONAI.core import torchTensor


In [None]:
#| export

def MSELoss(
    inp: Any,
    targ: Any
    ) -> torchTensor:
    
    return mse(inp, targ)


In [None]:
#| export
def L1Loss(
    inp: Any,
    targ: Any
    ) -> torchTensor:
    
    return mae(inp, targ)

In [None]:
show_doc(SSIMLoss)

---

### SSIMLoss

>      SSIMLoss (spatial_dims:int, data_range:float=1.0,
>                kernel_type:monai.metrics.regression.KernelType|str=gaussian,
>                win_size:int|collections.abc.Sequence[int]=11,
>                kernel_sigma:float|collections.abc.Sequence[float]=1.5,
>                k1:float=0.01, k2:float=0.03,
>                reduction:monai.utils.enums.LossReduction|str=mean)

*Compute the loss function based on the Structural Similarity Index Measure (SSIM) Metric.

For more info, visit
    https://vicuesoft.com/glossary/term/ssim-ms-ssim/

SSIM reference paper:
    Wang, Zhou, et al. "Image quality assessment: from error visibility to structural
    similarity." IEEE transactions on image processing 13.4 (2004): 600-612.*

### Combined Losses

In [None]:
#| export

class CombinedLoss:
    """
    CombinedLoss computes a weighted combination of SSIM, MSE, and MAE losses.

    This class allows for the combination of three different loss functions:
    Structural Similarity Index (SSIM), Mean Squared Error (MSE), and Mean Absolute Error (MAE).
    The weights for MSE and MAE can be adjusted, and the weight for SSIM is automatically 
    calculated as the remaining weight.
    
    CombinedLoss reference paper:
    Shah, Z. H., Müller, M., Hammer, B., Huser, T., & Schenck, W. (2022, July). 
    Impact of different loss functions on denoising of microscopic images. 
    In 2022 International Joint Conference on Neural Networks (IJCNN) (pp. 1-10). IEEE.
    """
    def __init__(self, spatial_dims=2,  # Number of spatial dimensions (2 for 2D images, 3 for 3D images)
                 mse_weight=0.33,       # Weight for the MSE loss component
                 mae_weight=0.33,       # Weight for the MAE loss component
                 ):
        store_attr()
        self.SSIM_loss = SSIMLoss(spatial_dims=spatial_dims)
        self.MSE_loss = nn.MSELoss()
        self.MAE_loss = nn.L1Loss()
        
    def __call__(self, pred, targ):
        """
        Compute the combined loss.
        """
        return (1 - self.mse_weight - self.mae_weight) * self.SSIM_loss(pred, targ) + self.mse_weight * self.MSE_loss(pred, targ) + self.mae_weight * self.MAE_loss(pred, targ)
        

In [None]:
#| export
class MSSSIMLoss(torch.nn.Module):
    """
    Multi-Scale Structural Similarity (MSSSIM) Loss using MONAI's SSIMLoss as the base.
    """
    def __init__(self, spatial_dims=2,      # Number of spatial dimensions (2 for 2D images, 3 for 3D images).
                 window_size: int = 8,      # Size of the Gaussian filter for SSIM.
                 sigma: float = 1.5,        # Standard deviation of the Gaussian filter.
                 reduction: str = "mean",   # Specifies the reduction to apply to the output ('mean', 'sum', or 'none').
                 levels: int = 3,           # Number of scales to use for MS-SSIM.
                 weights=None,              # Weights to apply to each scale. If None, default values are used.
                 ):
        super(MSSSIMLoss, self).__init__()
        self.ssim = SSIMLoss(spatial_dims, win_size=window_size, kernel_sigma=sigma, reduction="none")
        self.levels = levels
        if weights is None:
            # Default weights for 5 levels, typically used in MS-SSIM
            self.weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
        else:
            self.weights = torch.FloatTensor(weights)
        self.weights = self.weights[:levels]
        self.reduction = reduction
        self.spatial_dims = spatial_dims

    def forward(self, x, y):
        # Ensure input tensors are the same size
        if x.size() != y.size():
            raise ValueError("Input images must have the same dimensions.")
        
        # Make sure the weights match the number of levels
        if len(self.weights) != self.levels:
            raise ValueError(f"Number of weights ({len(self.weights)}) must match the number of levels ({self.levels}).")
        
        msssim_values = []
        for i in range(self.levels):
            # Compute SSIM at this scale
            ssim_value = self.ssim(x, y)
            msssim_values.append(ssim_value * self.weights[i])

            # Downsample images for the next scale, except at the last scale
            if i < self.levels - 1:
                pool = F.avg_pool2d if self.spatial_dims == 2 else F.avg_pool3d
                x = pool(x, kernel_size=2, stride=2)
                y = pool(y, kernel_size=2, stride=2)

        # Stack and sum weighted SSIM values from all scales
        msssim = torch.stack(msssim_values, dim=0).sum(dim=0)/self.weights.sum()

        # Apply reduction (mean or sum)
        if self.reduction == "mean":
            return msssim.mean()
        elif self.reduction == "sum":
            return msssim.sum()
        else:
            return msssim


In [None]:

msssim_loss = MSSSIMLoss(levels=3)
ssim_loss = SSIMLoss(2)
output = torch.rand(10, 3, 64, 64).cuda()  # Example output
target = torch.rand(10, 3, 64, 64).cuda()  # Example target
loss = msssim_loss(output, target)
loss2 = ssim_loss(output,target)
print("ms-ssim: ",loss, '\nssim: ', loss2)


ms-ssim:  tensor(0.9686, device='cuda:0') 
ssim:  tensor(0.9949, device='cuda:0')


In [None]:
#| export
class MSSSIML1Loss(torch.nn.Module):
    """
    Multi-Scale Structural Similarity (MSSSIM) with Gaussian-weighted L1 Loss.
    
    Reference paper:
    Zhao, H., Gallo, O., Frosio, I., & Kautz, J. (2016). 
    Loss functions for image restoration with neural networks. 
    IEEE Transactions on computational imaging, 3(1), 47-57.
    """
    def __init__(self, spatial_dims=2, # Number of spatial dimensions.
                 alpha: float = 0.025, #  Weighting factor between MS-SSIM and L1 loss.
                 window_size: int = 8, # Size of the Gaussian filter for SSIM.
                 sigma: float = 1.5, # Standard deviation of the Gaussian filter.
                 reduction: str = "mean", # Specifies the reduction to apply to the output ('mean', 'sum', or 'none').
                 levels: int = 3, # Number of scales to use for MS-SSIM.
                 weights=None, # Weights to apply to each scale. If None, default values are used.
                 ):
        super(MSSSIML1Loss, self).__init__()
        self.msssim = MSSSIMLoss(spatial_dims=spatial_dims, window_size=window_size, sigma=sigma, 
                                 reduction="none", levels=levels, weights=weights)
        store_attr()

    def forward(self, x, y):
        # Compute MSSSIM loss
        msssim_loss = self.msssim(x, y)

        # Compute L1 loss with Gaussian weighting
        gaussian = self.get_gaussian_weight(x.size()).to(x.device)
        l1_loss = F.l1_loss(x, y, reduction='none') * gaussian

        # Adjust reduction to accommodate 3D
        spatial_dims = tuple(range(1, x.ndim))  # Automatically handles 2D and 3D

        if self.reduction == "mean":
            l1_loss = l1_loss.mean(dim=spatial_dims)
        elif self.reduction == "sum":
            l1_loss = l1_loss.sum(dim=spatial_dims)

        # Combine the two losses
        combined_loss = self.alpha * msssim_loss + (1 - self.alpha) * l1_loss

        if self.reduction == "mean":
            return combined_loss.mean()
        elif self.reduction == "sum":
            return combined_loss.sum()
        else:
            return combined_loss

    def get_gaussian_weight(self, size):
        """Generate a Gaussian weight tensor based on input size."""
        batch_size, channels, *spatial_shape = size
        spatial_dims = len(spatial_shape)
        
        if spatial_dims == 2:
            width, height = spatial_shape
            sigma = width / 6.0
            x, y = torch.arange(width, dtype=torch.float32, device='cuda'), torch.arange(height, dtype=torch.float32, device='cuda')
            center_x, center_y = (width - 1) / 2.0, (height - 1) / 2.0
            x_grid, y_grid = torch.meshgrid(x, y, indexing='ij')
            gaussian = torch.exp(-((x_grid - center_x)**2 + (y_grid - center_y)**2) / (2 * sigma**2))
            gaussian /= gaussian.sum()
            gaussian_weight = gaussian.view(1, 1, width, height).expand(batch_size, channels, -1, -1)

        elif spatial_dims == 3:
            depth, width, height = spatial_shape
            sigma = width / 6.0
            z = torch.arange(depth, dtype=torch.float32, device='cuda')
            x = torch.arange(width, dtype=torch.float32, device='cuda')
            y = torch.arange(height, dtype=torch.float32, device='cuda')
            center_z, center_x, center_y = (depth - 1) / 2.0, (width - 1) / 2.0, (height - 1) / 2.0
            z_grid, x_grid, y_grid = torch.meshgrid(z, x, y, indexing='ij')
            gaussian = torch.exp(-((z_grid - center_z)**2 + (x_grid - center_x)**2 + (y_grid - center_y)**2) / (2 * sigma**2))
            gaussian /= gaussian.sum()
            gaussian_weight = gaussian.view(1, 1, depth, width, height).expand(batch_size, channels, -1, -1, -1)

        return gaussian_weight


In [None]:
msssiml1_loss = MSSSIML1Loss(alpha=0.025, window_size=11, sigma=1.5, levels=3)
input_image = torch.randn(4, 1, 128, 128)  # Batch of 4 grayscale images (1 channel)
target_image = torch.randn(4, 1, 128, 128)

# Compute MSSSIM + Gaussian-weighted L1 loss
loss = msssiml1_loss(input_image, target_image)
loss2 = ssim_loss(input_image, target_image)
print("ms-ssim: ", loss, '\nssim: ', loss2)


ms-ssim:  tensor(0.0250) 
ssim:  tensor(0.9955)


In [None]:
#| export
class MSSSIML2Loss(torch.nn.Module):
    """
    Multi-Scale Structural Similarity (MSSSIM) with Gaussian-weighted L2 Loss.

    Reference paper:
    Zhao, H., Gallo, O., Frosio, I., & Kautz, J. (2016). 
    Loss functions for image restoration with neural networks. 
    IEEE Transactions on computational imaging, 3(1), 47-57.    
    """
    def __init__(self, spatial_dims=2, # Number of spatial dimensions.
                 alpha: float = 0.1,# Weighting factor between MS-SSIM and L2 loss.
                 window_size: int = 11,# Size of the Gaussian window for SSIM.
                 sigma: float = 1.5,# Standard deviation of the Gaussian.
                 reduction: str = "mean",# Specifies the reduction to apply to the output ('mean', 'sum', or 'none').
                 levels: int = 3,# Number of scales to use for MS-SSIM.
                 weights=None,# Weights to apply to each scale. If None, default values are used.
                 ):
        super(MSSSIML2Loss, self).__init__()
        self.msssim = MSSSIMLoss(spatial_dims=spatial_dims, window_size=window_size, sigma=sigma, reduction="none", levels=levels, weights=weights)
        self.alpha = alpha
        self.reduction = reduction
        self.window_size = window_size
        self.sigma = sigma

    def forward(self, x, y):
        # Compute MSSSIM loss
        msssim_loss = self.msssim(x, y)

        # Compute L1 loss with Gaussian weighting
        # Generate Gaussian kernel based on the input size
        batch_size, _, height, width = x.size()
        gaussian = self.get_gaussian_weight(x.size()).to(x.device)

        # Apply the Gaussian kernel as a weight to the L1 loss
        l2_loss = F.mse_loss(x, y, reduction='none')
        l2_loss = l2_loss * gaussian

        # Sum or average the L1 loss based on the reduction
        if self.reduction == "mean":
            l2_loss = l2_loss.mean(dim=(1, 2, 3))  # Reduce over all spatial dimensions
        elif self.reduction == "sum":
            l2_loss = l2_loss.sum(dim=(1, 2, 3))   # Sum over all spatial dimensions
        else:
            l2_loss = l2_loss  # No reduction if 'none' is specified

        # Combine the two losses
        combined_loss = self.alpha * msssim_loss + (1 - self.alpha) * l2_loss

        # Apply final reduction to the combined loss
        if self.reduction == "mean":
            return combined_loss.mean()
        elif self.reduction == "sum":
            return combined_loss.sum()
        else:
            return combined_loss
        
    def get_gaussian_weight(self, size):
        """Generate a Gaussian weight tensor based on input size."""
        batch_size, channels, width, height = size
        sigma = width / 6.0  # Using width/6 as an approximate scale for sigma

        x = torch.arange(width, dtype=torch.float32, device='cuda')
        y = torch.arange(height, dtype=torch.float32, device='cuda')

        # Handle even-sized patches by adjusting the center position calculation
        center_x = (width - 1) / 2.0 if width % 2 == 1 else width / 2.0
        center_y = (height - 1) / 2.0 if height % 2 == 1 else height / 2.0

        # Explicitly pass the indexing argument
        x_grid, y_grid = torch.meshgrid(x, y, indexing='ij')

        gaussian = torch.exp(-((x_grid - center_x)**2 + (y_grid - center_y)**2) / (2 * sigma**2))
        gaussian /= gaussian.sum()  # Normalize the Gaussian

        gaussian_weight = gaussian.view(1, 1, width, height)
        gaussian_weight = gaussian_weight.expand(batch_size, channels, -1, -1)
        return gaussian_weight

In [None]:
msssim_l2_loss = MSSSIML2Loss()
output = torch.rand(10, 3, 64, 64).cuda()  # Example output with even dimensions
target = torch.rand(10, 3, 64, 64).cuda()  # Example target with even dimensions
loss = msssim_l2_loss(output, target)
print(loss)


tensor(0.0956, device='cuda:0')


### CrossEntropy and Dice Loss

In [None]:
#| export
class CrossEntropyLossFlat3D(CrossEntropyLossFlat):
    "Same as `nn.CrossEntropyLoss`, but flattens input and target for 3D inputs."
    def __call__(self, 
        inp: torchTensor, # Predictions (e.g., NCDHW or similar format)
        targ: torchTensor, # Targets
        **kwargs
    ) -> torchTensor:
        "Flatten spatial dimensions (DHW) and apply loss."
        inp = inp.permute(0, 2, 3, 4, 1).contiguous()  # Move class axis to the end
        targ = targ.contiguous()
        if self.flatten:
            inp = inp.view(-1, inp.shape[-1])
            targ = targ.view(-1)
        return self.func(inp, targ, **kwargs)

In [None]:
#| export

class DiceLoss(nn.Module):
    """
    DiceLoss computes the Sørensen–Dice coefficient loss, which is often used 
    for evaluating the performance of image segmentation algorithms.

    The Dice coefficient is a measure of overlap between two samples. It ranges 
    from 0 (no overlap) to 1 (perfect overlap). The Dice loss is computed as 
    1 - Dice coefficient, so it ranges from 1 (no overlap) to 0 (perfect overlap).

    Attributes:
        smooth (float): A smoothing factor to avoid division by zero and ensure numerical stability.

    Methods:
        forward(inputs, targets):
            Computes the Dice loss between the predicted probabilities (inputs) 
            and the ground truth (targets).
    """

    def __init__(self, smooth=1, # Smoothing factor to avoid division by zero
                 ):

        """
        Initializes the DiceLoss instance with a smoothing factor.

        """
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        
        # Make sure the inputs are probabilities
        inputs = sigmoid(inputs)

        # Flatten tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        # Calculate the intersection
        intersection = (inputs * targets).sum()

        # Compute Dice Coefficient
        dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)

        # Copmute dice loss
        loss = 1 - dice

        return loss
        

In [None]:
# inputs and targets must be equally dimensional tensors
from torch import randn, randint


In [None]:

inputs = randn((1, 1, 256, 256))  # Input
targets = randint(0, 2, (1, 1, 256, 256)).float()  # Ground Truth

# Initialize
dice_loss = DiceLoss()

# Compute loss
loss = dice_loss(inputs, targets)
print('Dice Loss:', loss.item())



Dice Loss: 0.4982335567474365


### Fourier Ring Correlation

In [None]:
#| export

def FRCLoss(image1,# The first input image.
            image2,# The second input image.
            ):

    """
    Compute the Fourier Ring Correlation (FRC) loss between two images.

    Returns:
        - torch.Tensor: The FRC loss.
    """
    
    return (1 - FRCMetric(image1, image2))
    

In [None]:
#| export
def FCRCutoff(image1,# The first input image.
             image2,# The second input image.
             ):


    """
    Calculate the cutoff frequency at when Fourier ring correlation drops to 1/7.

    Returns:
        - float: The cutoff frequency.
    """

    # Get y and x coordinates
    y, x = get_fourier_ring_correlations(image1, image2)

    # x -> frequency   y -> correlation
    x = x.numpy()
    y = y.numpy()


    # Exponential function to fit
    def exponential_func(x, a, b, c):
        return a * np.exp(-b * x) + c

    # Make fit
    params, _ = curve_fit(exponential_func, x, y, p0=[1, 1, 1])

    # Get Cutoff requency at 1/7
    cutoff_frequency = (exponential_func((1/7), *params))

    return cutoff_frequency

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()