In [2]:
import torch
from torch.nn import functional as F
from torch.autograd import Function


def Dice_Loss_Coefficient(pred, target, weights, smooth=1.):
    pred = pred.contiguous()
    target = target.contiguous()
    weights = weights.contiguous()
    intersection = (pred * target * weights).sum(dim=2).sum(dim=2)
    union = (weights * pred).sum(dim=2).sum(dim=2) + (weights * target).sum(dim=2).sum(dim=2)
    
    loss = (1 - ((2. * intersection + smooth) / (union + smooth)))

    return loss.mean()


class Weighted_Cross_Entropy_Loss(torch.nn.Module):

    def __init__(self):
        super(Weighted_Cross_Entropy_Loss, self).__init__()

    def forward(self, pred, target, weights):
        n, c, H, W = pred.shape
        # Calculate log probabilities
        logp = F.log_softmax(pred, dim=1)

        # Gather log probabilities with respect to target
        logp = torch.gather(logp, 1, target.view(n, 1, H, W))

        # Multiply with weights
        weighted_logp = (logp * weights).view(n, -1)

        # Rescale so that loss is in approx. same interval
        weighted_loss = weighted_logp.sum(1) / weights.view(n, -1).sum(1)

        # Average over mini-batch
        weighted_loss = -weighted_loss.mean()

        return weighted_loss

class BU_Net_Loss(torch.nn.Module):
    def __init__(self, weight=None):
        super(BU_Net_Loss, self).__init__()
        self.weight = weight
        self.cross_entropy_loss = Weighted_Cross_Entropy_Loss(weight)
        
    def forward(self, pred, target):
        weights = self.compute_class_weight(target)
        wce_loss = self.cross_entropy_loss(pred, target, weights)
        dice_loss = Dice_Loss_Coefficient(pred, target, weights)
        total_loss = wce_loss + dice_loss
        return total_loss

    def compute_class_weight(self, target):
        n, H, W = target.size()
        class_weights = torch.zeros(n, H, W).to(target.device)
        for i in range(target.max() + 1):
            mask = (target == i).float()
            class_weight = 1.0 / (mask.sum() + 1e-6)
            class_weights += mask * class_weight
        return class_weights