-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Don't understand the code. #3
Comments
Yes, there will be a dimension mismatch problem from the code point of view |
Logits finds the similarity between a prediction and all labels in the batch, it's shape should be There should be a more straightforward way to implement BMC. I write it this way for conciseness and to show its resemblance to supervised contrastive loss, as explained in Eqn. 3.15. @painlove1999 Could you provide more information on the dimension mismatch? The current implementation requires both |
I have added more descriptions to variable sizes in readme, hope this helps. def bmc_loss(pred, target, noise_var):
"""Compute the Balanced MSE Loss (BMC) between `pred` and the ground truth `targets`.
Args:
pred: A float tensor of size [batch, 1].
target: A float tensor of size [batch, 1].
noise_var: A float number or tensor.
Returns:
loss: A float tensor. Balanced MSE Loss.
"""
logits = - (pred - target.T).pow(2) / (2 * noise_var) # logit size: [batch, batch]
loss = F.cross_entropy(logits, torch.arange(pred.shape[0])) # contrastive-like loss
loss = loss * (2 * noise_var).detach() # optional: restore the loss scale, 'detach' when noise is learnable
return loss |
Thank you for your answer. I misunderstood that the pred in your code is logits, so its dimension should be [batch, cls_num]. I will try to adapt your code in my program. |
Great love the response. |
Thanks! Same as the standard MSE loss, the target range can be anything, e.g., in age estimation, the target range is [0, 120]. It can be unbounded as well, i.e., (-inf, inf). |
What do you mean by this line?
Cross entropy between logits and [0,1,2,3,4....,batchsize]
F.cross_entropy(logits, torch.arange(pred.shape[0]))
The text was updated successfully, but these errors were encountered: