# Losses and metrics


In [None]:
# hide
import sys
sys.path.append("..")
from nbdev.showdoc import *

In [None]:
# default_exp models.losses
# export 
from fastai.basics import *
import torchvision, torch
from warnings import warn

## Loss functions

### DICE Loss

In [None]:
# export
class DiceLossBinary():
    
    """
    Simple DICE loss as described in: 
        https://arxiv.org/pdf/1911.02855.pdf    
    
    Computes the Sørensen–Dice loss. Larger is better. 
    Note that PyTorch optimizers minimize a loss. So the loss is subtracted from 1. 
    
    Args:
        targ:    A tensor of shape [B, 1, D, H, W].
        pred:    A tensor of shape [B, 1, D, H, W]. Corresponds to
                 the raw output or logits of the model.
        method:  The method, how the DICE score should be calcualted. 
                    "simple"   = standard DICE loss
                    "miletari" = squared denominator for faster convergence
                    "tversky"  = variant of the DICE loss which allows to weight FP vs FN. 
        alpha, beta: weights for FP and FN for "tversky" loss, if both values are 0.5 the 
                 "tversky" loss corresponds to the "simple" DICE loss
        smooth:  Added smoothing factor. 
        eps: added to the denominator for numerical stability (acoid division by 0).
    Returns:
        dice_loss: the Sørensen–Dice loss.
    """
    
    def __init__(self, method = 'miletari', alpha = 0.5, beta = 0.5, eps = 1e-7, smooth = 1.) -> None:
        store_attr()
    
    def __call__(self, input: Tensor, target: Tensor) -> Tensor:
        if input.min() < 0 or input.max() > 1: 
            warn("Input is not in range between 0 and 1 but the loss will work better with input in that range. Consider rescaling your input. ")
           
        dims = (0,) + tuple(range(2, target.ndim))

        if self.method == 'simple':
            numerator  = torch.sum(input * target, dims) + self.smooth
            denominator  = torch.sum(input + target, dims) + self.smooth
            dice_loss = (2. * numerator / (denominator + self.eps))

        elif self.method == 'miletari':  
            numerator  = torch.sum(input * target, dims) + self.smooth
            denominator  = torch.sum(input**2 + target**2, dims) + self.smooth
            dice_loss = (2. * numerator / (denominator + self.eps))

        elif self.method == 'tversky':
            numerator  = torch.sum(input * target, dims) + self.smooth
            fps = torch.sum(input * (1 - target), dims)
            fns = torch.sum((1 - input) * target, dims)

            denominator  = numerator + self.alpha*fps + self.beta*fns + self.smooth
            dice_loss = (2. * numerator / (denominator + self.eps))
            
        else: 
            raise NotImplementedError('The specified type of DICE loss is not implemented')

        return 1-dice_loss 

In [None]:
# export 
class DiceLossMulti(DiceLossBinary):
    def __init__(self, n_classes, weights=None, **kwargs):
        store_attr()
        super().__init__(**kwargs)
    
    def __call__(self, input:Tensor, target:Tensor) -> Tensor:
        self.get_weights(input)

        if target.size(1) == 1:
            target = self.to_one_hot(target)
        elif target.size(1) != input.size(1):
            raise ValueError("Number of Channels between input and target do not match."
                             "Expected target to have 1 or {} channels but got {}".format(input.size(1), target.size(1)))
        input = self.activation(input)
        return torch.mean(super().__call__(input, target)*self.weights)
    
    def get_weights(self, target):
        if self.weights == 'auto':
            "estimates weights from the percentage distribution of a finding."
            dims = (0,) + tuple(range(2, target.ndim))
            self.weights = 1/torch.mean(target, dims)
        elif self.weights == None:
            self.weights = 1.
        elif isinstance(self.weights, (tuple, list)):
            self.weights = tensor(self.weights)

    def make_binary(self, t, set_to_one):
        return (t == set_to_one).float()

    def to_one_hot(self, target:Tensor):
        target = target.squeeze(1).long() # remove the solitary color channel (if there is one) and set type to int64
        one_hot = [self.make_binary(target, set_to_one=i) for i in range(0, self.n_classes)]

        return torch.stack(one_hot, 1)

    def activation(self, input):
        return F.softmax(input, 1)
        

### MCC Loss

Implementing the MCC score as loss function:  


$$\frac{ 
        \sum_{i}^{n} p_{ i }g_{ i } * \sum_{i}^{n}  1-p_{ i } 1-g_{ i } +
        \sum_{i}^{n}  1-p_{ i } g_{ i } * \sum_{i}^{n}  p_{ i } 1-g_{ i }}{ \sqrt{ 
        (\sum_{i}^{n}  p_{ i } g_{ i } + \sum_{i}^{n}  1-p_{ i } g_{ i }) * 
        (\sum_{i}^{n}  p_{ i } g_{ i } + \sum_{i}^{n} p_{ i } 1-g_{ i }) *  
        (\sum_{i}^{n}  1-p_{ i } g_{ i } + \sum_{i}^{n} 1-p_{ i } 1-g_{ i }) * 
        (\sum_{i}^{n}  p_{ i } 1-g_{ i } + \sum_{i}^{n} 1-p_{ i } 1-g_{ i }) 
     } }$$

