In [None]:
# default_exp loss
# all_slow

# Loss functions

> Various loss functions in PyTorch

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/marcomatteo/steel_segmentation/blob/master/nbs/06_loss.ipynb)

In [None]:
# hide
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# export
from steel_segmentation.metadata import *
from steel_segmentation.masks import *
from steel_segmentation.datasets import *
from steel_segmentation.dataloaders import *
from steel_segmentation.metrics import *

from fastai.torch_core import TensorBase
from fastai.losses import *

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss

import segmentation_models_pytorch as smp

from functools import partial

In this module there are various loss functions for binary and instance segmentation.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
dls = get_segmnt_dls(bs=4, device=device)

x, targs = dls.train.one_batch()
x.shape, targs.shape

(torch.Size([4, 3, 256, 1600]), torch.Size([4, 4, 256, 1600]))

In [None]:
x.device, targs.device

(device(type='cuda', index=0), device(type='cpu'))

In [None]:
model = smp.Unet("resnet18", 
                 encoder_weights="imagenet", 
                 classes=4, 
                 activation=None).to(device)
logits = model(x)
probs = torch.sigmoid(logits)
preds = ( probs > 0.5).float()
preds.shape

torch.Size([4, 4, 256, 1600])

## BCE and SoftDice loss

In this section there are some loss functions used by @khornlund in his [repository](https://github.com/khornlund/severstal-steel-defect-detection) for the Severstal competition.

In [None]:
def bce_loss(output, target):
    """BCE with logits from Pytorch."""
    return F.binary_cross_entropy_with_logits(output, target)

In [None]:
t_logits, t_targs = torch.Tensor(logits.float().cpu()), torch.Tensor(targs.float())

In [None]:
bce_loss(t_logits, t_targs)

tensor(0.8846, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

In [None]:
BCEWithLogitsLossFlat(axis=1)(t_logits, t_targs)

TensorBase(0.8846, grad_fn=<AliasBackward>)

In [None]:
#export
#reference: https://github.com/asanakoy/kaggle_carvana_segmentation/blob/master/asanakoy/losses.py
class SoftDiceLoss(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, logits, labels):
        probs = F.sigmoid(logits)
        num = labels.size(0)
        m1 = probs.view(num, -1)
        m2 = labels.view(num, -1)
        intersection = (m1 * m2)
        score = 2. * (intersection.sum(1) + 1) / (m1.sum(1) + m2.sum(1) + 1)
        score = 1 - score.sum() / num
        return score

In [None]:
loss = SoftDiceLoss()
loss(t_logits, t_targs)

tensor(0.9508, grad_fn=<RsubBackward1>)

In [None]:
#export
#reference: https://github.com/zdaiot/Kaggle-Steel-Defect-Detection
class WeightedSoftDiceLoss(nn.Module):

    def __init__(self, size_average=True, weight=[0.2, 0.8]):
        super().__init__()
        self.size_average = size_average
        self.weight = torch.FloatTensor(weight)
    
    def forward(self, logit_pixel, truth_pixel):
        batch_size = len(logit_pixel)
        logit = logit_pixel.view(batch_size, -1)
        truth = truth_pixel.view(batch_size, -1)
        assert(logit.shape == truth.shape)

        loss = self.soft_dice_criterion(logit, truth)

        if self.size_average:
            loss = loss.mean()
        return loss

    def soft_dice_criterion(self, logit, truth):
        batch_size = len(logit)
        probability = torch.sigmoid(logit)

        p = probability.view(batch_size, -1)
        t = truth.view(batch_size, -1)
        
        w = truth.detach()
        self.weight = self.weight.type_as(logit)
        w = w * (self.weight[1] - self.weight[0]) + self.weight[0]

        p = w * (p*2 - 1)  #convert to [0,1] --> [-1, 1]
        t = w * (t*2 - 1)

        intersection = (p * t).sum(-1)
        union =  (p * p).sum(-1) + (t * t).sum(-1)
        dice  = 1 - 2 * intersection/union

        loss = dice
        return loss

In [None]:
loss = WeightedSoftDiceLoss()
loss(t_logits, t_targs)

tensor(0.9108, grad_fn=<MeanBackward0>)

In [None]:
#export
#reference: https://github.com/zdaiot/Kaggle-Steel-Defect-Detection
class SoftBCEDiceLoss(nn.Module):

    def __init__(self, size_average=True, weight=[0.2, 0.8]):
        super().__init__()
        self.size_average = size_average
        self.weight = weight
        self.bce_loss = nn.BCEWithLogitsLoss(size_average=self.size_average, pos_weight=torch.tensor(self.weight[1]))
        self.softdiceloss = WeightedSoftDiceLoss(size_average=self.size_average, weight=weight)
    
    def forward(self, input, target):
        input, target = TensorBase(input).float(), TensorBase(target).float()
        soft_bce_loss = self.bce_loss(input, target)
        soft_dice_loss = self.softdiceloss(input, target)
        loss = 0.7 * soft_bce_loss + 0.3 * soft_dice_loss

        return loss

In [None]:
loss = SoftBCEDiceLoss()
loss(t_logits, t_targs)

TensorBase(0.7508, grad_fn=<AliasBackward>)

In [None]:
#export
#reference: https://github.com/zdaiot/Kaggle-Steel-Defect-Detection
class MultiClassesSoftBCEDiceLoss(nn.Module):
    
    def __init__(self, classes_num=4, size_average=True, weight=[0.2, 0.8]):
        super().__init__()
        self.classes_num = classes_num
        self.size_average = size_average
        self.weight = weight
        self.soft_bce_dice_loss = SoftBCEDiceLoss(size_average=self.size_average, weight=self.weight)
    
    def forward(self, input, target):
        """
        Args:
            input: tensor, [batch_size, classes_num, height, width]
            target: tensor, [batch_size, classes_num, height, width]
        """
        loss = 0
        for class_index in range(self.classes_num):
            input_single_class = input[:, class_index, :, :]
            target_singlt_class = target[:, class_index, :, :]
            single_class_loss = self.soft_bce_dice_loss(input_single_class, target_singlt_class)
            loss += single_class_loss
        
        loss /= self.classes_num

        return loss

In [None]:
loss = MultiClassesSoftBCEDiceLoss()
loss(t_logits, targs)

TensorBase(0.7436, grad_fn=<AliasBackward>)

In [None]:
#export
class IoULoss(nn.Module):
    """
    Intersection over union (Jaccard) loss
    Args:
        eps (float): epsilon to avoid zero division
        threshold (float): threshold for outputs binarization
        activation (str): An torch.nn activation applied to the outputs.
            Must be one of ['none', 'Sigmoid', 'Softmax2d']
    """

    def __init__(
        self,
        eps: float = 1e-7,
        threshold: float = None
    ):
        super().__init__()
        self.metric_fn = partial(iou, eps=eps, threshold=threshold)

    def forward(self, outputs, targets):
        iou = self.metric_fn(outputs, targets)
        return 1 - iou

In [None]:
loss = IoULoss()
loss(t_logits, t_targs)

tensor(0.9924, grad_fn=<RsubBackward1>)

In [None]:
# hide
from nbdev.export import notebook2script
notebook2script()

Converted 01_metadata.ipynb.
Converted 02_masks.ipynb.
Converted 03_datasets.ipynb.
Converted 04_dataloaders.ipynb.
Converted 05_metrics.ipynb.
Converted 06_loss.ipynb.
Converted 07_trainer.ipynb.
Converted 08_predict.ipynb.
Converted 09_visualize.ipynb.
Converted 10_fastai.classifier.ipynb.
Converted 11_fastai.seg_unet_resnet34.ipynb.
Converted 11_resnet18-UNET.ipynb.
Converted 12_fastai.seg_unet_xresnext34.ipynb.
Converted 13_torch.seg_fpn_resnet34.ipynb.
Converted 14_torch.seg_unet_resnet34.ipynb.
Converted 15_torch.seg_unet_resnet18.ipynb.
Converted 21_ensemble_unet_fpn_resnet34.ipynb.
Converted index.ipynb.
