Skip to content

Commit

Permalink
tests: improve warnings-related tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Nov 30, 2023
1 parent fe237cd commit d2b4800
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 67 deletions.
8 changes: 4 additions & 4 deletions jax/_src/test_util.py
Expand Up @@ -1040,9 +1040,9 @@ def assertMultiLineStrippedEqual(self, expected, what):

@contextmanager
def assertNoWarnings(self):
with warnings.catch_warnings(record=True) as caught_warnings:
with warnings.catch_warnings():
warnings.simplefilter("error")
yield
self.assertEmpty(caught_warnings)

def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True, tol=None,
rtol=None, atol=None, check_cache_misses=True):
Expand Down Expand Up @@ -1124,9 +1124,9 @@ def _assertDeleted(self, x, deleted):


@contextmanager
def ignore_warning(**kw):
def ignore_warning(*, message='', category=Warning, **kw):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", **kw)
warnings.filterwarnings("ignore", message=message, category=category, **kw)
yield

# -------------------- Mesh parametrization helpers --------------------
Expand Down
36 changes: 6 additions & 30 deletions tests/api_test.py
Expand Up @@ -34,7 +34,6 @@
import types
from typing import Callable, NamedTuple, Optional
import unittest
import warnings
import weakref

from absl import logging
Expand Down Expand Up @@ -394,16 +393,9 @@ def test_jit_donate_warning_raised(self, argnum_type, argnum_val):
y = jnp.array([1, 2], jnp.int32)
f = jit(lambda x, y: x.sum() + jnp.float32(y.sum()),
**{argnum_type: argnum_val})
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with self.assertWarnsRegex(UserWarning, "Some donated buffers were not usable"):
f(x, y)

self.assertLen(w, 1)
self.assertTrue(issubclass(w[-1].category, UserWarning))
self.assertIn(
"Some donated buffers were not usable:",
str(w[-1].message))

