Skip to content

Commit

Permalink
Fix #123 and add test (Return non-clipped value in custom-jvp rule fo…
Browse files Browse the repository at this point in the history
…r `stax._sqrt`).

PiperOrigin-RevId: 395889733
  • Loading branch information
romanngg committed Sep 10, 2021
1 parent 475a030 commit 8b7917f
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 5 deletions.
3 changes: 2 additions & 1 deletion neural_tangents/stax.py
Expand Up @@ -4345,7 +4345,8 @@ def _sqrt_jvp(tol, primals, tangents):
x_dot, = tangents
safe_tol = max(tol, 1e-30)
square_root = _sqrt(x, safe_tol)
return square_root, np.where(x > safe_tol, x_dot / (2 * square_root), 0.)
square_root_out = _sqrt(x, tol)
return square_root_out, np.where(x > safe_tol, x_dot / (2 * square_root), 0.)


def _get_diagonal(
Expand Down
100 changes: 96 additions & 4 deletions tests/stax_test.py
Expand Up @@ -26,7 +26,7 @@
from jax import lax
from jax import ops
from jax import test_util as jtu
from jax import jit, vjp, jvp, jacfwd, jacrev
from jax import jit, vjp, jvp, jacfwd, jacrev, value_and_grad
from jax.config import config
from jax.lib import xla_bridge
import jax.numpy as np
Expand Down Expand Up @@ -3718,16 +3718,14 @@ def test_conv_local_conv(self):
stax.Relu(),
stax.Conv(width, (3, 3), padding='SAME'),
stax.Relu(),
stax.Conv(width, (3, 3), padding='SAME'),
stax.Relu(),
stax.Conv(width, (3, 3), padding='SAME'))

init_fn, apply_fn, kernel_fn = local_conv

# No projection layer
k = kernel_fn(x1, x2)
self.assertEqual(k.diagonal_spatial, False)
self._test_against_mc(apply_fn, init_fn, k.reverse().nngp, x1, x2, 0.03)
self._test_against_mc(apply_fn, init_fn, k.nngp, x1, x2, 0.03)

# Top layer flat
init_fn, apply_fn, kernel_fn = stax.serial(local_conv, stax.Flatten())
Expand Down Expand Up @@ -3898,6 +3896,100 @@ def assert_close(x, y, tol=3e-5):
assert_close(np.moveaxis(k_fwd_1, (0, 2, 4), (1, 3, 5)), k_fwd_10)
assert_close(np.moveaxis(k_rev_1, (0, 2, 4), (1, 3, 5)), k_rev_10)

@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
f'{get}-{same_inputs}-{input_type}-{phi.__name__}-{do_jit}',
'get': get,
'same_inputs': same_inputs,
'phi': phi,
'input_type': input_type,
'do_jit': do_jit
}
for get in [
'ntk',
'nngp'
]
for do_jit in [True, False]
for input_type in ['zeros', 'ones', 'random']
for same_inputs in [True, False, None]
for phi in [
stax.Erf,
stax.Abs,
stax.Gelu,
stax.Relu,
]))
def test_issue_123(self, get, input_type, same_inputs, phi, do_jit):
"""Tests https://github.com/google/neural-tangents/issues/123."""
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
main = stax.serial(
phi(),
stax.Conv(
channels, (3, 3), strides, padding='SAME',
parameterization='standard'
),
phi(),
stax.Conv(channels, (3, 3), padding='SAME',
parameterization='standard'),
)
shortcut = (
stax.Identity()
if not channel_mismatch
else stax.Conv(
channels, (3, 3), strides, padding='SAME',
parameterization='standard'
)
)
return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut),
stax.FanInSum())

def WideResnetGroup(n, channels, strides=(1, 1)):
blocks = []
blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
for _ in range(n - 1):
blocks += [WideResnetBlock(channels, (1, 1))]
return stax.serial(*blocks)

def WideResnet(block_size, k, num_classes):
return stax.serial(
stax.Conv(16, (3, 3), padding='SAME', parameterization='standard'),
WideResnetGroup(block_size, int(16 * k)),
stax.GlobalAvgPool(),
stax.Dense(num_classes, 1.0, 0.0, parameterization='standard'),
)

_, _, kernel_fn = WideResnet(block_size=1, k=1, num_classes=1)

def get_x(key):
shape = (1, 8, 8, 3)
if input_type == 'zeros':
x = np.zeros(shape)
elif input_type == 'ones':
x = np.ones(shape)
elif input_type == 'random':
x = random.normal(random.PRNGKey(key), shape)
else:
raise ValueError(input_type)
return x

x1 = get_x(1)
if same_inputs is None:
x2 = None
elif same_inputs:
x2 = x1
else:
x2 = get_x(2)

def kernel_scalar(x1, x2):
return kernel_fn(x1, x2, get)[0, 0]

if do_jit:
kernel_scalar = jit(kernel_scalar)

k1 = kernel_scalar(x1, x2)
k2 = value_and_grad(kernel_scalar)(x1, x2)[0]
self.assertAllClose(k1, k2)


if __name__ == '__main__':
absltest.main()

0 comments on commit 8b7917f

Please sign in to comment.