In [4]:
import torch
from torch import nn
import torch.nn.functional as F

In [5]:
class GHM_Loss(nn.Module):
    def __init__(self, bins=10, alpha=0.5):
        '''
        bins: split to n bins
        alpha: hyper-parameter
        '''
        super(GHM_Loss, self).__init__()
        self._bins = bins
        self._alpha = alpha
        self._last_bin_count = None

    def _g2bin(self, g):
        return torch.floor(g * (self._bins - 0.0001)).long()

    def _custom_loss(self, x, target, weight):
        raise NotImplementedError

    def _custom_loss_grad(self, x, target):
        raise NotImplementedError

    def forward(self, x, target):
        g = torch.abs(self._custom_loss_grad(x, target)).detach()
        print("g", g)
        bin_idx = self._g2bin(g)
        print("bin_idx", bin_idx)
        bin_count = torch.zeros((self._bins))
        print("bin_count", bin_count)
        print("循环开始")
        for i in range(self._bins):
            print((bin_idx == i))
            print((bin_idx == i).sum())
            bin_count[i] = (bin_idx == i).sum().item()
            print("bin_count", bin_count, i)
        print("循环结束")
        N = (x.size(0) * x.size(1))
        print("行", x.size(0))
        print("列", x.size(1))
        print("N", N)
        if self._last_bin_count is None:
            self._last_bin_count = bin_count
        else:
            bin_count = self._alpha * self._last_bin_count + (1 - self._alpha) * bin_count
            self._last_bin_count = bin_count
        print("bin_count _last_bin_count", bin_count, self._last_bin_count)
        nonempty_bins = (bin_count > 0).sum().item()
        print("nonempty_bins", nonempty_bins)
        gd = bin_count * nonempty_bins
        print("gd", gd)
        gd = torch.clamp(gd, min=0.0001)
        print("gd", gd)
        beta = N / gd
        print("beta", beta)
        return self._custom_loss(x, target, beta[bin_idx])


class GHMC_Loss(GHM_Loss):
    '''
        GHM_Loss for classification
    '''

    def __init__(self, bins, alpha):
        super(GHMC_Loss, self).__init__(bins, alpha)

    def _custom_loss(self, x, target, weight):
        print(x.shape)
        print(target.shape)
        print(weight.shape)
        return F.binary_cross_entropy_with_logits(x, target, weight=weight)

    def _custom_loss_grad(self, x, target):
        return torch.sigmoid(x).detach() - target

In [None]:
x = torch.randn(3, 5)
target = torch.LongTensor(3, 1).random_(5)
target = torch.zeros(3, 5).scatter_(1, target, 1)
ghmc = GHMC_Loss(10, 0.5)
result = ghmc.forward(x, target)

In [14]:
input = torch.randn(3, 2)
target = torch.tensor([0, 1, 1])
target = F.one_hot(target).float()
print(input)
print(target)
weight_CE = torch.FloatTensor([1, 2, 3])
weight_CE = torch.randn(2)
print(weight_CE)
ce = nn.CrossEntropyLoss(weight=weight_CE)
    # ce = nn.CrossEntropyLoss()
loss = ce(input, target)
print(loss)

tensor([[-0.2240,  0.5549],
        [-1.1436, -0.0177],
        [ 0.0555, -0.1390]])
tensor([[1., 0.],
        [0., 1.],
        [0., 1.]])
tensor([2.8617, 0.4143])
tensor(1.2519)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class BinaryDiceLoss(nn.Module):
    """
    Args:
        ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient
        reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'
    Shapes:
        output: A tensor of shape [N, *] without sigmoid activation function applied
        target: A tensor of shape same with output
    Returns:
        Loss tensor according to arg reduction
    Raise:
        Exception if unexpected reduction
    """

    def __init__(self, ignore_index=None, reduction='mean', **kwargs):
        super(BinaryDiceLoss, self).__init__()
        self.smooth = 1  # suggest set a large number when target area is large,like '10|100'
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.batch_dice = False  # treat a large map when True
        if 'batch_loss' in kwargs.keys():
            self.batch_dice = kwargs['batch_loss']

    def forward(self, output, target, use_sigmoid=True):
        assert output.shape[0] == target.shape[0], "output & target batch size don't match"
        if use_sigmoid:
            output = torch.sigmoid(output)

        if self.ignore_index is not None:
            validmask = (target != self.ignore_index).float()
            output = output.mul(validmask)  # can not use inplace for bp
            target = target.float().mul(validmask)

        dim0 = output.shape[0]
        if self.batch_dice:
            dim0 = 1

        output = output.contiguous().view(dim0, -1)
        target = target.contiguous().view(dim0, -1).float()

        num = 2 * torch.sum(torch.mul(output, target), dim=1) + self.smooth
        den = torch.sum(output.abs() + target.abs(), dim=1) + self.smooth

        loss = 1 - (num / den)

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'none':
            return loss
        else:
            raise Exception('Unexpected reduction {}'.format(self.reduction))


