Skip to content

Commit

Permalink
Use forks of parameterized.product/parameters to limit the number o…
Browse files Browse the repository at this point in the history
…f test cases that run in public GitHub actions. Without this the previous change caused the whole grid to be evaluated publicly, making them time out.

PiperOrigin-RevId: 458600975
  • Loading branch information
romanngg committed Jul 2, 2022
1 parent da15ae2 commit b1202ed
Show file tree
Hide file tree
Showing 12 changed files with 379 additions and 321 deletions.
19 changes: 9 additions & 10 deletions tests/batching_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Tests for `neural_tangents/_src/batching.py`."""

from absl.testing import absltest
from absl.testing import parameterized

from functools import partial
from jax import jit
Expand Down Expand Up @@ -158,7 +157,7 @@ def _get_data_and_kernel_fn(
kernel_fn = kernel_fn(key, input_shape, network, **kwargs)
return data_other, data_self, kernel_fn

@parameterized.product(
@test_utils.product(
train_size=TRAIN_SIZES,
test_size=TEST_SIZES,
input_shape=INPUT_SHAPES,
Expand Down Expand Up @@ -189,7 +188,7 @@ def testSerial(

# We also exclude tests for dropout + parallel. It is not clear what is the
# best way to handle this case.
@parameterized.product(
@test_utils.product(
train_size=TRAIN_SIZES,
test_size=TEST_SIZES,
input_shape=INPUT_SHAPES,
Expand Down Expand Up @@ -217,7 +216,7 @@ def testParallel(
_test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self,
data_other, True)

@parameterized.product(
@test_utils.product(
train_size=TRAIN_SIZES,
test_size=TEST_SIZES,
input_shape=INPUT_SHAPES,
Expand Down Expand Up @@ -250,7 +249,7 @@ def testComposition(
_test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self,
data_other)

@parameterized.product(
@test_utils.product(
train_size=TRAIN_SIZES,
test_size=TEST_SIZES,
input_shape=INPUT_SHAPES,
Expand Down Expand Up @@ -334,7 +333,7 @@ def _test_analytic_kernel_composition(self, batching_fn):
composed_ker_out = composed_ker_out.replace(x1_is_x2=ker_out.x1_is_x2)
self.assertAllClose(ker_out, composed_ker_out)

@parameterized.product(
@test_utils.product(
store_on_device=[True, False],
batch_size=[2, 8]
)
Expand All @@ -349,7 +348,7 @@ def testAnalyticKernelComposeParallel(self):
test_utils.stub_out_pmap(batching, 2)
self._test_analytic_kernel_composition(batching._parallel)

@parameterized.product(
@test_utils.product(
store_on_device=[True, False],
batch_size=[2, 8]
)
Expand Down Expand Up @@ -428,7 +427,7 @@ def broadcast(arg):
self.assertAllClose(res_1[0][1], res_2[0][1])
self.assertAllClose(tree_map(broadcast, res_1[1]), res_2[1])

@parameterized.product(
@test_utils.product(
same_inputs=[True, False]
)
def test_parallel_in_out(self, same_inputs):
Expand Down Expand Up @@ -483,7 +482,7 @@ def net(N_out):
batch_K_readout_fn(batch_K_readin_fn(x1, x2)),
RTOL)

@parameterized.product(
@test_utils.product(
same_inputs=[True, False]
)
def test_parallel_in_out_empirical(self, same_inputs):
Expand Down Expand Up @@ -533,7 +532,7 @@ def net(N_out):
batch_kernel_fn(x1, x2, params),
RTOL)

@parameterized.product(
@test_utils.product(
same_inputs=[True, False],
device_count=[-1, 0, 1, 2],
trace_axes=[(-1,), (1, -1), ()],
Expand Down
23 changes: 11 additions & 12 deletions tests/empirical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import operator
from typing import Any, Callable, Sequence, Tuple, Optional, Dict, List
from absl.testing import absltest
from absl.testing import parameterized
from flax import linen as nn
import jax
from jax import jacobian, lax, remat
Expand Down Expand Up @@ -166,7 +165,7 @@ def _get_init_data(cls, shape):
x0 = random.normal(split, (shape[-1], 1))
return key, params, x0

@parameterized.product(
@test_utils.product(
shape=TAYLOR_MATRIX_SHAPES
)
def testLinearization(self, shape):
Expand All @@ -184,7 +183,7 @@ def testLinearization(self, shape):
x0, x, params, do_alter, do_shift_x=do_shift_x),
f_lin(x, params, do_alter, do_shift_x=do_shift_x))

@parameterized.product(
@test_utils.product(
shape=TAYLOR_MATRIX_SHAPES
)
def testTaylorExpansion(self, shape):
Expand Down Expand Up @@ -241,7 +240,7 @@ def _compare_kernels(self, x1, x2, ntk_fns, ntk_fns_vmapped, nngp_fn):
for i, ntk in ntks_vmapped.items():
self.assertAllClose(ntk_ref, ntk, err_msg=f'{i} vmapped impl. fails.')

@parameterized.product(
@test_utils.product(
train_test_network=list(zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK)),
kernel_type=list(KERNELS.keys())
)
Expand Down Expand Up @@ -273,7 +272,7 @@ def testNTKAgainstDirect(self, train_test_network, kernel_type):
self._compare_kernels(x1, None, ntk_fns, ntk_fns_vmapped, nngp_fn)
self._compare_kernels(x1, x2, ntk_fns, ntk_fns_vmapped, nngp_fn)

@parameterized.product(
@test_utils.product(
diagonal_axes=[
(),
(0,),
Expand Down Expand Up @@ -333,7 +332,7 @@ def testAxes(self, diagonal_axes, trace_axes):
if 0 not in _trace_axes and 0 not in _diagonal_axes:
self._compare_kernels(x1, x2, ntk_fns, ntk_fns_vmapped, nngp_fn)

@parameterized.product(
@test_utils.product(
same_inputs=[True, False]
)
def test_parallel_in_out(self, same_inputs):
Expand Down Expand Up @@ -381,7 +380,7 @@ def layer(N_out):
self.assertEqual(nngp[1].shape, (3, 3 if same_inputs else 4))
self._compare_kernels(x1, x2, ntk_fns, ntk_fns_vmapped, nngp_fn)

@parameterized.product(
@test_utils.product(
same_inputs=[True, False]
)
def test_parallel_nested(self, same_inputs):
Expand Down Expand Up @@ -440,7 +439,7 @@ def layer(N_out):
self.assertEqual(nngp[0][1].shape, nngp_shape)
self.assertEqual(nngp[1].shape, nngp_shape)

@parameterized.product(
@test_utils.product(
same_inputs=[True, False]
)
def test_vmap_axes(self, same_inputs):
Expand Down Expand Up @@ -843,7 +842,7 @@ def _compare_ntks(

class StructuredDerivativesTest(test_utils.NeuralTangentsTestCase):

@parameterized.product(
@test_utils.product(
_j_rules=[
True,
False
Expand Down Expand Up @@ -1269,7 +1268,7 @@ def _get_mixer_b16_config() -> Dict[str, Any]:
)


@parameterized.product(
@test_utils.product(
j_rules=[
True,
False
Expand Down Expand Up @@ -1405,7 +1404,7 @@ def apply_fn(params, x):
s_rules, fwd)


@parameterized.product(
@test_utils.product(
j_rules=[
True,
False
Expand Down Expand Up @@ -1457,7 +1456,7 @@ def test_flax_cnn(self, same_inputs, do_jit, do_remat, dtype, j_rules,
s_rules, fwd, vmap_axes=0)


@parameterized.product(
@test_utils.product(
j_rules=[
True,
False
Expand Down
3 changes: 1 addition & 2 deletions tests/experimental/empirical_tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def _resnet(x, blocks_per_layer, classes, filters):
x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, name='maxpool')(x)

x = _make_layer(x, filters, blocks_per_layer[0], name='layer1')
x = _make_layer(x, 2 * filters, blocks_per_layer[1], stride=2, name='layer2')

x = tf.keras.layers.GlobalAveragePooling2D(name='avgpool')(x)
initializer = tf.keras.initializers.RandomUniform(-1.0 / (2 * filters)**0.5,
Expand All @@ -182,7 +181,7 @@ def _resnet(x, blocks_per_layer, classes, filters):

def _MiniResNet(classes, input_shape, weights):
inputs = tf.keras.Input(shape=input_shape)
outputs = _resnet(inputs, [1, 1, 1, 1], classes=classes, filters=2)
outputs = _resnet(inputs, [1, 1, 1, 1], classes=classes, filters=4)
return tf.keras.Model(inputs=inputs, outputs=outputs)


Expand Down
13 changes: 6 additions & 7 deletions tests/monte_carlo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Tests for `neural_tangents/_src/monte_carlo.py`."""

