In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
class DiceLoss(nn.Module):
    """https://github.com/pytorch/pytorch/issues/1249"""
    def __init__(self, reduction="mean", eps=1e-6):
        super().__init__()
        self.eps = eps
        if reduction.lower() == "none":
            self.reduction_op = None
        elif reduction.lower() == "mean":
            self.reduction_op = torch.mean
        elif reduction.lower() == "sum":
            self.reduction_op = torch.sum
        else:
            raise ValueError("expected one of ('none', 'mean', 'sum'), got {}".format(reduction))
        
    def forward(self, input, target):
        if input.dim() != 2 and input.dim() != 4:
            raise ValueError("expected input of size 4 or 2, got {}".format(input.dim()))
            
        if target.dim() != 1 and target.dim() != 3:
            raise ValueError("expected target of size 3 or 1, got {}".format(target.dim()))
        
        if input.dim() == 4 and target.dim() == 3:
            reduce_dims = (3, 2)
        elif input.dim() == 2 and target.dim() == 1:
            reduce_dims = (1)
        else:
            raise ValueError("expected target dimension {} for input dimension {}, got {}".format(input.dim() - 1, input.dim(), target.dim()))
            
        target_onehot = onehot(target, input.size(1))
        probabilities = nn.functional.softmax(input, 1)
        num = torch.sum(target_onehot * probabilities, dim=reduce_dims)
        den_t = torch.sum(target_onehot * target_onehot, dim=reduce_dims)
        den_p = torch.sum(probabilities * probabilities, dim=reduce_dims)
        loss = -2 * (num / (den_t + den_p + self.eps))        
        
        if self.reduction_op is not None:
            return self.reduction_op(loss)
        else:
            return loss

In [3]:
def onehot(tensor, num_classes):    
    tensor = tensor.unsqueeze(1)
    onehot = torch.zeros(tensor.size(0), num_classes, *tensor.size()[2:])
    onehot.scatter_(1, tensor, 1)
    
    return onehot

In [4]:
loss = DiceLoss()

target = torch.Tensor([1]).long()
out = torch.Tensor([[-100, 100, -100, -50]]).float()
print("Target:\n", target)
print("Model out:\n", out)
print("Loss:\n", loss.forward(out, target))

Target:
 tensor([1])
Model out:
 tensor([[-100.,  100., -100.,  -50.]])
Loss:
 tensor(-1.0000)


In [5]:
target = torch.Tensor([1, 0, 0]).long()
out = torch.Tensor([[-100, 100, -100, -50], [-100, -100, -100, 50], [100, -100, -100, -50]]).float()
print("Target:\n", target)
print("Model out:\n", out)
print("Loss:\n", loss.forward(out, target))

Target:
 tensor([1, 0, 0])
Model out:
 tensor([[-100.,  100., -100.,  -50.],
        [-100., -100., -100.,   50.],
        [ 100., -100., -100.,  -50.]])
Loss:
 tensor(-0.6667)


In [6]:
target = torch.randint(3, (2, 5, 5)).long()
out = torch.randint(3, (2, 5, 5)).long()
out = onehot(out, 3) * 100
print("Target:\n", target.size())
print("Model out:\n", out.size())
print("Loss:\n", loss.forward(out, target))

Target:
 torch.Size([2, 5, 5])
Model out:
 torch.Size([2, 3, 5, 5])
Loss:
 tensor(-0.2780)


In [7]:
target = torch.randint(3, (2, 5, 5)).long()
out = onehot(target, 3).float() * 100
print("Target:\n", target.size())
print("Model out:\n", out.size())
print("Loss:\n", loss.forward(out, target))

Target:
 torch.Size([2, 5, 5])
Model out:
 torch.Size([2, 3, 5, 5])
Loss:
 tensor(-1.0000)


In [11]:
target = torch.randint(3, (2, 5, 5)).long()
out = onehot(target, 3).float() * 100
%timeit loss.forward(out, target)

The slowest run took 282.25 times longer than the fastest. This could mean that an intermediate result is being cached.
8.76 ms ± 7.48 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
