You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The trainer module accepts a string for a loss function. Can I pass a function instead?
Otherwise, where can I find the mapping between these strings and actual pytorch loss functions?
For example, how do I use the CrossEntropyLoss?
The text was updated successfully, but these errors were encountered:
def_validate_loss_input(loss):
dir_f=dir(F)
loss_fns= [d.lower() fordindir_f]
ifisinstance(loss, str):
ifloss.lower() =='unconstrained':
returnlambdax: xelifloss.lower() =='unconstrained_sum':
returnlambdax: x.sum()
elifloss.lower() =='unconstrained_mean':
returnlambdax: x.mean()
else:
try:
str_idx=loss_fns.index(loss.lower())
except:
raiseValueError('Invalid loss string input - must match pytorch function.')
returngetattr(F, dir(F)[str_idx])
elifcallable(loss):
returnlosselse:
raiseValueError('Invalid loss input')
You can see it matches the function names in torch.nn.functional (doesn't matter if the given case is wrong). You can also just pass in a callable function that takes the same arguments as something like F.nll_loss and it will use that.
The trainer module accepts a string for a loss function. Can I pass a function instead?
Otherwise, where can I find the mapping between these strings and actual pytorch loss functions?
For example, how do I use the CrossEntropyLoss?
The text was updated successfully, but these errors were encountered: