In [None]:


import torch
from torch.nn import functional as F
from torch.autograd import Function


def dice_loss(pred, target, smooth=1.):
    pred = pred.contiguous()
    target = target.contiguous()

    intersection = (pred * target).sum(dim=2).sum(dim=2)

    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))

    return loss.mean()


class Weighted_Cross_Entropy_Loss(torch.nn.Module):

    def __init__(self):
        super(Weighted_Cross_Entropy_Loss, self).__init__()

    def forward(self, pred, target, weights):
        n, c, H, W = pred.shape
        # Calculate log probabilities
        logp = F.log_softmax(pred, dim=1)

        # Gather log probabilities with respect to target
        logp = torch.gather(logp, 1, target.view(n, 1, H, W))

        # Multiply with weights
        weighted_logp = (logp * weights).view(n, -1)

        # Rescale so that loss is in approx. same interval
        weighted_loss = weighted_logp.sum(1) / weights.view(n, -1).sum(1)

        # Average over mini-batch
        weighted_loss = -weighted_loss.mean()

        return weighted_loss

# def class_weight(target):
#     weight = torch.zeros(batch_size, H, W)
#     for i in range(out_channels):
#         i_t = i * torch.ones([batch_size, H, W], dtype=torch.long)
#         loc_i = (target == i_t).to(torch.long)
#         count_i = loc_i.view(out_channels, -1).sum(1)
#         total = H*W
#         weight_i = total / count_i
#         weight_t = loc_i * weight_i.view(-1, 1, 1)
#         weight += weight_t
#     return weight

In [None]:
class BU_Net_Loss(torch.nn.Module):
    def __init__(self, weight=None):
        super(BU_Net_Loss, self).__init__()
        self.weight = weight
        self.cross_entropy_loss = Weighted_Cross_Entropy_Loss(weight)
    
    def forward(self, pred, target):
        weights = self.compute_class_weight(target)
        wce_loss = self.cross_entropy_loss(pred, target, weights)
        dice = dice_loss(pred, target)
        total_loss = wce_loss + dice
        return total_loss
    
    def compute_class_weight(self, target):
        n, H, W = target.size()
        class_weights = torch.zeros(n, H, W).to(target.device)
        for i in range(target.max() + 1):
            mask = (target == i).float()
            class_weight = 1.0 / (mask.sum() + 1e-6)
            class_weights += mask * class_weight
        return class_weights


### 사용법 예시 ... ??
- pred = torch.randn(8, 3, 256, 256)  # Example predictions
- target = torch.randint(0, 3, (8, 256, 256))  # Example target
- loss_fn = BU_Net_Loss()
- loss = loss_fn(pred, target)


- Weighted Cross-Entropy Loss (WCE)

클래스 분류를 위한 손실 함수이다. 클래스가 불균형한 경우, 각 클래스에 가중치를 부여하여 손실을 계산한다.

- Dice Loss Coefficient (DLC)

예측된 분할 영역과 실제 분할 영역의 최대 중첩을 찾기 위한 목적 함수이다. 주로 이미지 분할 작업에서 사용되며, 분할 성능을 향상시킨다.

- Loss Function Formulation
  - L total=WCE+DLC