In [1]:
import torch
from typing import List
import torch.nn.functional as F

Tensor = torch.Tensor


In [2]:
# set random seed
seed = 5
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

# define random targets and predictions
num_samples = 5
num_classes = 2

preds = torch.rand(num_samples, num_classes)
preds_softmax = F.softmax(preds, dim=1)

targets = torch.randint(0, num_classes, (num_samples, 1))
targets_onehot = F.one_hot(targets, num_classes=num_classes).squeeze(1)

preds_pos = torch.max(preds_softmax * targets_onehot, dim=1).values
preds_neg = torch.max(preds_softmax * (1 - targets_onehot), dim=1).values

# define random aesthetic classes
aesthetics = torch.randint(0, 2, (num_samples, 3))


In [3]:
# set which aesthetic class/classes we want to ignore
# the sequence is low contrast, blurry, broken
aes_to_ignore = torch.tensor([1, 0, 0])

difficult_labels = []
for x in aesthetics:
    if (x & aes_to_ignore).sum() > 0:
        difficult_labels.append(1)
    else:
        difficult_labels.append(0)
difficult_labels = torch.tensor(difficult_labels)


In [4]:
# defined necessary loss functions for AGCL
def cross_entropy(preds: Tensor = None, targets: Tensor = None):
    """
    Calculate the Cross Entropy Loss given predictions and targets (ground truths).
    Please refer to https://pytorch.org/docs/stable/generated/F.cross_entropy.html.

    Args:
        preds (Tensor): Predicted unnormalized logits.
        targets (Tensor) : Ground truth class indices or class probabilities.
    Returns:
        ce_loss (Tensor): calculated cross entropy loss.
    """
    ce_loss = F.cross_entropy(preds, targets.squeeze(1), reduction="none")
    return ce_loss


def positive_agcl(
    positive_w: float = 0.0, positive_gamma: float = 0.0, preds_pos: Tensor = None
):
    """
    Calculate the AGCL given positive predictions and targets (ground truths).

    Args:
        positive_w (float): Value for w_+ in Equation 5.
        positive_gamma (float): Value for gamma_+ in Equation 5.
        preds_pos (Tensor): Predicted unnormalized logits of positive samples.
    Returns:
        loss (Tensor): calculated positive_agcl.
    """
    loss = (
        -positive_w * torch.pow((1 - preds_pos), positive_gamma) * torch.log(preds_pos)
    )
    return loss


def negative_agcl(
    negative_w: float = 0.0,
    m: float = 0.0,
    negative_gamma: float = 0.0,
    preds_neg: Tensor = None,
):
    """
    Calculate the AGCL given negative predictions and targets (ground truths).

    Args:
        negative_w (float): Value for w_- in Equation 5.
        m (float): Value for shifted probability (p_m) specified in 
        the work of Asymmetric Loss (https://arxiv.org/pdf/2009.14119.pdf).
        negative_gamma (float): Value for gamma_- in Equation 5.
        preds_neg (Tensor): Predicted unnormalized logits of negative samples.
    Returns:
        loss (Tensor): calculated negative_agcl.
    """
    p_m = (preds_neg - m).clamp(min=0)
    p_m_2 = (1 - preds_neg + m).clamp(min=0)
    loss = -negative_w * torch.pow(p_m, negative_gamma) * torch.log(p_m_2)
    return loss