where p_i is the prediction for pixel i and g_i the corresponding ground truth pixel and gamma is the smoothing factor.

In [None]:
# export
class MCCLossBinary(DiceLossBinary):
    
    """
    Computes the MCC loss. 
    
    From Wikipedia (https://en.wikipedia.org/wiki/Matthews_correlation_coefficient):
        > The coefficient takes into account true and false positives and negatives and is generally 
        > regarded as a balanced measure which can be used even if the classes are of very different sizes
        > The MCC is in essence a correlation coefficient between the observed 
        > and predicted binary classifications; it returns a value between −1 and +1. 
        > A coefficient of +1 represents a perfect prediction, 0 no better than random prediction
        > and −1 indicates total disagreement between prediction and observation    
    
    For this loss to work best, the input should be in range 0-1, e.g. enforced through a sigmoid or softmax. 
    Note that PyTorch optimizers minimize a loss. So the loss is subtracted from 1. 

    Math: 
        \frac{ 
            \sum_{i}^{n} p_{ i }g_{ i } * \sum_{i}^{n}  1-p_{ i } 1-g_{ i } +
            \sum_{i}^{n}  1-p_{ i } g_{ i } * \sum_{i}^{n}  p_{ i } 1-g_{ i }}{ \sqrt{ 
            (\sum_{i}^{n}  p_{ i } g_{ i } + \sum_{i}^{n}  1-p_{ i } g_{ i }) * 
            (\sum_{i}^{n}  p_{ i } g_{ i } + \sum_{i}^{n} p_{ i } 1-g_{ i }) *  
            (\sum_{i}^{n}  1-p_{ i } g_{ i } + \sum_{i}^{n} 1-p_{ i } 1-g_{ i }) * 
            (\sum_{i}^{n}  p_{ i } 1-g_{ i } + \sum_{i}^{n} 1-p_{ i } 1-g_{ i }) 
         } }

    Args:
        input:   A tensor of shape [B, 1, D, H, W]. Predictions. 
        target:  A tensor of shape [B, 1, D, H, W]. Ground truth. 
        smooth:  Smoothing factor, default is 1. Inherited from DiceLossBinary base class 
        eps:     Added for numerical stability.
    Returns:
        mmc_loss: loss based on Matthews correlation coefficient
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    
    def __call__(self, input: Tensor, target: Tensor) -> Tensor:         
        return self.compute_loss(input, target)
    
    def activation(self, input):
        return torch.sigmoid(input)
       
    def compute_loss(self, input: Tensor, target: Tensor):
        
        dims = (0,) + tuple(range(2, target.ndim))
        
        tps = torch.sum(self.activation(input) * target, dims) 
        fps = torch.sum(self.activation(input) * (1 - target), dims)
        fns = torch.sum((1 - self.activation(input)) * target, dims)
        tns = torch.sum((1 - self.activation(input)) * (1-target), dims)
            
        numerator = (tps * tns - fps * fns) + self.smooth 
        denominator =  ((tps + fps) * (tps + fns) * (fps + tns) * (tns + fns) + self.eps)**0.5 + self.smooth
        
        mcc_loss = numerator / (denominator)
        
        return 1-mcc_loss

In [None]:
i = torch.randn(5,1,2,25,25).sigmoid()
t = torch.randn(5,1,2,25,25).sigmoid().round()

MCCLossBinary()(i, t)

tensor([1.0011])

In [None]:
#export
class MCCLossMulti(MCCLossBinary):
    
    """
    Computes the MCC loss for a multilabel target. Basically the same as `MCCLossBinary` 
    but one hot encodes the target before computation. 
    
    Args:
        num_features: Number of different features in y. 
                 Must correspond to the maximum number of overall features in the whole dataset.
        input:   A tensor of shape [B, C, D, H, W], where the `n_classes` should correspond to C.
        target:  A tensor of shape [B, 1, D, H, W] or [B, C, D, H, W] where C is the same size as in the input.  
        weights: Either a str: 'auto' for autocalculation, None or a list/tuple of soecified weights
        smooth:  Smoothing factor, default is 1. Inherited from DiceLossBinary base class 
        eps:     Added for numerical stability.
        n_classes: number of classes to predict
    Returns:
        mcc_loss: loss based on Matthews correlation coefficient
    
    """
    def __init__(self, n_classes, weights=None, **kwargs):
        store_attr()
        super().__init__(**kwargs)
        
    def __call__(self, input: Tensor, target: Tensor) -> Tensor:
        
        self.get_weights(input)
        
        if target.size(1) == 1:
            target = self.to_one_hot(target)
        elif target.size(1) != input.size(1):
            raise ValueError("Number of Channels between input and target do not match."
                             "Expected target to have 1 or {} channels but got {}".format(input.size(1), target.size(1)))
            
        return torch.mean(super().__call__(input, target)*self.weights)
    
    def get_weights(self, target):
        if self.weights == 'auto': 
            "estimates weights from the percentage distribution of a finding."
            dims = (0,) + tuple(range(2, target.ndim))
            self.weights = 1/torch.mean(target, dims)
        elif self.weights == None: 
            self.weights = 1.
        elif isinstance(self.weights, (tuple, list)):
            self.weights = tensor(self.weights)
    
    def make_binary(self, t, set_to_one):
        return (t == set_to_one).float()
  
    def to_one_hot(self, target:Tensor):
        target = target.squeeze(1).long() # remove the solitary color channel (if there is one) and set type to int64
        one_hot = [self.make_binary(target, set_to_one=i) for i in range(0, self.n_classes)]
        return torch.stack(one_hot, 1)
    
    def activation(self, input): 
        return F.softmax(input, 1) 

In [None]:
i = torch.randn(5,5,2,25,25)
t = torch.randint(0, 5, (5,1,2,25,25))

MCCLossMulti(5)(i, t)

tensor(1.0020)

In [None]:
# export
class SoftMCCLossMulti(MCCLossMulti):
    """
    Same as MCCLossMulti but can handle float values. 
    Example: 
        t = torch.randn(2,5); t
        >>> tensor([[ 0.9113, -0.7525, -2.1771, -0.2420, -0.2245],
                    [ 1.9503, -1.2903,  0.1201,  0.2830,  0.0473]])
                   
        MCCLossMulti(2).make_binary(t, 1)
        >>> tensor([[0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0.]])
        
        SoftMCCLossMulti(2).soft_binary(t, 0)
        >>> tensor([[0.9113, 0.0000, 0.0000, 0.0000, 0.0000],
                    [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
    """
    
    def soft_binary(self, t, set_to_one):
        return torch.where(t.gt(set_to_one - 0.49) != t.gt(set_to_one + 0.49), 
                           t.float(), 
                           tensor(0.).to(t.device) if set_to_one > 0 else tensor(1.).to(t.device))
    
    def to_one_hot(self, target:Tensor):
        target = target.squeeze(1) # remove the solitary color channel (if there is one) and set type to int64
        one_hot = [self.soft_binary(target, set_to_one=i) for i in range(0, self.n_classes)]
        return torch.stack(one_hot, 1)

In [None]:
SoftMCCLossMulti(5)(i, t)

tensor(1.0059)

In [None]:
# export
class WeightedMCCLossMulti(MCCLossMulti):
    """
    Weighted version of `MCCLossMulti`. 
    Note that class specific weight can still be added through `weights` during initialization. 
    
    Args: 
        alpha: weight for true positives
        beta: weight for false positives
        gamma: weight for false negatives
        delta: weight for true negatives
    """

    def __init__(self, gamma=0.5, delta=0.5,*args, **kwargs):
        "alpha and beta are already inherited from `DiceLossBinary`"
        store_attr()
        super().__init__(*args, **kwargs)

    def compute_loss(self, input: Tensor, target: Tensor):

        dims = (0,) + tuple(range(2, target.ndim))

        tps = torch.sum(self.activation(input) * target, dims) 
        fps = torch.sum(self.activation(input) * (1 - target), dims) 
        fns = torch.sum((1 - self.activation(input)) * target, dims)
        tns = torch.sum((1 - self.activation(input)) * (1-target), dims)

        numerator = (tps * tns - fps * fns) + self.smooth
        denominator =  ((tps * self.alpha + fps * self.beta) * (tps * self.alpha + fns * self.gamma) * (fps * self.beta + tns * self.delta) * (tns * self.delta + fns * self.gamma) + self.eps)**0.5 + self.smooth

        mcc_loss = numerator / (denominator)

        return 1-mcc_loss
    
    def activation(self, x): 
        return x

# Metrics

In [None]:
# export
class MCCScore(MCCLossMulti):
    def __init__(self, n_classes = None, thres = 0.5, **kwargs):
        super().__init__(n_classes, **kwargs)
        
        self.n_classes = 1 if n_classes is None else n_classes
        self.thres = thres
    
    def __call__(self, input:Tensor , target: Tensor):
        if self.n_classes is not None: 
            target = self.to_one_hot(target)
            
            return 1-torch.mean(super().__call__(input, target))
        
    def activation(self, input): 
        return (input > self.thres).float()

In [None]:
MCCScore()(i, t)

tensor(0.0062)

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 01-basics.ipynb.
Converted 02-transforms.ipynb.
Converted 03-datablock.ipynb.
Converted 04-datasets.ipynb.
Converted 05-models-all.ipynb.
Converted 05-models-losses-and-metrics-Copy1.ipynb.
Converted 05a-models-modules.ipynb.
Converted 05b-models-unet.ipynb.
Converted 05c-models-losses.ipynb.
Converted 06-callbacks.ipynb.
Converted 06-various-tools.ipynb.
