Skip to content

Commit

Permalink
Use absltest throughout the project.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 319672185
  • Loading branch information
romanngg committed Jul 5, 2020
1 parent 11d70e1 commit 24549f3
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 32 deletions.
3 changes: 1 addition & 2 deletions neural_tangents/tests/empirical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Tests for `utils/empirical.py`."""

from functools import partial
import unittest

from absl.testing import absltest
from jax import test_util as jtu
Expand Down Expand Up @@ -270,7 +269,7 @@ def testAxes(self, diagonal_axes, trace_axes):
_trace_axes = utils.canonicalize_axis(trace_axes, data_self)

if any(d == c for d in _diagonal_axes for c in _trace_axes):
raise unittest.SkipTest(
raise absltest.SkipTest(
'diagonal axes must be different from channel axes.')

implicit, direct, nngp = KERNELS['empirical_logits_3'](
Expand Down
3 changes: 1 addition & 2 deletions neural_tangents/tests/predict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


import math
import unittest

from absl.testing import absltest
from jax import test_util as jtu
Expand Down Expand Up @@ -238,7 +237,7 @@ def testNTKGDPrediction(self, train_shape, test_shape, network, out_logits,
trace_axes = () if g_dd.ndim == 4 else (-1,)
if loss == 'mse_analytic':
if momentum is not None:
raise unittest.SkipTest(momentum)
raise absltest.SkipTest(momentum)
predictor = predict.gradient_descent_mse(g_dd, y_train,
learning_rate=learning_rate,
trace_axes=trace_axes)
Expand Down
55 changes: 27 additions & 28 deletions neural_tangents/tests/stax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from neural_tangents.utils import monte_carlo
from neural_tangents.utils import test_utils
import numpy as onp
import unittest
from typing import Tuple


Expand Down Expand Up @@ -350,17 +349,17 @@ def test_exact(self, model, width, strides, padding, phi, same_inputs,
# Check for duplicate / incorrectly-shaped NN configs / wrong backend.
if is_conv:
if xla_bridge.get_backend().platform == 'cpu':
raise unittest.SkipTest('Not running CNN models on CPU to save time.')
raise absltest.SkipTest('Not running CNN models on CPU to save time.')

if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or
(padding == 'VALID' and filter_shape !=
(1, 1)))):
raise unittest.SkipTest('Different paths in a residual models need to '
raise absltest.SkipTest('Different paths in a residual models need to '
'return outputs of the same shape.')
elif (filter_shape != FILTER_SHAPES[0] or padding != PADDINGS[0] or
strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or
use_pooling):
raise unittest.SkipTest('FC models do not have these parameters.')
raise absltest.SkipTest('FC models do not have these parameters.')

pool_type = 'AVG'
W_std, b_std = 2.**0.5, 0.5**0.5
Expand Down Expand Up @@ -419,9 +418,9 @@ def test_parameterizations(self, model, width, same_inputs, is_ntk,
# Check for duplicate / incorrectly-shaped NN configs / wrong backend.
if is_conv:
if xla_bridge.get_backend().platform == 'cpu':
raise unittest.SkipTest('Not running CNN models on CPU to save time.')
raise absltest.SkipTest('Not running CNN models on CPU to save time.')
elif proj_into_2d != PROJECTIONS[0]:
raise unittest.SkipTest('FC models do not have these parameters.')
raise absltest.SkipTest('FC models do not have these parameters.')

net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
padding, phi, strides, width, is_ntk, proj_into_2d,
Expand Down Expand Up @@ -469,9 +468,9 @@ def test_layernorm(self,
# Check for duplicate / incorrectly-shaped NN configs / wrong backend.
if is_conv:
if xla_bridge.get_backend().platform == 'cpu':
raise unittest.SkipTest('Not running CNN models on CPU to save time.')
raise absltest.SkipTest('Not running CNN models on CPU to save time.')
elif proj_into_2d != PROJECTIONS[0] or layer_norm not in ('C', 'NC'):
raise unittest.SkipTest('FC models do not have these parameters.')
raise absltest.SkipTest('FC models do not have these parameters.')

W_std, b_std = 2.**0.5, 0.5**0.5
filter_shape = FILTER_SHAPES[0]
Expand Down Expand Up @@ -528,9 +527,9 @@ def test_pool(self, width, same_inputs, is_ntk, pool_type,
# Check for duplicate / incorrectly-shaped NN configs / wrong backend.

if xla_bridge.get_backend().platform == 'cpu':
raise unittest.SkipTest('Not running CNN models on CPU to save time.')
raise absltest.SkipTest('Not running CNN models on CPU to save time.')
if pool_type == 'SUM' and normalize_edges:
raise unittest.SkipTest('normalize_edges not applicable to SumPool.')
raise absltest.SkipTest('normalize_edges not applicable to SumPool.')

net = _get_net_pool(width, is_ntk, pool_type,
padding, filter_shape, strides, normalize_edges)
Expand Down Expand Up @@ -624,7 +623,7 @@ def test_avg_pool(self):
def test_dropout(self, model, width, same_inputs, is_ntk, padding, strides,
filter_shape, phi, use_pooling, proj_into_2d):
if xla_bridge.get_backend().platform == 'tpu' and same_inputs:
raise unittest.SkipTest(
raise absltest.SkipTest(
'Skip TPU test for `same_inputs`. Need to handle '
'random keys carefully for dropout + empirical kernel.')

Expand All @@ -638,17 +637,17 @@ def test_dropout(self, model, width, same_inputs, is_ntk, padding, strides,
parameterization = 'ntk'
if is_conv:
if xla_bridge.get_backend().platform == 'cpu':
raise unittest.SkipTest('Not running CNN models on CPU to save time.')
raise absltest.SkipTest('Not running CNN models on CPU to save time.')

if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or
(padding == 'VALID' and filter_shape !=
(1, 1)))):
raise unittest.SkipTest('Different paths in a residual models need to '
raise absltest.SkipTest('Different paths in a residual models need to '
'return outputs of the same shape.')
elif (filter_shape != FILTER_SHAPES[0] or padding != PADDINGS[0] or
strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or
use_pooling):
raise unittest.SkipTest('FC models do not have these parameters.')
raise absltest.SkipTest('FC models do not have these parameters.')

net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
padding, phi, strides, width, is_ntk, proj_into_2d,
Expand Down Expand Up @@ -852,7 +851,7 @@ def _test_activation(self, activation_fn, same_inputs, model, get,
rbf_gamma=None):
platform = xla_bridge.get_backend().platform
if platform == 'cpu' and 'conv' in model:
raise unittest.SkipTest('Not running CNNs on CPU to save time.')
raise absltest.SkipTest('Not running CNNs on CPU to save time.')

key = random.PRNGKey(1)
key, split = random.split(key)
Expand Down Expand Up @@ -940,9 +939,9 @@ def test_activation(self, same_inputs, model, phi_name, get, abc):
elif phi_name == 'Gelu':
activation = stax.Gelu()
if a != 1. or b != 1. or c != 0.:
unittest.SkipTest('Skip `Gelu` test if (a, b, c) != (1., 1., 0.).')
absltest.SkipTest('Skip `Gelu` test if (a, b, c) != (1., 1., 0.).')
else:
raise unittest.SkipTest(f'Activation {phi_name} is not implemented.')
raise absltest.SkipTest(f'Activation {phi_name} is not implemented.')
self._test_activation(activation, same_inputs, model, get)

@jtu.parameterized.named_parameters(
Expand Down Expand Up @@ -1197,11 +1196,11 @@ def _get_phi(cls, i):
'dense_after_branch_in']))
def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in):
if axis in (None, 0) and branch_in == 'dense_after_branch_in':
raise unittest.SkipTest('`FanInSum` and `FanInConcat(0)` '
raise absltest.SkipTest('`FanInSum` and `FanInConcat(0)` '
'require `is_gaussian`.')

if axis == 1 and branch_in == 'dense_before_branch_in':
raise unittest.SkipTest('`FanInConcat` on feature axis requires a dense '
raise absltest.SkipTest('`FanInConcat` on feature axis requires a dense '
'layer after concatenation.')

key = random.PRNGKey(1)
Expand Down Expand Up @@ -1295,14 +1294,14 @@ def test_fan_in_conv(self,
branch_in,
readout):
if xla_bridge.get_backend().platform == 'cpu':
raise unittest.SkipTest('Not running CNNs on CPU to save time.')
raise absltest.SkipTest('Not running CNNs on CPU to save time.')

if axis in (None, 0, 1, 2) and branch_in == 'dense_after_branch_in':
raise unittest.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` '
raise absltest.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` '
'require `is_gaussian`.')

if axis == 3 and branch_in == 'dense_before_branch_in':
raise unittest.SkipTest('`FanInConcat` on feature axis requires a dense '
raise absltest.SkipTest('`FanInConcat` on feature axis requires a dense '
'layer after concatenation.')

key = random.PRNGKey(1)
Expand Down Expand Up @@ -1413,11 +1412,11 @@ def test_conv_nd(self, same_inputs, n, get, proj, use_attn, channels_first,
use_dropout, use_layernorm):
platform = xla_bridge.get_backend().platform
if platform == 'cpu':
raise unittest.SkipTest('Skipping CPU CNN tests for speed.')
raise absltest.SkipTest('Skipping CPU CNN tests for speed.')
elif platform == 'gpu' and n not in (0, 1, 2, 3):
raise unittest.SkipTest('>=4D CNN does not work on GPU.')
raise absltest.SkipTest('>=4D CNN does not work on GPU.')
elif platform == 'tpu' and use_dropout and same_inputs:
raise unittest.SkipTest('Batched empirical kernel with dropout not '
raise absltest.SkipTest('Batched empirical kernel with dropout not '
'supported.')

width = 1024
Expand Down Expand Up @@ -1568,7 +1567,7 @@ class InputReqTest(test_utils.NeuralTangentsTestCase):
def test_input_req(self, same_inputs):
platform = xla_bridge.get_backend().platform
if platform == 'cpu':
raise unittest.SkipTest('Skipping CPU CNN tests for speed.')
raise absltest.SkipTest('Skipping CPU CNN tests for speed.')

key = random.PRNGKey(1)
x1 = random.normal(key, (2, 7, 8, 4, 3))
Expand Down Expand Up @@ -1800,9 +1799,9 @@ def apply_mask(x):
def test_mask_conv(self, same_inputs, get, mask_axis, mask_constant, concat,
proj, p, use_attn, n):
if xla_bridge.get_backend().platform == 'cpu':
raise unittest.SkipTest('Skipping CNN tests on CPU for speed.')
raise absltest.SkipTest('Skipping CNN tests on CPU for speed.')
elif xla_bridge.get_backend().platform == 'gpu' and n > 3:
raise unittest.SkipTest('>=4D-CNN is not supported on GPUs.')
raise absltest.SkipTest('>=4D-CNN is not supported on GPUs.')

width = 1024
n_samples = 128
Expand Down

0 comments on commit 24549f3

Please sign in to comment.