class DiceLoss(nn.Module):
    """
    Args:
        weight: An array of shape [num_classes,]
        ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient
        output: A tensor of shape [N, C, *]
        target: A tensor of same shape with output
        other args pass to BinaryDiceLoss
    Return:
        same as BinaryDiceLoss
    """

    def __init__(self, weight=None, ignore_index=None, **kwargs):
        super(DiceLoss, self).__init__()
        self.kwargs = kwargs
        self.weight = weight
        if isinstance(ignore_index, (int, float)):
            self.ignore_index = [int(ignore_index)]
        elif ignore_index is None:
            self.ignore_index = []
        elif isinstance(ignore_index, (list, tuple)):
            self.ignore_index = ignore_index
        else:
            raise TypeError("Expect 'int|float|list|tuple', while get '{}'".format(type(ignore_index)))

    def forward(self, output, target):
        assert output.shape == target.shape, 'output & target shape do not match'
        dice = BinaryDiceLoss(**self.kwargs)
        total_loss = 0
        output = F.softmax(output, dim=1)
        for i in range(target.shape[1]):
            if i not in self.ignore_index:
                dice_loss = dice(output[:, i], target[:, i], use_sigmoid=False)
                if self.weight is not None:
                    assert self.weight.shape[0] == target.shape[1], \
                        'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
                    dice_loss *= self.weights[i]
                total_loss += (dice_loss)
        loss = total_loss / (target.size(1) - len(self.ignore_index))
        return loss


def test():
    input = torch.rand((3, 1, 32, 32, 32))
    model = nn.Conv3d(1, 4, 3, padding=1)
    target = torch.randint(0, 4, (3, 1, 32, 32, 32)).float()
    target = make_one_hot(target, num_classes=4)
    criterion = DiceLoss(ignore_index=[2, 3], reduction='mean')
    loss = criterion(model(input), target)
    loss.backward()
    print(loss.item())


def make_one_hot(input, num_classes=None):
    """Convert class index tensor to one hot encoding tensor.

    Args:
         input: A tensor of shape [N, 1, *]
         num_classes: An int of number of class
    Shapes:
        predict: A tensor of shape [N, *] without sigmoid activation function applied
        target: A tensor of shape same with predict
    Returns:
        A tensor of shape [N, num_classes, *]
    """
    if num_classes is None:
        num_classes = input.max() + 1
    shape = np.array(input.shape)
    shape[1] = num_classes
    shape = tuple(shape)
    result = torch.zeros(shape)
    result = result.scatter_(1, input.cpu().long(), 1)
    return result


if __name__ == '__main__':
    test()

In [None]:
class BinaryDiceLoss(nn.Module):
    """DiceLoss implemented from 'Dice Loss for Data-imbalanced NLP Tasks'
    Useful in dealing with unbalanced data
    Add softmax automatically
    """

    def __init__(self, alpha=0.5, gamma=0.5):
        super(BinaryDiceLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, y_pred, y_true, reduction='mean'):
        """
        :param y_pred: [N, C, ]
        :param y_true: [N, C, ]
        :param reduction: 'mean' or 'sum'
        """
        batch_size = y_true.size(0)
        y_pred = y_pred.contiguous().view(batch_size, -1)
        y_true = y_true.contiguous().view(batch_size, -1)
        
        # 分子
        numerator = torch.sum(2 * torch.pow((1 - y_pred), self.alpha) * y_pred * y_true, dim=1) + self.gamma
        denominator = torch.sum(torch.pow((1 - y_pred), self.alpha)  * y_pred + y_true, dim=1) + self.gamma
        loss = 1 - (numerator / denominator)
        if reduction == 'mean':
            return loss.mean()
        elif reduction == 'sum':
            return loss.sum()
        else:
            return loss

class DiceLoss(nn.Module):
    def __init__(self, alpha=1, gamma=1):
        super(DiceLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.binary_dice_loss = BinaryDiceLoss(alpha, gamma)

    def forward(self, y_pred, y_true, reduction='mean'):
        """
        :param y_pred: [N, C, ]
        :param y_true: [N, ]
        :param reduction: 'mean' or 'sum'
        """
        shape = y_pred.shape
        num_labels = shape[1]
        dims = [i for i in range(len(y_pred.shape))]
        dims.insert(1, dims.pop())
        y_pred = torch.softmax(y_pred, dim=1)
        y_true = F.one_hot(y_true, num_classes=num_labels).permute(*dims)
        loss = self.binary_dice_loss(y_pred, y_true, reduction)
        return loss