Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

value_and_grad(kernel_fn) not equal to kernel_fn with standard parameterization #123

Closed
PythonNut opened this issue Sep 10, 2021 · 8 comments
Labels
bug Something isn't working

Comments

@PythonNut
Copy link

PythonNut commented Sep 10, 2021

I am confused by the behavior of the following snippet of code (the WideResNet from the README with standard parameterization):

import jax
from neural_tangents import stax


def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    main = stax.serial(
        stax.Relu(),
        stax.Conv(
            channels, (3, 3), strides, padding="SAME", parameterization="standard"
        ),
        stax.Relu(),
        stax.Conv(channels, (3, 3), padding="SAME", parameterization="standard"),
    )
    shortcut = (
        stax.Identity()
        if not channel_mismatch
        else stax.Conv(
            channels, (3, 3), strides, padding="SAME", parameterization="standard"
        )
    )
    return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut), stax.FanInSum())


def WideResnetGroup(n, channels, strides=(1, 1)):
    blocks = []
    blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
    for _ in range(n - 1):
        blocks += [WideResnetBlock(channels, (1, 1))]
    return stax.serial(*blocks)


def WideResnet(block_size, k, num_classes):
    return stax.serial(
        stax.Conv(16, (3, 3), padding="SAME", parameterization="standard"),
        WideResnetGroup(block_size, int(16 * k)),
        WideResnetGroup(block_size, int(32 * k), (2, 2)),
        WideResnetGroup(block_size, int(64 * k), (2, 2)),
        stax.AvgPool((8, 8)),
        stax.Flatten(),
        stax.Dense(num_classes, 1.0, 0.0, parameterization="standard"),
    )


_, _, kernel_fn = WideResnet(block_size=4, k=1, num_classes=1)

def kernel_scalar(x, y):
    return kernel_fn(x, y, "ntk")[0, 0]

z = jax.numpy.zeros((1, 32, 32, 3))
print(jax.value_and_grad(kernel_scalar)(z, z)[0])
print(kernel_scalar(z, z))

My understanding is that the two printed values should be the same. However, when I run it, I get two totally different values:

34.41480472358908
64.62813414153004

Is my understanding correct? I have not yet found a simpler network that features this behavior.

Versions:

  • jax 0.2.20
  • jaxlib 0.1.71+cuda111
  • neural-tangents 0.3.7
@romanngg
Copy link
Contributor

Thanks for the repro and good find, it's indeed a bug in our custom differentiation rule for the square root, where we clip the derivative around zero, but also clipped the outputs accidentally as well. I've sent a change to fix it, but needs code review so will likely land tomorrow, in the meantime this is what the change looks like

def _sqrt_jvp(tol, primals, tangents):

def _sqrt_jvp(tol, primals, tangents):
  x, = primals
  x_dot, = tangents
  safe_tol = max(tol, 1e-30)
  square_root = _sqrt(x, safe_tol)
+ square_root_out = _sqrt(x, tol)
- return square_root, np.where(x > safe_tol, x_dot / (2 * square_root), 0.)
+ return square_root_out, np.where(x > safe_tol, x_dot / (2 * square_root), 0.)

@romanngg romanngg added the bug Something isn't working label Sep 10, 2021
@PythonNut
Copy link
Author

Wow thanks for quickly determining the issue!

@PythonNut
Copy link
Author

Hmm after pulling 8b7917f, I see that the values match now, but the grad is all nans. Is that the intended outcome?

@romanngg romanngg reopened this Sep 20, 2021
@romanngg
Copy link
Contributor

romanngg commented Oct 1, 2021

Thanks, I'll need look into this, for the meantime, I suspect it's only happening for zero-value inputs and generally shouldn't be a problem otherwise (but perhaps I'm wrong, so worth double-checking to see if there are still nans or discrepancy in normal inputs like images etc)

@PythonNut
Copy link
Author

A much smaller example reproducing the nan issue:

from jax import *
from neural_tangents.stax import *

def f(x): return serial(Conv(1, (3, 3)), Relu(), Flatten())[2](x, x, "ntk")[0][0]
print(grad(f)(jax.numpy.zeros((1, 32, 32, 3))))