from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import random
from jax.config import config
Expand Down Expand Up @@ -67,7 +66,7 @@ def _get_inputs_and_model(width=1, n_classes=2, use_conv=True):

class MonteCarloTest(test_utils.NeuralTangentsTestCase):

@parameterized.product(
@test_utils.product(
batch_size=BATCH_SIZES,
device_count=DEVICE_COUNTS,
store_on_device=STORE_ON_DEVICE,
Expand All @@ -93,7 +92,7 @@ def test_sample_once_batch(
one_sample_batch = sample_once_batch_fn(x1, x2, key, get)
self.assertAllClose(one_sample, one_sample_batch)

@parameterized.product(
@test_utils.product(
batch_size=BATCH_SIZES,
device_count=DEVICE_COUNTS,
store_on_device=STORE_ON_DEVICE,
Expand All @@ -118,7 +117,7 @@ def test_batch_sample_once(
one_batch_sample = batch_sample_once_fn(x1, x2, key, get)
self.assertAllClose(one_sample, one_batch_sample)

@parameterized.product(
@test_utils.product(
batch_size=BATCH_SIZES,
device_count=DEVICE_COUNTS,
store_on_device=STORE_ON_DEVICE
Expand All @@ -143,7 +142,7 @@ def test_sample_vs_analytic_nngp(

test_utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)

@parameterized.product(
@test_utils.product(
batch_size=BATCH_SIZES,
device_count=DEVICE_COUNTS,
store_on_device=STORE_ON_DEVICE
Expand All @@ -169,7 +168,7 @@ def test_monte_carlo_vs_analytic_ntk(

test_utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)

@parameterized.product(
@test_utils.product(
batch_size=BATCH_SIZES,
device_count=DEVICE_COUNTS,
store_on_device=STORE_ON_DEVICE,
Expand Down Expand Up @@ -240,7 +239,7 @@ def test_monte_carlo_generator(
self.assertAllClose(ker_analytic_12, s_12, atol=2., rtol=2.)
self.assertAllClose(ker_analytic_12, ker_analytic_34)

@parameterized.product(
@test_utils.product(
same_inputs=[True, False],
batch_size=[1, 2]
)
Expand Down
21 changes: 10 additions & 11 deletions tests/predict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import math

from absl.testing import absltest
from absl.testing import parameterized
from jax import grad
from jax import jit
from jax import random
Expand Down Expand Up @@ -176,7 +175,7 @@ def _get_inputs(cls, out_logits, test_shape, train_shape):
x_test = random.normal(split, test_shape)
return key, x_test, x_train, y_train

@parameterized.product(
@test_utils.product(
train_size=TRAIN_SIZES,
test_size=TEST_SIZES,
input_shape=INPUT_SHAPES,
Expand Down Expand Up @@ -265,7 +264,7 @@ def _cov_empirical(cls, x):
return np.einsum('itjk,itlk->tjl', x, x, optimize=True) / (x.shape[0] *
x.shape[-1])

@parameterized.product(
@test_utils.product(
train_size=TRAIN_SIZES[:1],
test_size=TEST_SIZES[:1],
input_shape=INPUT_SHAPES[:1],
Expand Down Expand Up @@ -338,7 +337,7 @@ def predict_mc(count, key):
assert_close(cov_test_mc, cov_test_inf)
assert_close(fx_test_mc, fx_test_inf)

@parameterized.product(
@test_utils.product(
train_size=TRAIN_SIZES[:-1],
test_size=TEST_SIZES[:-1],
input_shape=INPUT_SHAPES,
Expand Down Expand Up @@ -383,7 +382,7 @@ def testGradientDescentMseEnsembleGet(
self.assertAllClose(out[0], out2[1])
self.assertAllClose(out[1], out2[0])

@parameterized.product(
@test_utils.product(
train_size=TRAIN_SIZES[:-1],
test_size=TEST_SIZES[:-1],
input_shape=INPUT_SHAPES[:-1],
Expand Down Expand Up @@ -424,7 +423,7 @@ def testInfiniteTimeAgreement(
self.assertAllClose(inf, inf_x)
self.assertAllClose(inf_x, fin_x)

@parameterized.product(
@test_utils.product(
train_size=TRAIN_SIZES,
test_size=TEST_SIZES,
input_shape=INPUT_SHAPES,
Expand Down Expand Up @@ -480,7 +479,7 @@ def always_ntk(x1, x2, get=('nngp', 'ntk')):
return out._replace(nngp=out.ntk)
return always_ntk

@parameterized.product(
@test_utils.product(
train_size=TRAIN_SIZES,
test_size=TEST_SIZES,
input_shape=INPUT_SHAPES,
Expand Down Expand Up @@ -573,7 +572,7 @@ def always_nngp(x1, x2, get=('nngp', 'ntk')):
# Although, due to accumulation of numerical errors, only roughly.
self.assertAllClose(nngp_cov, nngp_ntk_cov)

@parameterized.product(
@test_utils.product(
train_size=TRAIN_SIZES,
test_size=TEST_SIZES,
input_shape=INPUT_SHAPES,
Expand Down Expand Up @@ -610,7 +609,7 @@ def testPredCovPosDef(
self.assertAllClose(cov, np.moveaxis(cov, -1, -2))
self.assertGreater(np.min(np.linalg.eigh(cov)[0]), -1e-4)

@parameterized.product(
@test_utils.product(
train_size=TRAIN_SIZES[:1],
test_size=TEST_SIZES[:1],
input_shape=INPUT_SHAPES[:1],
Expand Down Expand Up @@ -935,7 +934,7 @@ def testPredictND(self):
self.assertAllClose(y_test_shape, p_test.shape)
self.assertAllClose(y_train_shape, p_train.shape)

@parameterized.product(
@test_utils.product(
train_size=TRAIN_SIZES,
input_shape=INPUT_SHAPES,
network=NETWORK,
Expand Down Expand Up @@ -1008,7 +1007,7 @@ def get_loss(opt_state):

class PredictKwargsTest(test_utils.NeuralTangentsTestCase):

@parameterized.product(
@test_utils.product(
do_batch=[True, False],
mode=['analytic', 'mc', 'empirical']
)
Expand Down

0 comments on commit b1202ed

Please sign in to comment.