In [None]:
#default_exp losses
from nbdev.showdoc import show_doc

# Losses

> Implements custom loss functions.

In [None]:
#hide
from fastcore.test import *
from fastai.torch_core import TensorImage, TensorMask
from deepflash2.utils import import_package

In [None]:
#export
import torch
import torch.nn.functional as F
from fastai.torch_core import TensorBase

## Weighted Softmax Cross Entropy Loss

as described by Falk, Thorsten, et al. "U-Net: deep learning for cell counting, detection, and morphometry." Nature methods 16.1 (2019): 67-70.


- `axis` for softmax calculations. Defaulted at 1 (channel dimension).
- `reduction` will be used when we call `Learner.get_preds`
- `activation` function will be applied on the raw output logits of the model when calling `Learner.get_preds` or `Learner.predict`
- `decodes` function converts the output of the model to a format similar to the target (here binary masks). This is used in `Learner.predict`

In [None]:
#export
class WeightedSoftmaxCrossEntropy(torch.nn.Module):
    "Weighted Softmax Cross Entropy loss functions"
    def __init__(self, *args, axis=-1, reduction = 'mean'):
        super().__init__()
        self.reduction = reduction
        self.axis = axis
    
    def _contiguous(self,x): return TensorBase(x.contiguous())
    def forward(self, inp, targ, weights):
    
        inp, targ  = map(self._contiguous, (inp, targ))
        # Weighted soft-max cross-entropy loss
        loss = F.cross_entropy(inp, targ, reduction='none')
        loss = loss * weights
        if  self.reduction == 'mean':
            return loss.mean()

        elif self.reduction == 'sum':
            return loss.sum()

        else:
            return loss

    def decodes(self, x): return x.argmax(dim=self.axis)
    def activation(self, x): return F.softmax(x, dim=self.axis)

In a segmentation task, we want to take the softmax over the channel dimension

In [None]:
torch.manual_seed(0)
tst = WeightedSoftmaxCrossEntropy(axis=1)
output = TensorImage(torch.randn(4, 5, 356, 356, requires_grad=True))
targets = TensorMask(torch.ones(4, 356, 356).long())
weights = torch.randn(4, 356, 356)
loss = tst(output, targets, weights)
test_eq(loss.detach().numpy(), -0.002415925730019808)
test_eq(tst.activation(output), F.softmax(output, dim=1))
test_eq(tst.decodes(output), output.argmax(dim=1))

## Kornia Segmentation Losses Integration

Helper functions to load segmentation losses from [kornia](https://github.com/kornia/kornia). 
Read the [docs](https://kornia.readthedocs.io/en/latest/losses.html#module) for a detailed explanation.

In [None]:
#export 
def load_kornia_loss(loss_name, alpha=0.5, beta=0.5, gamma=2.0, reduction='mean', eps = 1e-08):
    'Load segmentation_models_pytorch model'
    kornia = import_package('kornia')
    if loss_name=="DiceLoss": return kornia.losses.DiceLoss(eps=eps)
    elif loss_name=="TverskyLoss": 
        return kornia.losses.TverskyLoss(alpha=alpha, beta=beta, eps=eps)
    elif loss_name=="FocalLoss": 
        return kornia.losses.FocalLoss(alpha=alpha, gamma=gamma, reduction=reduction, eps=eps)
    else: raise NotImplementedError

In [None]:
output = TensorImage(torch.randn(4, 5, 356, 356, requires_grad=True))
targets = TensorMask(torch.ones(4, 356, 356).long())
tst = load_kornia_loss("TverskyLoss", alpha=0.5, beta=0.5) # equals dice loss
loss = tst(output, targets)
tst2 = load_kornia_loss("DiceLoss")
loss2 = tst2(output, targets)
test_eq(loss.detach().numpy(), loss2.detach().numpy())

tst3 = load_kornia_loss("FocalLoss")
loss3 = tst3(output, targets)

## Export -

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_learner.ipynb.
Converted 01_models.ipynb.
Converted 02_data.ipynb.
Converted 02a_transforms.ipynb.
Converted 03_metrics.ipynb.
Converted 04_callbacks.ipynb.
Converted 05_losses.ipynb.
Converted 06_utils.ipynb.
Converted 07_tta.ipynb.
Converted 08_gui.ipynb.
Converted 09_gt.ipynb.
Converted add_information.ipynb.
Converted deepflash2.ipynb.
Converted gt_estimation.ipynb.
Converted index.ipynb.
Converted model_library.ipynb.
Converted predict.ipynb.
Converted train.ipynb.
Converted tutorial.ipynb.
