In [1]:
import torch
import torch.nn as nn
import numpy as np

# Binary dice loss

In [2]:
class BinaryDiceWithLogitsLoss(nn.Module):
    """Computes the Sørensen–Dice loss with logits for binary data.

    Dice_coefficient = 2 * intersection(X, Y) / (|X| + |Y|)
    where, X and Y are sets of binary data, in this case, predictions and targets.
    |X| and |Y| are the cardinalities of the corresponding sets.

    The optimizer minimizes the loss function therefore:
    Dice_loss = -Dice_coefficient
    (min(-x) = max(x))

    See: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient

    Arguments:
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
            'mean': the sum of the output will be divided by the number of
            elements in the output, 'sum': the output will be summed. Default: 'mean'
        eps (float, optional): small value to avoid division by zero. Default: 1e-6.

    """

    def __init__(self, reduction="mean", eps=1e-6):
        super().__init__()
        self.eps = eps
        if reduction.lower() == "none":
            self.reduction_op = None
        elif reduction.lower() == "mean":
            self.reduction_op = torch.mean
        elif reduction.lower() == "sum":
            self.reduction_op = torch.sum
        else:
            raise ValueError(
                "expected one of ('none', 'mean', 'sum'), got {}".format(reduction)
            )

    def forward(self, input, target):
        if input.size() != target.size():
            raise ValueError(
                "size mismatch, {} != {}".format(input.size(), target.size())
            )
        elif target.unique(sorted=True).tolist() not in [[0, 1], [0], [1]]:
            raise ValueError("target values are not binary")

        input = input.view(-1)
        target = target.view(-1)

        # Dice = 2 * intersection(X, Y) / (|X| + |Y|)
        # X and Y are sets of binary data, in this case, probabilities and targets
        # |X| and |Y| are the cardinalities of the corresponding sets
        probabilities = torch.sigmoid(input)
        num = torch.sum(target * probabilities)
        den_t = torch.sum(target)
        den_p = torch.sum(probabilities)
        loss = -2 * (num / (den_t + den_p + self.eps))

        if self.reduction_op is not None:
            return self.reduction_op(loss)
        else:
            return loss

In [3]:
loss = BinaryDiceWithLogitsLoss()

target = torch.Tensor([1])
out = torch.Tensor([2.2])
print("Target:\n", target)
print("Model out:\n", out)
print("BD Loss:\n", loss.forward(out, target))
print("BCE Loss:\n", nn.functional.binary_cross_entropy_with_logits(out, target))

Target:
 tensor([1.])
Model out:
 tensor([2.2000])
BD Loss:
 tensor(-0.9475)
BCE Loss:
 tensor(0.1051)


In [4]:
target = torch.Tensor([1])
out = torch.Tensor([3.43])
print("Target:\n", target)
print("Model out:\n", out)
print("BD Loss:\n", loss.forward(out, target))
print("BCE Loss:\n", nn.functional.binary_cross_entropy_with_logits(out, target))

Target:
 tensor([1.])
Model out:
 tensor([3.4300])
BD Loss:
 tensor(-0.9841)
BCE Loss:
 tensor(0.0319)


In [5]:
target = torch.Tensor([1, 0])
out = torch.Tensor([100, -50])
print("Target:\n", target)
print("Model out:\n", out)
print("Loss:\n", loss.forward(out, target))
print("BCE Loss:\n", nn.functional.binary_cross_entropy_with_logits(out, target))

Target:
 tensor([1., 0.])
Model out:
 tensor([100., -50.])
Loss:
 tensor(-1.0000)
BCE Loss:
 tensor(0.)


In [6]:
target = torch.Tensor([1, 0, 0, 0, 1])
out = torch.Tensor([-5, -2.5, -6, -10, -2])
print("Target:\n", target)
print("Model out:\n", out)
print("Loss:\n", loss.forward(out, target))
print("BCE Loss:\n", nn.functional.binary_cross_entropy_with_logits(out, target))

Target:
 tensor([1., 0., 0., 0., 1.])
Model out:
 tensor([ -5.0000,  -2.5000,  -6.0000, -10.0000,  -2.0000])
Loss:
 tensor(-0.1142)
BCE Loss:
 tensor(1.4430)


In [7]:
target = torch.randint(2, (2, 5, 5))
out = torch.randint(2, (2, 5, 5)) * 3.44
print("Target:\n", target.size())
print("Model out:\n", out.size())
print("Loss:\n", loss.forward(out, target))
print("BCE Loss:\n", nn.functional.binary_cross_entropy_with_logits(out, target))

Target:
 torch.Size([2, 5, 5])
Model out:
 torch.Size([2, 5, 5])
Loss:
 tensor(-0.5891)
BCE Loss:
 tensor(1.3547)


In [8]:
target = torch.randint(2, (2, 5, 5)).float()
out = (target * 100) - 50
print("Target:\n", target.size())
print("Model out:\n", out.size())
print("Loss:\n", loss.forward(out, target))
print("BCE Loss:\n", nn.functional.binary_cross_entropy_with_logits(out, target))

Target:
 torch.Size([2, 5, 5])
Model out:
 torch.Size([2, 5, 5])
