In [None]:
#export

import torch
import torch.nn as NN

In [None]:
#export
class CustomLoss:
    def __init__(self, reduction="mean", quadratic=False):
        self.sigmoid = NN.Sigmoid()
        self.kl = NN.KLDivLoss(reduction=reduction)
        self.cel = NN.BCEWithLogitsLoss(reduction=reduction)
        self.quadratic = quadratic
    def __call__(self, X, y):
        comb_loss = self.kl(self.sigmoid(X), y) + self.cel(X, y)
        if self.quadratic: return torch.pow(comb_loss, 2.)
        return comb_loss

In [None]:
#export
def get_weighted_loss(pos_weights, neg_weights, epsilon=1e-7, reduced=True):
    def weighted_loss(y_pred, y_true):
        loss = 0.0
        for i in range(len(pos_weights)):
            if reduced:
                loss_pos = torch.mean(-pos_weights[i] * y_true[:, i] * torch.log(y_pred[:, i] + epsilon))
                loss_neg = torch.mean(-neg_weights[i] * (1 - y_true[:, i]) * torch.log(1 - y_pred[:, i] + epsilon))
            else:
                loss_pos = -pos_weights[i] * y_true[:, i] * torch.log(y_pred[:, i] + epsilon)
                loss_neg = -neg_weights[i] * (1 - y_true[:, i]) * torch.log(1 - y_pred[:, i] + epsilon)
            loss += loss_pos + loss_neg
        return loss
    return weighted_loss

In [None]:
#export
def get_weighted_loss_with_logits(pos_weights, neg_weights, epsilon=1e-7, reduced=True):
    def weighted_loss(y_pred, y_true):
        y_pred = NN.Sigmoid()(y_pred)
        loss = 0.0
        for i in range(len(pos_weights)):
            if reduced:
                loss_pos = torch.mean(-pos_weights[i] * y_true[:, i] * torch.log(y_pred[:, i] + epsilon))
                loss_neg = torch.mean(-neg_weights[i] * (1 - y_true[:, i]) * torch.log(1 - y_pred[:, i] + epsilon))
            else:
                loss_pos = -pos_weights[i] * y_true[:, i] * torch.log(y_pred[:, i] + epsilon)
                loss_neg = -neg_weights[i] * (1 - y_true[:, i]) * torch.log(1 - y_pred[:, i] + epsilon)
            loss += loss_pos + loss_neg
        return loss
    return weighted_loss

In [None]:
#export
def get_weighted_loss_with_logits_and_punishment(pos_weights, neg_weights, epsilon=1e-7):
    def weighted_loss(y_pred, y_true):
        y_pred = NN.Sigmoid()(y_pred)
        cel_loss = 0.0
        for i in range(len(pos_weights)):
            cel_loss_pos = -pos_weights[i] * y_true[:, i] * torch.log(y_pred[:, i] + epsilon)
            cel_loss_neg = -neg_weights[i] * (1 - y_true[:, i]) * torch.log(1 - y_pred[:, i] + epsilon)
            cel_loss += cel_loss_pos + cel_loss_neg
        pred_loss = torch.pow(2 * torch.abs(torch.abs(y_pred - 0.5) - 0.5), 2)
        loss = cel_loss.mean() + pred_loss.mean()
        return loss
    return weighted_loss