In [None]:
# default_exp losses

# Losses

> Implements custom loss functions.

In [None]:
#hide
from nbdev.showdoc import *
from fastcore.test import *
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#export 
from fastai2.vision.all import *
#import torch
#import torch.nn.functional as F

## Weighted Softmax Cross Entropy Loss

as described by Falk, Thorsten, et al. "U-Net: deep learning for cell counting, detection, and morphometry." Nature methods 16.1 (2019): 67-70.


- `axis` for softmax calculations. Defaulted at 1 (channel dimension).
- `reduction` will be used when we call `Learner.get_preds`
- `activation` function will be applied on the raw output logits of the model when calling `Learner.get_preds` or `Learner.predict`
- `decodes` function converts the output of the model to a format similar to the target (here binary masks). This is used in `Learner.predict`

In [None]:
#export
class WeightedSoftmaxCrossEntropy(torch.nn.Module):
    "Weighted Softmax Cross Entropy loss functions"
    def __init__(self, axis=-1, *args, reduction = 'mean'):
        super().__init__()
        self.reduction = reduction
        self.axis = axis
        
    def decodes(self, x):  
        return x.argmax(dim=self.axis)

    def activation(self, x): 
        return F.softmax(x, dim=self.axis)
    
    def forward(self, inputs, targ_weights): 
    
        # Unpack targets and weights tuple
        targets = targ_weights[0]
        weights = targ_weights[1]
        
        # Weighted soft-max cross-entropy loss
        log_smx = F.log_softmax(inputs, dim=1)*targets
        # Broadcasting weights a axis 1 instead?
        loss_wce = -log_smx.min(dim=1).values*weights
        
        if  self.reduction == 'mean':
            return loss_wce.mean()
            
        elif self.reduction == 'sum':
            return loss_wce.sum()
        
        else:
            return loss_wce

In a segmentation task, we want to take the softmax over the channel dimension

In [None]:
tst = WeightedSoftmaxCrossEntropy(axis=1)
output = torch.randn(4, 2, 356, 356)
targ_weights = (torch.ones_like(output), torch.randn(4, 356, 356))
_ = tst(output, targ_weights)

test_eq(tst.activation(output), F.softmax(output, dim=1))
test_eq(tst.decodes(output), output.argmax(dim=1))

## Export

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_learner.ipynb.
Converted 01_models.ipynb.
Converted 02_data.ipynb.
Converted 03_metrics.ipynb.
Converted 04_callbacks.ipynb.
Converted 05_losses.ipynb.
Converted 06_utils.ipynb.
Converted index.ipynb.
