In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BinaryNoiseBalancingLoss(nn.modules.loss._Loss):
    """
    x = p(c_true=0|c_noisy=1)
    y = p(c_true=1|c_noisy=0)
    theta_0 = p(c_noisy=1|c_true=0)
    theta_1 = p(c_noisy=0|c_true=1)
    """

    def __init__(self, x, y, apply_softmax=True, reduction="mean"):
        super().__init__()
        self.theta_0 = x / (1 + x - y)
        self.theta_1 = y / (1 + y - x)
        self.apply_softmax = apply_softmax
        self.reduction = reduction

    def forward(self, input, target):
        if self.apply_softmax:
            p = torch.log_softmax(input, 1)
        else:
            p = input

        p_0 = p[:, 0]
        p_1 = p[:, 1]

        a = (self.theta_0 * p_1) + ((1 - self.theta_1) * p_0)
        b = (1 - self.theta_0) * p_1 + (self.theta_1 * p_0)

        loss = F.nll_loss(torch.stack((a, b), dim=1), target)

        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            return loss
        
        

loss = BinaryNoiseBalancingLoss(x=0, y=1)

ZeroDivisionError: division by zero