Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About loss_c in Multibox Loss #14

Closed
taneslle opened this issue May 31, 2018 · 7 comments
Closed

About loss_c in Multibox Loss #14

taneslle opened this issue May 31, 2018 · 7 comments

Comments

@taneslle
Copy link

taneslle commented May 31, 2018

Hi~
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1,1))

loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1,1))

This operation seems to me that same as to calculate the softmax cross entropy loss, so why not use torch.nn.functional.cross_entropy after softmax directly?
ps. +x_max and -x_max
return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max
in log_sum_exp seems do nothing either.

@hellojialee
Copy link

@taneslle , HI, have you figured it out why?

@lzx1413
Copy link
Owner

lzx1413 commented Aug 30, 2018

Because we want to get the softmax loss for each element. However, in the early version of Pytorch, this function is not supported and only the sum or average loss of all the elements is produced.

@hellojialee
Copy link

@lzx1413 Thank you so much for your replying! Now, we can control the 'reduce=False" to get the element -wise loss. So, do you mean we can just use the F.cross_entropy twice now? One for the hard negative mining, one for the calculation of the final conf_loss ?

@lzx1413
Copy link
Owner

lzx1413 commented Aug 30, 2018

You can calculate the softmax loss once and generate a mask for postive instances and hard negtive instances. Then you can multiply them together to generate the valuable loss.

@hellojialee
Copy link

Yes. Another issue is that if we can do it by using the F.cross_entropy twice. Will it change the input tensor's gradient because the forward pass of the same input is tracked. Thank you for your help again!

@lzx1413
Copy link
Owner

lzx1413 commented Aug 30, 2018

If you don't add the first softmax's loss to the final total loss, it will not affect the network.

@hellojialee
Copy link

Thank you.

@lzx1413 lzx1413 closed this as completed Aug 30, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants