Skip to content

Commit

Permalink
Deflake tests; update citation file
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 500748660
  • Loading branch information
romanngg committed Jan 9, 2023
1 parent 42e9ecf commit db1c240
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 116 deletions.
10 changes: 10 additions & 0 deletions CITATION
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,13 @@
pdf={https://arxiv.org/abs/2001.07301},
url={https://github.com/google/neural-tangents}
}

# Elementwise nonlinearities and sketching:
@inproceedings{han2022fast,
title={Fast Neural Kernel Embeddings for General Activations},
author={Insu Han and Amir Zandieh and Jaehoon Lee and Roman Novak and Lechao Xiao and Amin Karbasi},
booktitle = {Advances in Neural Information Processing Systems},
year={2022},
pdf={https://arxiv.org/abs/2209.04121},
url={https://github.com/google/neural-tangents}
}
209 changes: 105 additions & 104 deletions README.md

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions tests/empirical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,7 +1254,7 @@ def __call__(self, x, train: bool = True):
return x


_ResNet18 = partial(_ResNet, stage_sizes=[2, 2, 2, 2],
_ResNet18 = partial(_ResNet, stage_sizes=[1, 1, 1, 1],
block_cls=_ResNetBlock)


Expand Down Expand Up @@ -1440,8 +1440,8 @@ def test_resnet18(self, same_inputs, do_jit, do_remat, dtype, j_rules,

model = _ResNet18(num_classes=1)
k1, k2, ki = random.split(random.PRNGKey(1), 3)
x1 = random.normal(k1, (1, 224, 224, 1), dtype)
x2 = None if same_inputs else random.normal(k2, (1, 224, 224, 1), dtype)
x1 = random.normal(k1, (1, 128, 128, 1), dtype)
x2 = None if same_inputs else random.normal(k2, (1, 128, 128, 1), dtype)
p = model.init(ki, x1)

def apply_fn(params, x):
Expand All @@ -1456,8 +1456,8 @@ def test_mixer_b16(self, same_inputs, do_jit, do_remat, dtype, j_rules,

model = _MlpMixer(num_classes=1, **_get_mixer_b16_config())
k1, k2, ki = random.split(random.PRNGKey(1), 3)
x1 = random.normal(k1, (1, 224, 224, 1), dtype)
x2 = None if same_inputs else random.normal(k2, (1, 224, 224, 1), dtype)
x1 = random.normal(k1, (1, 128, 128, 1), dtype)
x2 = None if same_inputs else random.normal(k2, (1, 128, 128, 1), dtype)
p = model.init(ki, x1, train=True)

def apply_fn(params, x):
Expand Down
6 changes: 3 additions & 3 deletions tests/experimental/empirical_tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _resnet(x, blocks_per_layer, classes, filters):

def _MiniResNet(classes, input_shape, weights):
inputs = tf.keras.Input(shape=input_shape)
outputs = _resnet(inputs, [1, 1, 1, 1], classes=classes, filters=4)
outputs = _resnet(inputs, [1, 1, 1, 1], classes=classes, filters=2)
return tf.keras.Model(inputs=inputs, outputs=outputs)


Expand Down Expand Up @@ -262,7 +262,7 @@ def _compare_ntks(
# tf.keras.applications.MobileNet,
],
input_shape=[
(64, 64, 3)
(32, 32, 3)
],
trace_axes=[
(),
Expand Down Expand Up @@ -291,7 +291,7 @@ def test_keras_functional(

@parameterized.product(
input_shape=[
(32, 32, 3)
(16, 16, 3)
],
trace_axes=[
(),
Expand Down
2 changes: 1 addition & 1 deletion tests/rules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
(1, 2, 3),
(6, 3, 2),
# (2, 1, 1),
(3, 2, 1),
# (3, 2, 1),
(1, 2, 1, 3),
# (2, 2, 2, 2),
(2, 1, 3, 4),
Expand Down
4 changes: 2 additions & 2 deletions tests/stax/elementwise_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ def test_activations(
test_utils.skip_test(self)

n_out = 1 if get == 'ntk' else 1024
width = 832
width = 612

W_std_in = width**(-0.5) if parameterization_out == 'standard' else 1.
if phi == stax.Exp:
Expand Down Expand Up @@ -909,7 +909,7 @@ def kernel_scalar_mc_grad_mean(x1, x2):
k_mc2_mean, k_mc2_grad_mean = kernel_scalar_mc_grad_mean(x1, x2)

# Compare kernels.
self.assertAllClose(k1, k_mc2_mean, atol=4e-3, rtol=4e-2)
self.assertAllClose(k1, k_mc2_mean, atol=8e-3, rtol=4e-2)

if phi == stax.Sign and get == 'nngp':
raise absltest.SkipTest('Derivative of the empirical NNGP of a '
Expand Down
2 changes: 1 addition & 1 deletion tests/stax/linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1628,7 +1628,7 @@ def test_conv_local_conv(self):
time_pool = time.time()
k = k_jit(x1, x2).nngp.block_until_ready()
time_pool = time.time() - time_pool
self.assertLess(time_flat * 5, time_pool)
self.assertLess(time_flat * 4, time_pool)
self._test_against_mc(apply_fn, init_fn, k, x1, x2, 0.03)

# Top layer LCN + pooling
Expand Down

0 comments on commit db1c240

Please sign in to comment.