In [None]:
from nbdev import *
%nbdev_default_export losses

Cells will be exported to deepflash2.losses,
unless a different module is specified after an export flag: `%nbdev_export special.module`


# Losses

> Implements custom loss functions.

In [None]:
%nbdev_hide
from nbdev.showdoc import *
from fastcore.test import *
from fastai.torch_core import TensorImage, TensorMask
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
%nbdev_export
import torch
import torch.nn.functional as F

## Weighted Softmax Cross Entropy Loss

as described by Falk, Thorsten, et al. "U-Net: deep learning for cell counting, detection, and morphometry." Nature methods 16.1 (2019): 67-70.


- `axis` for softmax calculations. Defaulted at 1 (channel dimension).
- `reduction` will be used when we call `Learner.get_preds`
- `activation` function will be applied on the raw output logits of the model when calling `Learner.get_preds` or `Learner.predict`
- `decodes` function converts the output of the model to a format similar to the target (here binary masks). This is used in `Learner.predict`

In [None]:
%nbdev_export
class WeightedSoftmaxCrossEntropy(torch.nn.Module):
    "Weighted Softmax Cross Entropy loss functions"
    def __init__(self, *args, axis=-1, reduction = 'mean'):
        super().__init__()
        self.reduction = reduction
        self.axis = axis

    def forward(self, inputs, targets, weights):

        # Weighted soft-max cross-entropy loss
        loss = F.cross_entropy(inputs, targets, reduction='none')
        loss = loss * weights
        if  self.reduction == 'mean':
            return loss.mean()

        elif self.reduction == 'sum':
            return loss.sum()

        else:
            return loss

    def decodes(self, x): return x.argmax(dim=self.axis)
    def activation(self, x): return F.softmax(x, dim=self.axis)

In a segmentation task, we want to take the softmax over the channel dimension

In [None]:
tst = WeightedSoftmaxCrossEntropy(axis=1)
output = TensorImage(torch.randn(4, 5, 356, 356, requires_grad=True))
targets = TensorMask(torch.ones(4, 356, 356).long())
weights = torch.randn(4, 356, 356)
loss = tst(output, targets, weights)

test_eq(tst.activation(output), F.softmax(output, dim=1))
test_eq(tst.decodes(output), output.argmax(dim=1))

## Export -

In [None]:
%nbdev_hide
from nbdev.export import *
notebook2script()

Converted 00_learner.ipynb.
Converted 01_models.ipynb.
Converted 02_data.ipynb.
Converted 03_metrics.ipynb.
Converted 04_callbacks.ipynb.
Converted 05_losses.ipynb.
Converted 06_utils.ipynb.
Converted add_information.ipynb.
Converted gt_estimation.ipynb.
Converted index.ipynb.
Converted model_library.ipynb.
Converted predict.ipynb.
Converted train.ipynb.
