Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to specify loss function? #45

Closed
rosenfeldamir opened this issue Jun 29, 2017 · 2 comments
Closed

How to specify loss function? #45

rosenfeldamir opened this issue Jun 29, 2017 · 2 comments

Comments

@rosenfeldamir
Copy link

The trainer module accepts a string for a loss function. Can I pass a function instead?
Otherwise, where can I find the mapping between these strings and actual pytorch loss functions?
For example, how do I use the CrossEntropyLoss?

@rosenfeldamir rosenfeldamir changed the title question How to specify loss function? Jun 29, 2017
@ncullen93
Copy link
Member

Here is the function:

def _validate_loss_input(loss):
    dir_f = dir(F)
    loss_fns = [d.lower() for d in dir_f]
    if isinstance(loss, str):
        if loss.lower() == 'unconstrained':
            return lambda x: x
        elif loss.lower() == 'unconstrained_sum':
            return lambda x: x.sum()
        elif loss.lower() == 'unconstrained_mean':
            return lambda x: x.mean()
        else:
            try:
                str_idx = loss_fns.index(loss.lower())
            except:
                raise ValueError('Invalid loss string input - must match pytorch function.')
            return getattr(F, dir(F)[str_idx])
    elif callable(loss):
        return loss
    else:
        raise ValueError('Invalid loss input')

You can see it matches the function names in torch.nn.functional (doesn't matter if the given case is wrong). You can also just pass in a callable function that takes the same arguments as something like F.nll_loss and it will use that.

@rosenfeldamir
Copy link
Author

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants