In [None]:
# hide
# skip
!git clone https://github.com/benihime91/gale # install gale on colab
!pip install -e "gale[dev]"

In [None]:
# default_exp losses

In [None]:
# hide
%load_ext nb_black
%load_ext autoreload
%autoreload 2
%matplotlib inline

<IPython.core.display.Javascript object>

In [None]:
# hide
import warnings

from nbdev.export import *
from nbdev.showdoc import *

warnings.filterwarnings("ignore")

<IPython.core.display.Javascript object>

# Loss Functions
> Custom loss functions in `Gale`

In [None]:
# export
import logging
from typing import *

import torch
import torch.nn.functional as F
import torch.nn.modules.loss as torch_losses
from fastcore.all import store_attr
from fvcore.nn import sigmoid_focal_loss
from omegaconf import DictConfig, OmegaConf
from timm.loss import SoftTargetCrossEntropy
from torch import Tensor, nn

from gale.torch_utils import maybe_convert_to_onehot
from gale.utils.structures import LOSS_REGISTRY

_logger = logging.getLogger(__name__)

<IPython.core.display.Javascript object>

In [None]:
from fastcore.all import *

<IPython.core.display.Javascript object>

In [None]:
# export
_all_ = ["SoftTargetCrossEntropy", "LOSS_REGISTRY"]

<IPython.core.display.Javascript object>

In [None]:
# export
LOSS_REGISTRY.register(SoftTargetCrossEntropy)

<IPython.core.display.Javascript object>

In [None]:
# export
@LOSS_REGISTRY.register()
class LabelSmoothingCrossEntropy(nn.Module):
    "Cross Entropy Loss with Label Smoothing"

    def __init__(
        self,
        eps: float = 0.1,
        reduction: str = "mean",
        weight: Optional[Tensor] = None,
    ):
        super(LabelSmoothingCrossEntropy, self).__init__()
        store_attr("eps, reduction, weight")

    def forward(self, input: Tensor, target: Tensor):
        """
        Shape:
        - Input  : $(N,C)$ where $N$ is the mini-batch size and $C$ is the total number of classes
        - Target : $(N)$ where each value is $0 \leq {targets}[i] \leq C-10≤targets[i]≤C−1$
        - Output: scalar. If `reduction` is `none`, then $(N, *)$ , same shape as input.
        """
        c = input.size()[1]
        log_preds = F.log_softmax(input, dim=1)
        if self.reduction == "sum":
            loss = -log_preds.sum()
        else:
            loss = -log_preds.sum(dim=1)
            if self.reduction == "mean":
                loss = loss.mean()
        loss = loss * self.eps / c + (1 - self.eps) * F.nll_loss(
            log_preds, target.long(), weight=self.weight, reduction=self.reduction
        )
        return loss

<IPython.core.display.Javascript object>

In [None]:
criterion = LabelSmoothingCrossEntropy(reduction="mean")

output = torch.randn(32, 5, requires_grad=True)
target = torch.empty(32, dtype=torch.long).random_(5)

loss = criterion(output, target)

<IPython.core.display.Javascript object>

In [None]:
# export
@LOSS_REGISTRY.register()
class BinarySigmoidFocalLoss(nn.Module):
    """
    Creates a criterion that computes the focal loss between binary `input` and `target`.
    Focal Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.

    Source: https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
    """

    def __init__(
        self,
        alpha: float = -1,
        gamma: float = 2,
        reduction: str = "mean",
    ):
        super(BinarySigmoidFocalLoss, self).__init__()
        store_attr("alpha, gamma, reduction")

    def forward(self, input: Tensor, target: Tensor):
        """
        Shape:
        - Input: : $(N, *)$ where $*$ means, any number of additional dimensions.
        - Target: : $(N, *)$, same shape as the input.
        - Output: scalar. If `reduction` is 'none', then $(N, *)$ , same shape as input.
        """
        loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha, self.reduction)
        return loss

<IPython.core.display.Javascript object>