I guess this no longer has anything to do with parameterization="standard" (it happens either way), so should this be a new issue?

@romanngg
Copy link
Contributor

romanngg commented Oct 2, 2021

I think it's probably the same issue, likely related to differentiating kernel functions with Relu (maybe some other nonlinearities too, will need to look into this) at exactly zero, i.e. jax.numpy.zeros((1, 32, 32, 3)). I agree it's likely not specific to parameterization.

romanngg added a commit that referenced this issue Jan 21, 2022
…7d01d65#diff-096654d44536fb53f7fee3c9c85f41ab9fedb894de1333cbdb39959b6b914fd6 I made a change to default to no bias variable instead of `b_std=0`, but in standard parameterization no bias variable (`b_std=None`) is different from zero-variance bias variable (`b_std=0`). This is also related to #123.

Also update tests to catch this error (scale down `W_std`, which in standard parameterization dwarfed the bias contribution). Add absolute tolerance to testing and logging. Make `stax` tests deterministic.

PiperOrigin-RevId: 421074614
romanngg added a commit that referenced this issue Feb 8, 2022
…ble nonlinearities at 0 - #123.

Currently we have NaNs and/or values at 0 that are inconsistent with the limit of finite-width empirical kernels.

First, note that `np.sign(0) == 0`, therefore we must have `T_sign(0, x2) = T_sign(x1, 0) = T_sign(0, 0) = 0`, but we currently have it equal to 1, i.e. we assume incorrectly in our infinite width limit that `np.sign(0) == 1`.

Secondly, JAX defines gradient of ABRelu at 0 to be (a + b) / 2, i.e. mean subgradient. This means that we must have `Tdot_abrelu(0, x2) = Tdot_abrelu(x1, 0) = Tdot_abrelu(0, 0) = [(a + b) / 2]^2`, but we currently have it equal to `(a^2 + b^2) / 2`, which is equivalent to assuming that the gradient is `[(a^2 + b^2) / 2)]^0.5`, i.e. for Relu the gradient at 0 is 1/2^0.5 instead of 1/2.

We fix the above issues by extending `np.arctan2(0, 0) := np.pi / 2` (mathematically the function is undefined, and by default JAX/numpy have it be 0, but `np.pi / 2` gives us correct values above).

Finally, we also extend the gradient of np.arctan2 at (0, 0) to (0, 0). The gradient at 0 is by default undefined, and earlier we had NaN gradients at zero inputs to nonlinearities. While the gradient can't be extended continuously at (0, 0), setting it to (0, 0) at least makes it continuous along `x = 0` or `y = 0`, and helps fix a lot of NaNs.

Also add more tests, including comparisons of gradients of infinite-width kernels with MC estimates.

Make matrix comparison tests fail on NaNs or infinities.

PiperOrigin-RevId: 427073332
@romanngg
Copy link
Contributor

romanngg commented Feb 8, 2022

Thank you for your patience here! I think you were right that there were actually two bugs here.

One was wrong treatment of biases with b_std=None, parameterization='standard' (precisely, in standard parameterization, having no bias (b_std=None) and having a zero-variance bais (b_std=0) is not the same.

The other was that the derivative at x1 or x2 being zero is technically undefined for ReLU and similar activations. We now set the gradient at zero inputs to be 0. Note that this is correct for x1 = x2 = 0, and for x2 = 0, but e.g. dK(x1, x2 != 0)/dx1 is genuinely undefined/infinite at x1 = 0, but we will return 0. While this technically incorrect, this matches JAX's behavior of defining the gradient of non-differentiable functions as the mean subgradient, e.g. in JAX jax.grad(jax.numpy.sign)(0.) == 0., or jax.grad(lambda x: jax.numpy.maximum(x, 0.))(0.) == 0.5, so arguably this is a reasonable value to return.

Hope this helps!

@PythonNut
Copy link
Author

Thanks so much for the thorough fix! All of the gradient-related anomalies I've been seeing have gone away. I'll open new issues if I run into more problems in the future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants