You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello,
I have encountered a problem when I try to get the gradient of some loss function with respect to some input variable x, I got a NaN after several iterations. And this only appears when the NTK is including ReLu layer. I've tried Erf or Sigmoid both don't have this problem.
… help with #88
1) By jaschasd@: switch to arctangent in ReLU instead of arccosine; this greatly stabilizes all NNGP derivatives, and stabilizes the first NTK derivative.
2) Use a new `_sqrt` with custom derivative to make sure the tolerance cutoff only kicks in when computing the derivative, not the forward pass. Remove the `do_backprop` parameter, since now the forward pass is as cheap as before (and in fact, more numerically accurate).
3) Remove safety checks in Erf and other places where I believe the inputs must be within safe range, and we don't need to be that careful. However, I imagine we may have introduced these for a reason, anyone recalls why?
4) Add some basic differentiating sanity checks.
Co-authored-by: Jascha Sohl-dickstein <jaschasd@google.com>
PiperOrigin-RevId: 348948017
Thanks for the report - could you check if this^ commit helps? It should remove the do_backprop argument and fix nans, as well as improve numerical stability of differentiating nonlinearities - lmk if this works!
Thanks for the report - could you check if this^ commit helps? It should remove the do_backprop argument and fix nans, as well as improve numerical stability of differentiating nonlinearities - lmk if this works!
Hello,
I have encountered a problem when I try to get the gradient of some loss function with respect to some input variable x, I got a NaN after several iterations. And this only appears when the NTK is including ReLu layer. I've tried Erf or Sigmoid both don't have this problem.
The kernel function I am getting from:
And I try to grad via:
grads = grad(model_loss, argnums=1)(params, (x, y))[0]
and model_loss = lambda params, (x,y) : loss_func(pred(params, x), y)
btw,
Is the _safe_sqrt function here:
neural-tangents/neural_tangents/stax.py
Line 3847 in c6f759d
is not back_prop safe? We might need np.where to np.maximum, just like in _sqrt.
Thanks!
The text was updated successfully, but these errors were encountered: