# catgeorical focal loss

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------------------------------------------------------------
# Compute class weights (alpha) from class counts
# ---------------------------------------------------------------
def compute_alpha_from_counts(class_counts):
    """
    Compute normalized inverse frequency class weights.
    
    Args:
        class_counts (Tensor): shape (num_classes,)
    
    Returns:
        alpha (Tensor): shape (num_classes,)
    """
    # Convert to float
    counts = class_counts.float()
    
    # Compute class frequencies
    freq = counts / counts.sum()
    
    # Compute inverse frequencies
    inv_freq = 1.0 / freq
    
    # Normalize to sum to 1
    alpha = inv_freq / inv_freq.sum()
    
    return alpha

# Example class counts
class_counts = torch.tensor([700, 200, 100])
alpha = compute_alpha_from_counts(class_counts)

print("Computed alpha weights:", alpha.tolist())

# ---------------------------------------------------------------
# Define categorical focal loss class
# ---------------------------------------------------------------
class CategoricalFocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, reduction='mean'):
        """
        Categorical Focal Loss.

        Args:
            gamma (float): focusing parameter γ
            alpha (Tensor or None): class weights, shape (num_classes,)
                                    if None, no weighting
            reduction (str): 'mean', 'sum', or 'none'
        """
        super().__init__()
        self.gamma = gamma
        if alpha is not None:
            alpha = torch.tensor(alpha, dtype=torch.float32)
            self.register_buffer('alpha', alpha)
        else:
            self.alpha = None
        self.reduction = reduction

    def forward(self, logits, targets):
        """
        Compute focal loss.

        Args:
            logits (Tensor): shape (batch_size, num_classes)
            targets (Tensor): shape (batch_size,) with integer class indices

        Returns:
            loss (Tensor): scalar if reduced, else shape (batch_size,)
        """
        # Compute softmax probabilities
        probs = F.softmax(logits, dim=1)  # (batch_size, num_classes)

        # Select the probability of the true class
        pt = probs[torch.arange(logits.shape[0]), targets]  # (batch_size,)

        # Compute the modulating factor (1 - pt)^gamma
        focal_factor = (1.0 - pt) ** self.gamma

        # Compute log(pt)
        log_pt = torch.log(pt + 1e-9)

        # Basic focal loss
        loss = - focal_factor * log_pt

        # Apply class weights if given
        if self.alpha is not None:
            at = self.alpha[targets]  # gather alpha weight per example
            loss = at * loss

        # Reduce
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

# ---------------------------------------------------------------
# Dummy test example
# ---------------------------------------------------------------
# Dummy logits from a model (batch_size=4, num_classes=3)
logits = torch.tensor([
    [2.0, 0.5, 0.3],
    [0.2, 2.2, 0.1],
    [1.0, 0.1, 3.0],
    [0.5, 1.5, 0.5]
], requires_grad=True)

# True class labels
targets = torch.tensor([0, 1, 2, 1])

# Create focal loss criterion
criterion = CategoricalFocalLoss(
    gamma=2.0,
    alpha=alpha,
    reduction='mean'
)

# Compute loss
loss = criterion(logits, targets)
print("Focal Loss:", loss.item())

# Backward pass
loss.backward()

# Check gradients
print("Gradients w.r.t. logits:\n", logits.grad)


Computed alpha weights: [0.08695652335882187, 0.30434781312942505, 0.6086956262588501]


  alpha = torch.tensor(alpha, dtype=torch.float32)


Focal Loss: 0.009567061439156532
Gradients w.r.t. logits:
 tensor([[-0.0014,  0.0008,  0.0006],
        [ 0.0010, -0.0018,  0.0009],
        [ 0.0013,  0.0005, -0.0018],
        [ 0.0072, -0.0145,  0.0072]])


: 

# Augmentation