@parameterized.named_parameters(
("argnums", "donate_argnums", 0),
("argnames", "donate_argnames", 'x'),
Expand Down Expand Up @@ -480,8 +472,7 @@ def _test(array):

x = jnp.asarray([0, 1])
x_copy = jnp.array(x, copy=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with jtu.ignore_warning():
_test(x) # donation

# Gives: RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
Expand Down Expand Up @@ -2879,35 +2870,20 @@ def f(x):

def test_dtype_from_builtin_types(self):
for dtype in [bool, int, float, complex]:
with warnings.catch_warnings(record=True) as caught_warnings:
with self.assertNoWarnings():
x = jnp.array(0, dtype=dtype)
self.assertEmpty(caught_warnings)
assert x.dtype == dtypes.canonicalize_dtype(dtype)
self.assertEqual(x.dtype, dtypes.canonicalize_dtype(dtype))

def test_dtype_warning(self):
# cf. issue #1230
if config.enable_x64.value:
raise unittest.SkipTest("test only applies when x64 is disabled")

def check_warning(warn, nowarn):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

nowarn() # get rid of extra startup warning

prev_len = len(w)
nowarn()
assert len(w) == prev_len

with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype"):
warn()
assert len(w) > 0
msg = str(w[-1].message)
expected_prefix = "Explicitly requested dtype "
self.assertEqual(expected_prefix, msg[:len(expected_prefix)])

prev_len = len(w)
with self.assertNoWarnings():
nowarn()
assert len(w) == prev_len

check_warning(lambda: jnp.array([1, 2, 3], dtype="float64"),
lambda: jnp.array([1, 2, 3], dtype="float32"))
Expand Down
4 changes: 1 addition & 3 deletions tests/jaxpr_effects_test.py
Expand Up @@ -14,7 +14,6 @@
import functools
import threading
import unittest
import warnings

from absl.testing import absltest
import jax
Expand Down Expand Up @@ -535,8 +534,7 @@ def test_cant_jit_and_pmap_function_with_unordered_effects(self):
def f(x):
effect_p.bind(effect=bar_effect)
return x + 1
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with jtu.ignore_warning():
f(jnp.arange(jax.device_count())) # doesn't crash

def test_cant_jit_and_pmap_function_with_ordered_effects(self):
Expand Down
5 changes: 1 addition & 4 deletions tests/lax_numpy_indexing_test.py
Expand Up @@ -18,7 +18,6 @@
import itertools
import typing
from typing import Any, Optional
import warnings

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -1524,10 +1523,8 @@ def np_fun(data, segment_ids):
def testIndexDtypeError(self):
# https://github.com/google/jax/issues/2795
jnp.array(1) # get rid of startup warning
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("error")
with self.assertNoWarnings():
jnp.zeros(5).at[::2].set(1)
self.assertLen(w, 0)

@jtu.sample_product(
[dict(idx=idx, idx_type=idx_type)
Expand Down
10 changes: 3 additions & 7 deletions tests/lax_numpy_test.py
Expand Up @@ -25,7 +25,6 @@
from typing import cast, Optional
import unittest
from unittest import SkipTest
import warnings

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -4564,8 +4563,7 @@ def testR_(self):
self.assertArraysEqual(np.r_['0,4,-2', [1,2,3], [4,5,6]], jnp.r_['0,4,-2', [1,2,3], [4,5,6]])

# matrix directives
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
with jtu.ignore_warning(category=PendingDeprecationWarning):
self.assertArraysEqual(np.r_['r',[1,2,3], [4,5,6]], jnp.r_['r',[1,2,3], [4,5,6]])
self.assertArraysEqual(np.r_['c', [1, 2, 3], [4, 5, 6]], jnp.r_['c', [1, 2, 3], [4, 5, 6]])

Expand Down Expand Up @@ -4613,8 +4611,7 @@ def testC_(self):
self.assertArraysEqual(np.c_['0,4,-1', [1,2,3], [4,5,6]], jnp.c_['0,4,-1', [1,2,3], [4,5,6]])
self.assertArraysEqual(np.c_['0,4,-2', [1,2,3], [4,5,6]], jnp.c_['0,4,-2', [1,2,3], [4,5,6]])
# matrix directives, avoid numpy deprecation warning
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
with jtu.ignore_warning(category=PendingDeprecationWarning):
self.assertArraysEqual(np.c_['r',[1,2,3], [4,5,6]], jnp.c_['r',[1,2,3], [4,5,6]])
self.assertArraysEqual(np.c_['c', [1, 2, 3], [4, 5, 6]], jnp.c_['c', [1, 2, 3], [4, 5, 6]])

Expand Down Expand Up @@ -5497,8 +5494,7 @@ def _dtypes_for_ufunc(name: str) -> Iterator[tuple[str, ...]]:
for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin):
args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes)
try:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "divide by zero", RuntimeWarning)
with jtu.ignore_warning(category=RuntimeWarning, message="divide by zero"):
_ = func(*args)
except TypeError:
pass
Expand Down
8 changes: 1 addition & 7 deletions tests/memories_test.py
Expand Up @@ -14,7 +14,6 @@

import functools
import math
import warnings
from absl.testing import absltest
from absl.testing import parameterized
from absl import flags
Expand Down Expand Up @@ -929,14 +928,9 @@ def test_no_donation_across_memory_kinds(self):
def f(x):
return x * 2

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with self.assertWarnsRegex(UserWarning, "Some donated buffers were not usable"):
f(inp)

self.assertLen(w, 1)
self.assertTrue(issubclass(w[-1].category, UserWarning))
self.assertIn("Some donated buffers were not usable:", str(w[-1].message))

lowered_text = f.lower(inp).as_text("hlo")
self.assertNotIn("input_output_alias", lowered_text)
self.assertNotDeleted(inp)
Expand Down
8 changes: 1 addition & 7 deletions tests/pmap_test.py
Expand Up @@ -24,7 +24,6 @@
from typing import Optional, cast
import unittest
from unittest import SkipTest
import warnings
import weakref

import numpy as np
Expand Down Expand Up @@ -1848,14 +1847,9 @@ def testJitOfPmapWarningMessage(self):

def foo(x): return x

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with self.assertWarnsRegex(UserWarning, "The jitted function foo includes a pmap"):
jit(self.pmap(foo))(jnp.arange(device_count))

self.assertGreaterEqual(len(w), 1)
self.assertIn("The jitted function foo includes a pmap",
str(w[-1].message))

def testJitOfPmapOutputSharding(self):
device_count = jax.device_count()

Expand Down
6 changes: 1 addition & 5 deletions tests/xla_bridge_test.py
Expand Up @@ -355,12 +355,8 @@ def test_factory_returns_none(self):
xb.get_backend("none")

def cpu_fallback_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with self.assertWarnsRegex(UserWarning, "No GPU/TPU found, falling back to CPU"):
xb.get_backend()
self.assertLen(w, 1)
msg = str(w[-1].message)
self.assertIn("No GPU/TPU found, falling back to CPU", msg)

def test_jax_platforms_flag(self):
self._register_factory("platform_A", 20, assert_used_at_most_once=True)
Expand Down

0 comments on commit d2b4800

Please sign in to comment.