Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't understand the code. #3

Closed
caotong0 opened this issue Apr 26, 2022 · 6 comments
Closed

Don't understand the code. #3

caotong0 opened this issue Apr 26, 2022 · 6 comments

Comments

@caotong0
Copy link

What do you mean by this line?
Cross entropy between logits and [0,1,2,3,4....,batchsize]

F.cross_entropy(logits, torch.arange(pred.shape[0]))

@painlove1999
Copy link

Yes, there will be a dimension mismatch problem from the code point of view

@jiawei-ren
Copy link
Owner

jiawei-ren commented Apr 26, 2022

Logits finds the similarity between a prediction and all labels in the batch, it's shape should be [BS, BS]. For example, Logit[i][j] computes the i-th prediction's similarity to j-th label. Ultimately, we'd like the prediction to be most similar to its corresponding label, i.e., i-th prediction should be most similar to i-th label in the batch. This is referred to as "classifying within a batch" in the paper. Therefore, we use the cross-entropy loss for the "classification" and usetorch.arange(pred.shape[0]) as the ground truth, which means exactly that the i-th predicition's ground truth index is i.

There should be a more straightforward way to implement BMC. I write it this way for conciseness and to show its resemblance to supervised contrastive loss, as explained in Eqn. 3.15.

@painlove1999 Could you provide more information on the dimension mismatch? The current implementation requires both pred and label to be [BS,1].

@jiawei-ren
Copy link
Owner

I have added more descriptions to variable sizes in readme, hope this helps.

def bmc_loss(pred, target, noise_var):
    """Compute the Balanced MSE Loss (BMC) between `pred` and the ground truth `targets`.
    Args:
      pred: A float tensor of size [batch, 1].
      target: A float tensor of size [batch, 1].
      noise_var: A float number or tensor.
    Returns:
      loss: A float tensor. Balanced MSE Loss.
    """
    logits = - (pred - target.T).pow(2) / (2 * noise_var)   # logit size: [batch, batch]
    loss = F.cross_entropy(logits, torch.arange(pred.shape[0]))     # contrastive-like loss
    loss = loss * (2 * noise_var).detach()  # optional: restore the loss scale, 'detach' when noise is learnable 

    return loss

@painlove1999
Copy link

Logits 找到预测和批次中所有标签之间的相似性,它的形状应该是[BS, BS]. 例如,Logit[i][j] 计算第 i 个预测与第 j 个标签的相似度。最终,我们希望预测与其对应的标签最相似,即第 i 个预测应该与批次中的第 i 个标签最相似。这在论文中被称为“批次内分类”。因此,我们使用交叉熵损失进行“分类”,并使用torch.arange(pred.shape[0])作为ground truth,这意味着第i个预测的ground truth index是i。

应该有一种更直接的方式来实现 BMC。我这样写是为了简洁,并显示它与监督对比损失的相似之处,如方程式中所述。3.15。

@painlove1999 您能否提供有关尺寸不匹配的更多信息?当前的实现需要predlabelto be [BS,1]

Thank you for your answer. I misunderstood that the pred in your code is logits, so its dimension should be [batch, cls_num]. I will try to adapt your code in my program.

@caotong0
Copy link
Author

Great love the response.
Another question: Is the range of target [0,1] or [0, inf)?

@jiawei-ren
Copy link
Owner

Great love the response. Another question: Is the range of target [0,1] or [0, inf)?

Thanks! Same as the standard MSE loss, the target range can be anything, e.g., in age estimation, the target range is [0, 120]. It can be unbounded as well, i.e., (-inf, inf).

@caotong0 caotong0 closed this as completed May 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants