In [None]:
#default_exp losses
from nbdev.showdoc import show_doc

# Losses

> Implements popular segmentation loss functions.

In [None]:
#hide
from fastcore.test import *
from fastai.torch_core import TensorImage, TensorMask
from fastai.losses import CrossEntropyLossFlat

In [None]:
#export
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
import fastai
from fastai.torch_core import TensorBase
import segmentation_models_pytorch as smp
from deepflash2.utils import import_package

Losses implemented here:

In [None]:
#export
LOSSES = ['CrossEntropyLoss', 'DiceLoss', 'SoftCrossEntropyLoss', 'CrossEntropyDiceLoss',  'JaccardLoss', 'FocalLoss', 'LovaszLoss', 'TverskyLoss']

## Loss Wrapper functions

Wrapper for handling different tensor types from [fastai](https://docs.fast.ai/torch_core.html#TensorBase).

In [None]:
#export 
class FastaiLoss(_Loss):
    'Wrapper class around loss function for handling different tensor types.'
    def __init__(self, loss, axis=1):
        super().__init__()
        self.loss = loss
        self.axis=axis
        
    #def _contiguous(self, x): return TensorBase(x.contiguous())
    def _contiguous(self,x):
        return TensorBase(x.contiguous()) if isinstance(x,torch.Tensor) else x
    
    def forward(self, *input):
        #input = map(self._contiguous, input)        
        input = [self._contiguous(x) for x in input]
        return self.loss(*input) #

Wrapper for combining different losses, adapted from from [pytorch-toolbelt](https://github.com/BloodAxe/pytorch-toolbelt)

In [None]:
#export 
# from https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/losses/joint_loss.py
class WeightedLoss(_Loss):
    '''
    Wrapper class around loss function that applies weighted with fixed factor.
    This class helps to balance multiple losses if they have different scales
    '''
    def __init__(self, loss, weight=1.0):
        super().__init__()
        self.loss = loss
        self.weight = weight

    def forward(self, *input):
        return self.loss(*input) * self.weight

class JointLoss(_Loss):
    'Wrap two loss functions into one. This class computes a weighted sum of two losses.'

    def __init__(self, first: nn.Module, second: nn.Module, first_weight=1.0, second_weight=1.0):
        super().__init__()
        self.first = WeightedLoss(first, first_weight)
        self.second = WeightedLoss(second, second_weight)

    def forward(self, *input):
        return self.first(*input) + self.second(*input)

## Popular segmentation losses

The `get_loss()` function loads popular segmentation losses from [Segmenation Models Pytorch](https://github.com/qubvel/segmentation_models.pytorch) and [kornia](https://kornia.readthedocs.io/en/latest/losses.html#semantic-segmentation): 
- (Soft) CrossEntropy Loss
- Dice Loss
- Jaccard Loss
- Focal Loss
- Lovasz Loss
- TverskyLoss

In [None]:
#export 
def get_loss(loss_name, mode='multiclass', classes=[1], smooth_factor=0., alpha=0.5, beta=0.5, gamma=2.0, reduction='mean', **kwargs):
    'Load losses from based on loss_name'
    
    assert loss_name in LOSSES, f'Select one of {LOSSES}'
       
    if loss_name=="CrossEntropyLoss": 
        loss = fastai.losses.CrossEntropyLossFlat(axis=1) 
        
    if loss_name=="SoftCrossEntropyLoss":   
        loss = smp.losses.SoftCrossEntropyLoss(smooth_factor=smooth_factor, **kwargs)

    elif loss_name=="DiceLoss": 
        loss = smp.losses.DiceLoss(mode=mode, classes=classes, **kwargs)

    elif loss_name=="JaccardLoss": 
        loss = smp.losses.JaccardLoss(mode=mode, classes=classes, **kwargs)

    elif loss_name=="FocalLoss": 
        loss = smp.losses.FocalLoss(mode=mode, alpha=alpha, gamma=gamma, reduction=reduction, **kwargs)

    elif loss_name=="LovaszLoss": 
        loss = smp.losses.LovaszLoss(mode=mode, **kwargs)

    elif loss_name=="TverskyLoss": 
        kornia = import_package('kornia')
        loss = kornia.losses.TverskyLoss(alpha=alpha, beta=beta, **kwargs)

    elif loss_name=="CrossEntropyDiceLoss":
        dc = smp.losses.DiceLoss(mode=mode, classes=classes, **kwargs)
        ce = fastai.losses.CrossEntropyLossFlat(axis=1)
        loss = JointLoss(ce, dc, 1, 1)
        
    return loss

In [None]:
#Test if all losses are running
n_classes = 2
#output = TensorImage(torch.randn(4, n_classes, 356, 356, requires_grad=True))
#target = TensorMask(torch.randint(0, n_classes, (4, 356, 356)))
output = torch.randn(4, n_classes, 356, 356, requires_grad=True)
target = torch.randint(0, n_classes, (4, 356, 356))
for loss_name in LOSSES:
    print(f'Testing {loss_name}')
    tst = get_loss(loss_name, classes=list(range(1,n_classes))) 
    loss = tst(output, target)

Testing CrossEntropyLoss
Testing DiceLoss
Testing SoftCrossEntropyLoss
Testing CrossEntropyDiceLoss
Testing JaccardLoss
Testing FocalLoss
Testing LovaszLoss
Testing TverskyLoss


In [None]:
#Compare soft cross entropy loss with smooth_factor=0 to (fastai) cross entropy 
ce1 = get_loss('SoftCrossEntropyLoss', smooth_factor=0)
ce2 = CrossEntropyLossFlat(axis=1)
test_close(ce1(output, target), ce2(output, target), eps=1e-04)

In [None]:
#Compare soft cross entropy loss with smooth_factor=0 to cross entropy 
jc = get_loss('JaccardLoss')
dc = get_loss('DiceLoss')
dc_loss = dc(output, target)
dc_to_jc = 2*dc_loss/(dc_loss+1) #it seems to be the other way around?
test_close(jc(output, target), dc_to_jc, eps=1e-02)

In [None]:
#Compare TverskyLoss with alpha=0.5 and beta=0.5 to dice loss, should be equal
tw = get_loss("TverskyLoss", alpha=0.5, beta=0.5)
test_close(dc(output, target), tw(output, target), eps=1e-02)

In [None]:
#Temporaty test if classes are working
output = torch.randn(4, n_classes, 356, 356)
output[:,1,...] = 0.5
tst = get_loss(loss_name='DiceLoss', classes=None) 
tst2 = get_loss(loss_name='DiceLoss', classes=list(range(1,n_classes))) 
test_ne(tst(output, target), tst2(output, target))

## Export -

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_learner.ipynb.
Converted 01_models.ipynb.
Converted 02_data.ipynb.
Converted 03_metrics.ipynb.
Converted 05_losses.ipynb.
Converted 06_utils.ipynb.
Converted 07_tta.ipynb.
Converted 08_gui.ipynb.
Converted 09_gt.ipynb.
Converted add_information.ipynb.
Converted deepflash2.ipynb.
Converted gt_estimation.ipynb.
Converted index.ipynb.
Converted model_library.ipynb.
Converted predict.ipynb.
Converted train.ipynb.
Converted tutorial.ipynb.
