In [None]:
# default_exp losses

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
# export
import torch
import torch.nn as nn

from functools import update_wrapper
from torch.nn.modules.loss import _Loss

# Losses

> Common objective (loss) functions

## Attribution

This section introduces the `FlatLoss` class. This code is, for the most part, copied by the [fast.ai](https://github.com/fastai/fastai/blob/8013797e05f0ae0d771d60ecf7cf524da591503c/fastai/layers.py) library.

## Wrappers

Below we define several wrappers for well known losses that are defined and implemented in the Pytorch library. The main idea is that we need to *flatten* our prediction before we pass them accordingly to the chosen loss function.

In [None]:
# export
class FlatLoss():
    """Same as whatever `func` is, but with flattened input and target."""
    def __init__(self, func: _Loss, axis: int = -1, to_float: bool = False, 
                 is_2d: bool = False, **kwargs):
        self.func = func(**kwargs)
        self.axis = axis
        self.to_float = to_float
        self.is_2d = is_2d
    
        update_wrapper(self, self.func)

    def __repr__(self): 
        return f'FlatLoss of {self.func}'
    
    @property
    def reduction(self) -> str:
        assert hasattr(self.func, 'reduction'), f'{self.func} does not have "reduction" attribute'
        return self.func.reduction
    
    @reduction.setter
    def reduction(self, reduction: str):
        self.func.reduction = reduction

    @property
    def weight(self) -> torch.tensor:
        assert hasattr(self.func, 'weight'), f'{self.func} does not have "weight" attribute'
        return self.func.weight
    
    @weight.setter
    def weight(self, weight: torch.tensor):
        self.func.weight = weight

    def __call__(self, prediction: torch.tensor, target: torch.tensor, **kwargs) -> torch.tensor:
        prediction = prediction.transpose(self.axis, -1).contiguous()
        target = target.transpose(self.axis, -1).contiguous()
        
        if self.to_float:
            target = target.float()
            
        prediction = prediction.view(-1, prediction.shape[-1]) if self.is_2d else prediction.view(-1)
        return self.func.__call__(prediction, target.view(-1), **kwargs)

The FlatLoss class creates a callable that will do whatever the function that we pass would do, but with flattened input and target before the operation.

## Common losses

In [None]:
# export
def FlatCrossEntropyLoss(axis: int = -1, to_float: bool = True, is_2d: bool = False, **kwargs):
    """Same as `nn.CrossEntropyLoss`, but with flattened input and target."""
    return FlatLoss(nn.CrossEntropyLoss, axis=axis, to_float=to_float, is_2d=is_2d, **kwargs)

In [None]:
# export 
def FlatBCELoss(axis: int = -1, to_float: bool = True, is_2d: bool = False, **kwargs):
    """Same as `nn.BCELoss`, but with flattened input and target."""
    return FlatLoss(nn.BCELoss, axis=axis, to_float=to_float, is_2d=is_2d, **kwargs)

In [None]:
# export
def FlatMSELoss(axis: int = -1, to_float: bool = True, is_2d: bool = False, **kwargs):
    """Same as `nn.MSELoss`, but with flattened input and target."""
    return FlatLoss(nn.MSELoss, axis=axis, to_float=to_float, is_2d=is_2d, **kwargs)