From d2b480072385a973e09489c11c53cc4350808d78 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 30 Nov 2023 10:35:24 -0800 Subject: [PATCH] tests: improve warnings-related tests --- jax/_src/test_util.py | 8 +++---- tests/api_test.py | 36 ++++++-------------------------- tests/jaxpr_effects_test.py | 4 +--- tests/lax_numpy_indexing_test.py | 5 +---- tests/lax_numpy_test.py | 10 +++------ tests/memories_test.py | 8 +------ tests/pmap_test.py | 8 +------ tests/xla_bridge_test.py | 6 +----- 8 files changed, 18 insertions(+), 67 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 43b78c7c9937..3704925205b1 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1040,9 +1040,9 @@ def assertMultiLineStrippedEqual(self, expected, what): @contextmanager def assertNoWarnings(self): - with warnings.catch_warnings(record=True) as caught_warnings: + with warnings.catch_warnings(): + warnings.simplefilter("error") yield - self.assertEmpty(caught_warnings) def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True, tol=None, rtol=None, atol=None, check_cache_misses=True): @@ -1124,9 +1124,9 @@ def _assertDeleted(self, x, deleted): @contextmanager -def ignore_warning(**kw): +def ignore_warning(*, message='', category=Warning, **kw): with warnings.catch_warnings(): - warnings.filterwarnings("ignore", **kw) + warnings.filterwarnings("ignore", message=message, category=category, **kw) yield # -------------------- Mesh parametrization helpers -------------------- diff --git a/tests/api_test.py b/tests/api_test.py index 9d02229f8aaa..05e730432910 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -34,7 +34,6 @@ import types from typing import Callable, NamedTuple, Optional import unittest -import warnings import weakref from absl import logging @@ -394,16 +393,9 @@ def test_jit_donate_warning_raised(self, argnum_type, argnum_val): y = jnp.array([1, 2], jnp.int32) f = jit(lambda x, y: x.sum() + jnp.float32(y.sum()), **{argnum_type: argnum_val}) - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + with self.assertWarnsRegex(UserWarning, "Some donated buffers were not usable"): f(x, y) - self.assertLen(w, 1) - self.assertTrue(issubclass(w[-1].category, UserWarning)) - self.assertIn( - "Some donated buffers were not usable:", - str(w[-1].message)) - @parameterized.named_parameters( ("argnums", "donate_argnums", 0), ("argnames", "donate_argnames", 'x'), @@ -480,8 +472,7 @@ def _test(array): x = jnp.asarray([0, 1]) x_copy = jnp.array(x, copy=True) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") + with jtu.ignore_warning(): _test(x) # donation # Gives: RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer. @@ -2879,10 +2870,9 @@ def f(x): def test_dtype_from_builtin_types(self): for dtype in [bool, int, float, complex]: - with warnings.catch_warnings(record=True) as caught_warnings: + with self.assertNoWarnings(): x = jnp.array(0, dtype=dtype) - self.assertEmpty(caught_warnings) - assert x.dtype == dtypes.canonicalize_dtype(dtype) + self.assertEqual(x.dtype, dtypes.canonicalize_dtype(dtype)) def test_dtype_warning(self): # cf. issue #1230 @@ -2890,24 +2880,10 @@ def test_dtype_warning(self): raise unittest.SkipTest("test only applies when x64 is disabled") def check_warning(warn, nowarn): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - nowarn() # get rid of extra startup warning - - prev_len = len(w) - nowarn() - assert len(w) == prev_len - + with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype"): warn() - assert len(w) > 0 - msg = str(w[-1].message) - expected_prefix = "Explicitly requested dtype " - self.assertEqual(expected_prefix, msg[:len(expected_prefix)]) - - prev_len = len(w) + with self.assertNoWarnings(): nowarn() - assert len(w) == prev_len check_warning(lambda: jnp.array([1, 2, 3], dtype="float64"), lambda: jnp.array([1, 2, 3], dtype="float32")) diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 161665c15e26..255822a26109 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -14,7 +14,6 @@ import functools import threading import unittest -import warnings from absl.testing import absltest import jax @@ -535,8 +534,7 @@ def test_cant_jit_and_pmap_function_with_unordered_effects(self): def f(x): effect_p.bind(effect=bar_effect) return x + 1 - with warnings.catch_warnings(): - warnings.simplefilter("ignore") + with jtu.ignore_warning(): f(jnp.arange(jax.device_count())) # doesn't crash def test_cant_jit_and_pmap_function_with_ordered_effects(self): diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 404b34f7d855..c77cdd35a1e2 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -18,7 +18,6 @@ import itertools import typing from typing import Any, Optional -import warnings from absl.testing import absltest from absl.testing import parameterized @@ -1524,10 +1523,8 @@ def np_fun(data, segment_ids): def testIndexDtypeError(self): # https://github.com/google/jax/issues/2795 jnp.array(1) # get rid of startup warning - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("error") + with self.assertNoWarnings(): jnp.zeros(5).at[::2].set(1) - self.assertLen(w, 0) @jtu.sample_product( [dict(idx=idx, idx_type=idx_type) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 51cf262eb9a7..9ce9a1cf62a2 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -25,7 +25,6 @@ from typing import cast, Optional import unittest from unittest import SkipTest -import warnings from absl.testing import absltest from absl.testing import parameterized @@ -4564,8 +4563,7 @@ def testR_(self): self.assertArraysEqual(np.r_['0,4,-2', [1,2,3], [4,5,6]], jnp.r_['0,4,-2', [1,2,3], [4,5,6]]) # matrix directives - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=PendingDeprecationWarning) + with jtu.ignore_warning(category=PendingDeprecationWarning): self.assertArraysEqual(np.r_['r',[1,2,3], [4,5,6]], jnp.r_['r',[1,2,3], [4,5,6]]) self.assertArraysEqual(np.r_['c', [1, 2, 3], [4, 5, 6]], jnp.r_['c', [1, 2, 3], [4, 5, 6]]) @@ -4613,8 +4611,7 @@ def testC_(self): self.assertArraysEqual(np.c_['0,4,-1', [1,2,3], [4,5,6]], jnp.c_['0,4,-1', [1,2,3], [4,5,6]]) self.assertArraysEqual(np.c_['0,4,-2', [1,2,3], [4,5,6]], jnp.c_['0,4,-2', [1,2,3], [4,5,6]]) # matrix directives, avoid numpy deprecation warning - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=PendingDeprecationWarning) + with jtu.ignore_warning(category=PendingDeprecationWarning): self.assertArraysEqual(np.c_['r',[1,2,3], [4,5,6]], jnp.c_['r',[1,2,3], [4,5,6]]) self.assertArraysEqual(np.c_['c', [1, 2, 3], [4, 5, 6]], jnp.c_['c', [1, 2, 3], [4, 5, 6]]) @@ -5497,8 +5494,7 @@ def _dtypes_for_ufunc(name: str) -> Iterator[tuple[str, ...]]: for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin): args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes) try: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "divide by zero", RuntimeWarning) + with jtu.ignore_warning(category=RuntimeWarning, message="divide by zero"): _ = func(*args) except TypeError: pass diff --git a/tests/memories_test.py b/tests/memories_test.py index 1d7de4436bb8..7f0be2043c6f 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -14,7 +14,6 @@ import functools import math -import warnings from absl.testing import absltest from absl.testing import parameterized from absl import flags @@ -929,14 +928,9 @@ def test_no_donation_across_memory_kinds(self): def f(x): return x * 2 - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + with self.assertWarnsRegex(UserWarning, "Some donated buffers were not usable"): f(inp) - self.assertLen(w, 1) - self.assertTrue(issubclass(w[-1].category, UserWarning)) - self.assertIn("Some donated buffers were not usable:", str(w[-1].message)) - lowered_text = f.lower(inp).as_text("hlo") self.assertNotIn("input_output_alias", lowered_text) self.assertNotDeleted(inp) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index d0c42ee424d1..ef2816ff0b5c 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -24,7 +24,6 @@ from typing import Optional, cast import unittest from unittest import SkipTest -import warnings import weakref import numpy as np @@ -1848,14 +1847,9 @@ def testJitOfPmapWarningMessage(self): def foo(x): return x - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + with self.assertWarnsRegex(UserWarning, "The jitted function foo includes a pmap"): jit(self.pmap(foo))(jnp.arange(device_count)) - self.assertGreaterEqual(len(w), 1) - self.assertIn("The jitted function foo includes a pmap", - str(w[-1].message)) - def testJitOfPmapOutputSharding(self): device_count = jax.device_count() diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index b5522e2e55ca..a3d78f1c2e18 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -355,12 +355,8 @@ def test_factory_returns_none(self): xb.get_backend("none") def cpu_fallback_warning(self): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + with self.assertWarnsRegex(UserWarning, "No GPU/TPU found, falling back to CPU"): xb.get_backend() - self.assertLen(w, 1) - msg = str(w[-1].message) - self.assertIn("No GPU/TPU found, falling back to CPU", msg) def test_jax_platforms_flag(self): self._register_factory("platform_A", 20, assert_used_at_most_once=True)