Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN occurs when backprop through the kernel including ReLu layer #88

Closed
ZuowenWang0000 opened this issue Dec 22, 2020 · 2 comments
Closed
Labels
bug Something isn't working

Comments

@ZuowenWang0000
Copy link

Hello,
I have encountered a problem when I try to get the gradient of some loss function with respect to some input variable x, I got a NaN after several iterations. And this only appears when the NTK is including ReLu layer. I've tried Erf or Sigmoid both don't have this problem.

The kernel function I am getting from:

  self.init_fn, self.f, self.kernel_fn = ntstax.serial(
      ntstax.Dense(1, parameterization='ntk'),
      ntstax.Relu(do_backprop=True, do_stabilize=True),
      ntstax.Dense(1, parameterization='ntk')
  )

And I try to grad via:
grads = grad(model_loss, argnums=1)(params, (x, y))[0]

and model_loss = lambda params, (x,y) : loss_func(pred(params, x), y)

btw,
Is the _safe_sqrt function here:

def _safe_sqrt(x):

is not back_prop safe? We might need np.where to np.maximum, just like in _sqrt.

Thanks!

@romanngg romanngg added the bug Something isn't working label Dec 24, 2020
romanngg added a commit that referenced this issue Dec 24, 2020
… help with #88

1) By jaschasd@: switch to arctangent in ReLU instead of arccosine; this greatly stabilizes all NNGP derivatives, and stabilizes the first NTK derivative.

2) Use a new `_sqrt` with custom derivative to make sure the tolerance cutoff only kicks in when computing the derivative, not the forward pass. Remove the `do_backprop` parameter, since now the forward pass is as cheap as before (and in fact, more numerically accurate).

3) Remove safety checks in Erf and other places where I believe the inputs must be within safe range, and we don't need to be that careful. However, I imagine we may have introduced these for a reason, anyone recalls why?

4) Add some basic differentiating sanity checks.

Co-authored-by: Jascha Sohl-dickstein <jaschasd@google.com>
PiperOrigin-RevId: 348948017
@romanngg
Copy link
Contributor

Thanks for the report - could you check if this^ commit helps? It should remove the do_backprop argument and fix nans, as well as improve numerical stability of differentiating nonlinearities - lmk if this works!

@ZuowenWang0000
Copy link
Author

Thanks for the report - could you check if this^ commit helps? It should remove the do_backprop argument and fix nans, as well as improve numerical stability of differentiating nonlinearities - lmk if this works!

ReLU works now. thanks for the timely update!

@romanngg romanngg closed this as completed Jan 5, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants