In [100]:
import torch.nn as nn
import torch
import torch.nn.functional as F

class BinaryFocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean', eps=1e-8):
        """
        alpha: 类别权重（平衡正负样本，建议 0.25 用于正样本少的场景）
        gamma: 难易样本调节因子（越大，对难样本的关注越高）
        reduction: 'mean'/'sum'/'none'
        eps: 数值稳定性
        """
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.eps = eps

    def forward(self, inputs, targets):
        # 计算概率
        probs = torch.sigmoid(inputs)
        bce_loss = F.binary_cross_entropy_with_logits(
            inputs, targets, reduction='none'
        )
        
        # Focal Weight: (1 - p_t)^gamma
        p_t = probs * targets + (1 - probs) * (1 - targets)  # p if t=1 else 1-p
        focal_weight = (1 - p_t).pow(self.gamma)
        
        # Alpha 权重
        alpha_weight = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        
        # 组合损失
        loss = focal_weight * alpha_weight * bce_loss
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss




In [229]:

class AsymmetricLossOptimized(nn.Module):
    ''' Notice - optimized version, minimizes memory allocation and gpu uploading,
    favors inplace operations'''

    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False, ft_cls=None, num_classes=9):
        super(AsymmetricLossOptimized, self).__init__()

        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.eps = eps

        self.flag = True

        self.ft_cls = ft_cls
        self.num_classes = num_classes
        # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
        self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None

    def forward(self, x, y):
        """"
        Parameters
        ----------
        x: input logits
        y: targets (multi-label binarized vector)
        """
        
        self.targets = y
        self.anti_targets = 1 - y

        # Calculating Probabilities
        self.xs_pos = torch.sigmoid(x)
        self.xs_neg = 1.0 - self.xs_pos

        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            self.xs_neg.add_(self.clip).clamp_(max=1)

        self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
        self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))
        
        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
            self.xs_pos = self.xs_pos * self.targets
            self.xs_neg = self.xs_neg * self.anti_targets

            if self.ft_cls is not None:
                # 需要按照微调需求手动更改
                # 根据目前的测试结果看，漏的情况的原因：1）阳性类的得分不够；2）0类的得分高了
                
                # 由于1和0经常比较相近，因此我们还可以考虑不对1类动手的方案
                gamma_neg = [1.0] + [1.0] + [10.] + [10.] + [10.] + [10.] + [1.]*3
                gamma_pos = [self.gamma_pos] * 9
                #weights = [0.] + [1.]*5 + [0.]*4
                weights = [0.] + [1.] + [2.]*4 + [0.]*3
            else:
                gamma_neg = self.gamma_neg
                gamma_pos = self.gamma_pos
                weights = torch.tensor([1.]*9, device=x.device)

            self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
                                          gamma_pos * self.targets + gamma_neg * self.anti_targets)

            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            self.loss = self.loss * self.asymmetric_w
        

        if self.ft_cls is not None and 1==1:
            assert self.loss.shape[-1] == 10
            if self.ft_cls == 1:
                print("移除阳性类的loss")
                self.loss *= torch.tensor([1.] + [0.]*5 + [0.]*4).to(x.device) # 移除阳性类的loss
            elif self.ft_cls == 2: # 移除阴性类的loss:
                print("移除阴性类的loss")
                self.loss = self.loss*weights

        return -self.loss.sum(dim=1).mean()


In [250]:
import torch.nn.functional as F

num_classes = 10
batch_size = 5
logits = torch.randn(batch_size, num_classes)
labels = torch.randint(0, num_classes, (batch_size,))
labels_onehot = F.one_hot(labels, num_classes).type(torch.float32)
print(logits.shape, labels_onehot.shape)
criterion_focal = BinaryFocalLoss(alpha=0.25, gamma=2, reduction='mean', eps=1e-8)
criterion_bce = nn.BCEWithLogitsLoss(reduction='mean')
criterion_asl = AsymmetricLossOptimized(gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8)
loss_focal = criterion_focal(logits, labels_onehot)
loss_bce = criterion_bce(logits, labels_onehot)
loss_asl = criterion_asl(logits, labels_onehot)
print(loss_focal, loss_bce, loss_asl)

torch.Size([5, 10]) torch.Size([5, 10])
tensor(0.2845) tensor(0.8440) tensor(1.6843)
