In [None]:
import import_ipynb

In [None]:
import torch.nn as nn

import torch
import base
import functional as F

importing Jupyter notebook from base.ipynb
importing Jupyter notebook from functional.ipynb


In [None]:
class JaccardLoss(base.Loss):
    def __init__(self, eps=1.0, ignore_channels=None, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        self.ignore_channels = ignore_channels

    def forward(self, y_pr, y_gt):
        return 1 - F.jaccard(
            y_pr,
            y_gt,
            eps=self.eps,
            threshold=None,
            ignore_channels=self.ignore_channels,
        )

In [None]:
class DiceLoss(base.Loss):
    def __init__(self, eps=1.0, beta=1.0, ignore_channels=None, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        self.beta = beta
        self.ignore_channels = ignore_channels

    def forward(self, y_pr, y_gt):
        return 1 - F.f_score(
            y_pr,
            y_gt,
            beta=self.beta,
            eps=self.eps,
            threshold=None,
            ignore_channels=self.ignore_channels,
        )

In [None]:
class L1Loss(nn.L1Loss, base.Loss):
    pass


class MSELoss(nn.MSELoss, base.Loss):
    pass


class CrossEntropyLoss(nn.CrossEntropyLoss, base.Loss):
    pass


class NLLLoss(nn.NLLLoss, base.Loss):
    pass


class BCELoss(nn.BCELoss, base.Loss):
    pass


class BCEWithLogitsLoss(nn.BCEWithLogitsLoss, base.Loss):
    pass

['BINARY_MODE',
 'DiceLoss',
 'FocalLoss',
 'JaccardLoss',
 'LovaszLoss',
 'MCCLoss',
 'MULTICLASS_MODE',
 'MULTILABEL_MODE',
 'SoftBCEWithLogitsLoss',
 'SoftCrossEntropyLoss',
 'TverskyLoss',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 '_functional',
 'constants',
 'dice',
 'focal',
 'jaccard',
 'lovasz',
 'mcc',
 'soft_bce',
 'soft_ce',
 'tversky']

In [None]:
class FocalLoss(base.Loss):
    def __init__(self, alpha=1, gamma=2, class_weights=None, logits=False, reduction='mean'):
        super().__init__()
        assert reduction in ['mean', None]
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduction = reduction
        self.class_weights = class_weights if class_weights is not None else 1.

    def forward(self, y_pr, y_gt):
        bce_loss = nn.functional.binary_cross_entropy(y_pr, y_gt)

        pt = torch.exp(- bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        focal_loss = focal_loss * torch.tensor(self.class_weights).to(focal_loss.device)

        if self.reduction == 'mean':
            focal_loss = focal_loss.mean()

        return focal_loss