[Focal Loss](https://arxiv.org/pdf/1708.02002.pdf) is the same as cross entropy except easy-to-classify observations are down-weighted in the loss calculation. The strength of down-weighting is proportional to the size of the gamma parameter. Put another way, the larger gamma the less the easy-to-classify observations contribute to the loss.

In [None]:
criterion = BinarySigmoidFocalLoss(reduction="mean")

target = torch.ones([10, 64], dtype=torch.float32)
output = torch.full([10, 64], 1.5)

loss = criterion(output, target)

<IPython.core.display.Javascript object>

In [None]:
# export
@LOSS_REGISTRY.register()
class FocalLoss(nn.Module):
    """
    Same as `nn.CrossEntropyLoss` but with focal paramter, `gamma`.
    Focal Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Focal loss is computed as follows :
    ${FL}(p_t)$ = $\alpha(1 - p_t)^{\gamma}{log}(p_t)$

    Source: https://kornia.readthedocs.io/en/latest/_modules/kornia/losses/focal.html
    """

    def __init__(
        self,
        alpha: float = 1,
        gamma: float = 2,
        reduction: str = "mean",
        eps: float = 1e-8,
    ):

        super(FocalLoss, self).__init__()
        store_attr("alpha, gamma, reduction, eps")

    def forward(self, input: Tensor, target: Tensor):
        """
        Shape:
        - Input  : $(N,C)$ where $N$ is the mini-batch size and $C$ is the total number of classes
        - Target : $(N)$ where each value is $0 \leq {targets}[i] \leq C-10≤targets[i]≤C−1$
        """
        if not len(input.shape) >= 2:
            raise ValueError(
                "Invalid input shape, we expect BxCx*. Got: {}".format(input.shape)
            )

        if input.size(0) != target.size(0):
            raise ValueError(
                "Expected input batch_size ({}) to match target batch_size ({}).".format(
                    input.size(0), target.size(0)
                )
            )

        n = input.size(0)

        # compute softmax over the classes axis
        softmax_inputs: Tensor = F.softmax(input, dim=1) + self.eps

        # create the labels one hot tensor
        one_hot_targs: Tensor = maybe_convert_to_onehot(target, softmax_inputs)

        # compute the actual focal loss
        focal_weight = torch.pow(-softmax_inputs + 1.0, self.gamma)

        focal_factor = -self.alpha * focal_weight * torch.log(softmax_inputs)

        loss = torch.sum(one_hot_targs * focal_factor, dim=1)

        if self.reduction == "none":
            loss = loss
        elif self.reduction == "mean":
            loss = torch.mean(loss)
        elif self.reduction == "sum":
            loss = torch.sum(loss)
        else:
            raise NotImplementedError(
                "Invalid reduction mode: {}".format(self.reduction)
            )
        return loss

<IPython.core.display.Javascript object>

Arguments to `FocalLoss`:
- `alpha` (float): Weighting factor $\alpha$ in `[0, 1]`.
- `gamma` (float, optional): Focusing parameter $\gamma$ >= 0. Default 2.
- `reduction` (str, optional): Specifies the reduction to apply to the
- `output`: `none` | `mean` | `sum`. 
  * `none`: no reduction will be applied,
  * `mean`: the sum of the output will be divided by the number of elements in the output 
  * `sum`: the output will be summed. 
  * Default: `none`.
- `eps` (float, optional): Scalar to enforce numerical stabiliy. Default: 1e-8.

In [None]:
criterion = FocalLoss(alpha=0.5, gamma=2.0, reduction="mean")

N = 5  # num_classes
input = torch.randn(32, N, requires_grad=True)
target = torch.empty(32, dtype=torch.long).random_(N)
loss = criterion(input, target)


# Compare focal loss with gamma = 0 ,cross entropy
fl = FocalLoss(alpha=1, gamma=0, reduction="mean")
ce = nn.CrossEntropyLoss(reduction="mean")
output = torch.randn(32, N, requires_grad=True)
target = torch.empty(32, dtype=torch.long).random_(N)
test_close(fl(output, target), ce(output, target))

# Test focal loss with gamma > 0 is different than cross entropy
fl = FocalLoss(gamma=2)
with torch.no_grad():
    test_ne(fl(output, target), ce(output, target))

<IPython.core.display.Javascript object>

## Build

Losses are created by the Lightning-Tasks in Gale using the Config. To load a loss via gale config the loss must be present in either `LOSS_REGISTRY` or losses available in the `torch.nn.modules.loss`[_module](https://pytorch.org/docs/stable/nn.html#loss-functions)

In [None]:
# export
def build_loss(config: DictConfig):
    """
    Builds a loss from a config.
    This assumes a 'name' key in the config which is used to determine what
    model class to instantiate. For instance, a config `{"name": "my_loss",
    "foo": "bar"}` will find a class that was registered as "my_loss". A custom
    loss must first be registerd into `LOSS_REGISTRY`.
    """

    assert "name" in config, f"name not provided for loss: {config}"
    config = OmegaConf.to_container(config, resolve=True)

    name = config["name"]
    args = config["init_args"]

    # if we are passing weights, we need to change the weights from a list to a tensor
    if args is not None:
        if "weight" in args and args["weight"] is not None:
            args["weight"] = torch.tensor(args["weight"], dtype=torch.float)

    if name in LOSS_REGISTRY:
        instance = LOSS_REGISTRY.get(name)

    # the name should be available in torch.nn.modules.loss
    else:
        assert hasattr(torch_losses, name), (
            f"{name} isn't a registered loss"
            ", nor is it available in torch.nn.modules.loss"
        )
        instance = getattr(torch_losses, name)

    if args is not None:
        loss = instance(**args)
    else:
        loss = instance()
    _logger.debug("Built loss function: {}".format(loss.__class__.__name__))
    return loss

<IPython.core.display.Javascript object>

For Image Classification a loss is created like so ...

In [None]:
# hide-output
from gale.config import get_config

cfg = get_config(config_name="classification")

# grab the config for the Loss Function
loss_cfg = cfg.training.train_loss_fn

# print(OmegaConf.to_yaml(loss_cfg))
loss = build_loss(loss_cfg)

<IPython.core.display.Javascript object>

## Export-

In [None]:
# hide
notebook2script("01a_losses.ipynb")

Converted 01a_losses.ipynb.


<IPython.core.display.Javascript object>