Loss:
 tensor(-1.)
BCE Loss:
 tensor(0.)


In [9]:
target = torch.randint(2, (10, 2048, 2048))
out = target * 100
%timeit loss.forward(out, target)

871 ms ± 12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Multi-class dice loss

In [10]:
class DiceWithLogitsLoss(nn.Module):
    """Computes the Sørensen–Dice loss with logits.

    Dice_coefficient = 2 * intersection(X, Y) / (|X| + |Y|)
    where, X and Y are sets of binary data, in this case, predictions and targets.
    |X| and |Y| are the cardinalities of the corresponding sets.

    The optimizer minimizes the loss function therefore:
    Dice_loss = -Dice_coefficient
    (min(-x) = max(x))

    See: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient

    Arguments:
        num_classes (int): number of classes in the classification problem
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
            'mean': the sum of the output will be divided by the number of
            elements in the output, 'sum': the output will be summed. Default: 'mean'
        eps (float, optional): small value to avoid division by zero. Default: 1e-6.

    """

    def __init__(self, reduction="mean", eps=1e-6):
        super().__init__()
        self.eps = eps
        if reduction.lower() == "none":
            self.reduction_op = None
        elif reduction.lower() == "mean":
            self.reduction_op = torch.mean
        elif reduction.lower() == "sum":
            self.reduction_op = torch.sum
        else:
            raise ValueError(
                "expected one of ('none', 'mean', 'sum'), got {}".format(reduction)
            )

    def forward(self, input, target):
        if input.dim() != 2 and input.dim() != 4:
            raise ValueError(
                "expected input of size 4 or 2, got {}".format(input.dim())
            )

        if target.dim() != 1 and target.dim() != 3:
            raise ValueError(
                "expected target of size 3 or 1, got {}".format(target.dim())
            )

        if input.dim() == 4 and target.dim() == 3:
            reduce_dims = (0, 3, 2)
        elif input.dim() == 2 and target.dim() == 1:
            reduce_dims = 0
        else:
            raise ValueError(
                "expected target dimension {} for input dimension {}, got {}".format(
                    input.dim() - 1, input.dim(), target.dim()
                )
            )

        target_onehot = to_onehot(target, input.size(1))
        probabilities = nn.functional.softmax(input, 1)

        # Dice = 2 * intersection(X, Y) / (|X| + |Y|)
        # X and Y are sets of binary data, in this case, probabilities and targets
        # |X| and |Y| are the cardinalities of the corresponding sets
        num = torch.sum(target_onehot * probabilities, dim=reduce_dims)
        den_t = torch.sum(target_onehot, dim=reduce_dims)
        den_p = torch.sum(probabilities, dim=reduce_dims)
        loss = -2 * (num / (den_t + den_p + self.eps))

        if self.reduction_op is not None:
            return self.reduction_op(loss)
        else:
            return loss

In [11]:
def to_onehot(tensor, num_classes):    
    tensor = tensor.unsqueeze(1)
    onehot = torch.zeros(tensor.size(0), num_classes, *tensor.size()[2:])
    onehot.scatter_(1, tensor, 1)
    
    return onehot

In [12]:
loss = DiceWithLogitsLoss()

target = torch.Tensor([1]).long()
out = torch.Tensor([[-100, 100, -100, -50]]).float()
print("Target:\n", target)
print("Model out:\n", out)
print("Loss:\n", loss.forward(out, target))

Target:
 tensor([1])
Model out:
 tensor([[-100.,  100., -100.,  -50.]])
Loss:
 tensor(-0.2500)


In [13]:
target = torch.Tensor([1, 0, 0]).long()
out = torch.Tensor([[-100, 100, -100, -50], [-100, -100, -100, 50], [100, -100, -100, -50]]).float()
print("Target:\n", target)
print("Model out:\n", out)
print("Loss:\n", loss.forward(out, target))

Target:
 tensor([1, 0, 0])
Model out:
 tensor([[-100.,  100., -100.,  -50.],
        [-100., -100., -100.,   50.],
        [ 100., -100., -100.,  -50.]])
Loss:
 tensor(-0.4167)


In [14]:
target = torch.randint(3, (2, 5, 5)).long()
out = torch.randint(3, (2, 5, 5)).long()
out = to_onehot(out, 3) * 100
print("Target:\n", target.size())
print("Model out:\n", out.size())
print("Loss:\n", loss.forward(out, target))

Target:
 torch.Size([2, 5, 5])
Model out:
 torch.Size([2, 3, 5, 5])
Loss:
 tensor(-0.3683)


In [15]:
target = torch.randint(3, (2, 5, 5)).long()
out = to_onehot(target, 3).float() * 100
print("Target:\n", target.size())
print("Model out:\n", out.size())
print("Loss:\n", loss.forward(out, target))

Target:
 torch.Size([2, 5, 5])
Model out:
 torch.Size([2, 3, 5, 5])
Loss:
 tensor(-1.)


In [16]:
target = torch.randint(3, (10, 2048, 2048)).long()
out = to_onehot(target, 3).float() * 100
%timeit loss.forward(out, target)

6.28 s ± 109 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
