Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nnf_nll_loss - ignore_index #53

Closed
jwijffels opened this issue Jun 16, 2020 · 6 comments
Closed

nnf_nll_loss - ignore_index #53

jwijffels opened this issue Jun 16, 2020 · 6 comments
Labels
nn Related to nn API

Comments

@jwijffels
Copy link
Contributor

Would be great if nnf_nll_loss would have a default value for ignore_index
https://github.com/mlverse/torch/blob/master/R/nnf-loss.R#L343

@dfalbel
Copy link
Member

dfalbel commented Jun 16, 2020

Yeah should be -100 follwing the pytorch impl:

def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100,
             reduce=None, reduction='mean'):

@jwijffels
Copy link
Contributor Author

Yes. I tested that also but it said boom on my Windows machine

@jwijffels
Copy link
Contributor Author

jwijffels commented Jun 16, 2020

What I meant to say is that this crashes my session at the call of cpp_torch_namespace_nll_loss_self_Tensor_target_Tensor

library(torch)
m = nn_log_softmax(dim=1)
input = torch_randn(3, 5, requires_grad=TRUE)
target = torch_tensor(c(1L, 0L, 4L))
input = m(input)
output = nnf_nll_loss(input, target, ignore_index=-100L)
output

while it should be calling https://github.com/mlverse/torch/blob/master/src/lantern/lantern.h#L1649

@dfalbel
Copy link
Member

dfalbel commented Jun 16, 2020

ok, I'll take a look ASAP

@dfalbel
Copy link
Member

dfalbel commented Jun 16, 2020

This works for me if I do:

target = torch_tensor(c(1L, 0L, 4L), dtype = torch_long())

I could consider making torch_long() the default dtype when converting from R integers to torch tensors. We did something similar for R doubles that are converted to Tensors with dtype = torch_float(). What do you think?

@dfalbel dfalbel added the nn Related to nn API label Jun 16, 2020
@jwijffels
Copy link
Contributor Author

Indeed, works with long instead of int. Don't know enough about the C API of lantern/libtorch to give advice. I don't mind specifyng that it is a long. Don't know currently if this impacts speed of anything.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
nn Related to nn API
Projects
None yet
Development

No branches or pull requests

2 participants