In [38]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
class DiceLossA(nn.Module):

    def __init__(self, eps=1e-6) -> None:
        super().__init__()
        self.eps = eps

    def forward(  # type: ignore
        self,
        output,
        target
    ):
        if not torch.is_tensor(output):
            raise TypeError("Input type is not a torch.Tensor. Got {}"
                            .format(type(output)))
        if not len(output.shape) == 4:
            raise ValueError("Invalid input shape, we expect BxNxHxW. Got: {}"
                             .format(output.shape))
        if not output.shape[-2:] == target.shape[-2:]:
            raise ValueError("input and target shapes must be the same. Got: {}"
                             .format(output.shape, output.shape))
        if not output.device == target.device:
            raise ValueError(
                "input and target must be in the same device. Got: {}" .format(
                    output.device, target.device))
        # compute softmax over the classes axis
        output = torch.sigmoid(output)

        # compute the actual dice score
        dims = (1, 2, 3)
        intersection = torch.sum(output * target, dims)
        cardinality = torch.sum(output + target, dims)

        dice_score = 2. * intersection / (cardinality + self.eps)
        return torch.mean(torch.tensor(1.) - dice_score)

In [13]:
from functools import partial 

class DiceLossB(nn.Module):
    def __init__(self, eps: float = 1e-7, threshold: float = None):
        super().__init__()

        self.loss_fn = partial(dice, eps=eps, threshold=threshold)

    def forward(self, logits, targets):
        dice = self.loss_fn(logits, targets)
        return 1 - dice
    
def dice(
    outputs: torch.Tensor,
    targets: torch.Tensor,
    eps: float = 1e-7,
    threshold: float = None,
):
    """
    Computes the dice metric
    Args:
        outputs (list):  A list of predicted elements
        targets (list): A list of elements that are to be predicted
        eps (float): epsilon
        threshold (float): threshold for outputs binarization
    Returns:
        double:  Dice score
    """
    outputs = torch.sigmoid(outputs)

    if threshold is not None:
        outputs = (outputs > threshold).float()

    intersection = torch.sum(targets * outputs)
    union = torch.sum(targets) + torch.sum(outputs)
    dice = 2 * intersection / (union + eps)

    return dice

In [20]:
def dice_classwise(
    outputs: torch.Tensor,
    targets: torch.Tensor,
    eps: float = 1e-7,
    threshold: float = None,
):
    """
    Computes the dice metric for each class, and averages them together
    Args:
        outputs (list):  A list of predicted elements
        targets (list): A list of elements that are to be predicted
        eps (float): epsilon
        threshold (float): threshold for outputs binarization
    Returns:
        double:  Dice score
    """
    outputs = torch.sigmoid(outputs)

    if threshold is not None:
        outputs = (outputs > threshold).float()

    B, C, H, W = outputs.size()
    for i, c in enumerate(range(C)):
        intersection = torch.sum(targets[:, c, :, :] * outputs[:, c, :, :])
        union = torch.sum(targets[:, c, :, :]) + torch.sum(outputs[:, c, :, :])
        if i == 0:
            dice = 2 * intersection / (union + eps)
        else:
            dice += 2 * intersection / (union + eps)
    return dice / C

In [30]:
outputs = torch.rand(8, 4, 32, 32)
targets = torch.ones(8, 4, 32, 32)

In [35]:
l1 = DiceLossA()(outputs, targets)
l1

tensor(0.2347)

In [36]:
l2 = DiceLossB()(outputs, targets)
l2

tensor(0.2347)

In [37]:
l3 = dice_classwise(outputs, targets)
1 - l3

tensor(0.2347)