def agcl(
    stage: int = 1,
    positive_w: float = 0.0,
    positive_gamma: float = 0.0,
    preds_pos: float = 0.0,
    difficult_labels: List[int] = [0],
    negative_w: float = 0.0,
    m: float = 0.0,
    negative_gamma: float = 0.0,
    preds_neg: Tensor = None,
    preds: Tensor = None,
    targets: Tensor = None,
):
    """
    Calculate the complete AGCL in two stages.
    Stage 1 of AGCL: 
        - Calculate the loss of positive samples using positive_agcl.
        - Calculate the loss of negative samples using negative_agcl.
        - Mask out the loss for difficult samples to zero out their gradients.
    
    Stage 2 of AGCL:
        - Calculate the loss of all samples using cross_entropy.

    Args:
        stage (int): Value to indicate stage of AGCL.
        positive_w (float): Value for w_+ in Equation 5.
        positive_gamma (float): Value for gamma_+ in Equation 5.
        preds_pos (Tensor): Predicted unnormalized logits of positive samples.
        difficult_labels (List[int]): Mask to zero out the gradient of difficult samples.
        negative_w (float): Value for w_- in Equation 5.
        m (float): Value for shifted probability (p_m) specified in 
        the work of Asymmetric Loss (https://arxiv.org/pdf/2009.14119.pdf).
        negative_gamma (float): Value for gamma_- in Equation 5.
        preds_neg (Tensor): Predicted unnormalized logits of negative samples.
    Returns:
        loss (Tensor): calculated loss averaged over all samples.
    """
    if stage == 1:
        positive_loss = positive_agcl(positive_w, positive_gamma, preds_pos)
        negative_loss = negative_agcl(negative_w, m, negative_gamma, preds_neg)
        loss = positive_loss * difficult_labels + negative_loss * difficult_labels
    else:
        loss = cross_entropy(preds, targets)
    return loss.mean()


In [5]:
# Detailed breakdown of Stage 1 of AGCL
positive_w = 1
positive_gamma = 0

stage_1_positive_agcl = positive_agcl(positive_w, positive_gamma, preds_pos)

negative_w = 0.52
m = 0.14
negative_gamma = 1.06

stage_1_negative_agcl = negative_agcl(negative_w, m, negative_gamma, preds_neg)

stage_1_final = (
    stage_1_positive_agcl * difficult_labels + stage_1_negative_agcl * difficult_labels
)

zero_indices = ", ".join(str(x[0]) for x in (difficult_labels == 0).nonzero().tolist())
print(f"Here we can see that loss values of {stage_1_final} at {zero_indices} is 0.")
print("Therefore, we ignored the training loss of difficult characters.")
print(f"Stage 1 of AGCL Loss: {stage_1_final.mean()}")


Here we can see that loss values of tensor([0.0000, 0.7183, 1.3946, 0.0000, 1.5419]) at 0, 3 is 0.
Therefore, we ignored the training loss of difficult characters.
Stage 1 of AGCL Loss: 0.7309598326683044


In [6]:
# Detailed breakdown of Stage 2 of AGCL
stage_2_ce = cross_entropy(preds, targets)

print(f"Stage 2 of AGCL Loss: {stage_2_ce.mean()}")


Stage 2 of AGCL Loss: 1.0470192432403564


In [7]:
# General training flow of AGCL
# This is just for demo purposes
positive_w = 1
positive_gamma = 0
negative_w = 0.52
m = 0.14
negative_gamma = 1.06

total_epochs = 10
stage_2_epoch = 8

for epoch in range(total_epochs):
    if epoch <= stage_2_epoch - 1:
        loss = agcl(
            1,
            positive_w,
            positive_gamma,
            preds_pos,
            difficult_labels,
            negative_w,
            m,
            negative_gamma,
            preds_neg,
        )
        print(f"Stage: 1, Epoch: {epoch}, Loss: {loss}")
    else:
        loss = agcl(2, preds=preds, targets=targets)
        print(f"Stage: 2, Epoch: {epoch}, Loss: {loss}")


Stage: 1, Epoch: 0, Loss: 0.7309598326683044
Stage: 1, Epoch: 1, Loss: 0.7309598326683044
Stage: 1, Epoch: 2, Loss: 0.7309598326683044
Stage: 1, Epoch: 3, Loss: 0.7309598326683044
Stage: 1, Epoch: 4, Loss: 0.7309598326683044
Stage: 1, Epoch: 5, Loss: 0.7309598326683044
Stage: 1, Epoch: 6, Loss: 0.7309598326683044
Stage: 1, Epoch: 7, Loss: 0.7309598326683044
Stage: 2, Epoch: 8, Loss: 1.0470192432403564
Stage: 2, Epoch: 9, Loss: 1.0470192432403564
