Skip to content

Commit

Permalink
Allow suppression of GPU warning via jax_platform_name
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 28, 2021
1 parent 0d68dbd commit c8e571a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
3 changes: 1 addition & 2 deletions jax/lib/xla_bridge.py
Expand Up @@ -225,13 +225,12 @@ def backends():
# we expect a RuntimeError.
logging.info("Unable to initialize backend '%s': %s" % (name, err))
continue
if _default_backend.platform == "cpu":
if _default_backend.platform == "cpu" and FLAGS.jax_platform_name != 'cpu':
logging.warning('No GPU/TPU found, falling back to CPU. '
'(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
return _backends



@lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence.
def get_backend(platform=None):
# TODO(mattjj,skyewm): remove this input polymorphism after we clean up how
Expand Down
28 changes: 27 additions & 1 deletion tests/api_test.py
Expand Up @@ -20,8 +20,10 @@
from functools import partial
import operator
import re
import unittest
import subprocess
import sys
import types
import unittest
import warnings
import weakref
import functools
Expand Down Expand Up @@ -55,6 +57,7 @@
FLAGS = config.FLAGS


python_version = (sys.version_info[0], sys.version_info[1])
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))


Expand Down Expand Up @@ -5436,5 +5439,28 @@ def test_integer_overflow(self, jit_type, func):
self.assertRaises(OverflowError, f, int_min - 1)


class BackendsTest(jtu.JaxTestCase):

@unittest.skipIf(not sys.executable, "test requires sys.executable")
@unittest.skipIf(python_version < (3, 7), "test requires Python 3.7 or higher")
@jtu.skip_on_devices("gpu", "tpu")
def test_cpu_warning_suppression(self):
warning_expected = (
"import jax; "
"jax.numpy.arange(10)")
warning_not_expected = (
"import jax; "
"jax.config.update('jax_platform_name', 'cpu'); "
"jax.numpy.arange(10)")

result = subprocess.run([sys.executable, '-c', warning_expected],
check=True, capture_output=True)
assert "No GPU/TPU found" in result.stderr.decode()

result = subprocess.run([sys.executable, '-c', warning_not_expected],
check=True, capture_output=True)
assert "No GPU/TPU found" not in result.stderr.decode()


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit c8e571a

Please sign in to comment.