Skip to content

Commit

Permalink
Register TPU profiler plugin when get_topology_desc is called with tp…
Browse files Browse the repository at this point in the history
…u platform.

This allows the TPU profiler to work with other plugin backends.

Tested on a GPU VM:
$ pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
$ pip install -e .
$ TPU_SKIP_MDS_QUERY=1 python tests/cross_aot_test.py
Running tests under Python 3.10.12: /usr/bin/python
[ RUN      ] JaxAotTest.test_tpu_profiler_registered_get_topology_from_devices
NOT_FOUND: WARNING: could not determine TPU accelerator type. Set env var `TPU_ACCELERATOR_TYPE` to set manually. TPU runtime may not be properly initialized.
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:285

NOT_FOUND: WARNING: could not determine TPU worker number. Set env var `TPU_WORKER_ID` to set manually. TPU runtime may not be properly initialized.
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:285

NOT_FOUND: WARNING: could not determine TPU worker hostnames or internal IP addresses. Set env var `TPU_WORKER_HOSTNAMES` to set manually. TPU runtime may not be properly initialized.
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:285
learning/45eac/tfrc/runtime/common_lib.cc:341

I0510 00:32:03.063246 130900437979136 cross_aot_test.py:58] Expected to fail to get topology
I0510 00:32:03.079923 130900437979136 xla_bridge.py:884] Unable to initialize backend 'cuda':
I0510 00:32:03.080080 130900437979136 xla_bridge.py:884] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0510 00:32:03.089399 130900437979136 xla_bridge.py:884] Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: No ba16c7433 device found.
W0510 00:32:03.089633 130900437979136 xla_bridge.py:931] An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
/home/jieying/.local/lib/python3.10/site-packages/tensorflow/__init__.py:30: DeprecationWarning: The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives
  import distutils as _distutils
2024-05-10 00:32:03.359597: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-10 00:32:03.359652: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-10 00:32:03.361368: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-10 00:32:04.562557: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[       OK ] JaxAotTest.test_tpu_profiler_registered_get_topology_from_devices
----------------------------------------------------------------------
Ran 1 test in 2.549s

OK

In tests/cross_aot_test.py
class JaxAotTest(jtu.JaxTestCase):
  def test_tpu_profiler_registered_get_topology_from_devices(self):
    try:
      _ = topologies.get_topology_desc(
          topology_name='fake_topology',
          platform='tpu',
      )
    except xla_extension.XlaRuntimeError:
      logging.info('Expected to fail to get topology')

    with tempfile.TemporaryDirectory() as tmpdir:
      try:
        jax.profiler.start_trace(tmpdir)
        jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(
            jnp.ones(jax.local_device_count())
        )
      finally:
        jax.profiler.stop_trace()

      proto_path = glob.glob(
          os.path.join(tmpdir, '**/*.xplane.pb'), recursive=True
      )
      self.assertLen(proto_path, 1)
      with open(proto_path[0], 'rb') as f:
        proto = f.read()
      # Sanity check that serialized proto contains host, and Python traces
      # without deserializing.
      self.assertIn(b'/host:metadata', proto)
      if jtu.test_device_matches(['tpu']):
        self.assertNotIn(b'/device:TPU', proto)
      self.assertIn(b'pxla.py', proto)

PiperOrigin-RevId: 633076007
  • Loading branch information
jyingl3 authored and jax authors committed May 13, 2024
1 parent af4bddb commit ba8480a
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,7 +1212,8 @@ def make_pjrt_tpu_topology(topology_name='', **kwargs):
raise RuntimeError(
"JAX TPU support not installed; cannot generate TPU topology. See"
" https://github.com/google/jax#installation")
xla_client.load_pjrt_plugin_dynamically("tpu", library_path)
c_api = xla_client.load_pjrt_plugin_dynamically("tpu", library_path)
xla_client.profiler.register_plugin_profiler(c_api)
assert xla_client.pjrt_plugin_loaded("tpu")
if not xla_client.pjrt_plugin_initialized("tpu"):
xla_client.initialize_pjrt_plugin("tpu")
Expand Down

0 comments on commit ba8480a

Please sign in to comment.