Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Register TPU profiler plugin when get_topology_desc is called with tp…
…u platform. This allows the TPU profiler to work with other plugin backends. Tested on a GPU VM: $ pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html $ pip install -e . $ TPU_SKIP_MDS_QUERY=1 python tests/cross_aot_test.py Running tests under Python 3.10.12: /usr/bin/python [ RUN ] JaxAotTest.test_tpu_profiler_registered_get_topology_from_devices NOT_FOUND: WARNING: could not determine TPU accelerator type. Set env var `TPU_ACCELERATOR_TYPE` to set manually. TPU runtime may not be properly initialized. === Source Location Trace: === learning/45eac/tfrc/runtime/common_lib.cc:285 NOT_FOUND: WARNING: could not determine TPU worker number. Set env var `TPU_WORKER_ID` to set manually. TPU runtime may not be properly initialized. === Source Location Trace: === learning/45eac/tfrc/runtime/common_lib.cc:285 NOT_FOUND: WARNING: could not determine TPU worker hostnames or internal IP addresses. Set env var `TPU_WORKER_HOSTNAMES` to set manually. TPU runtime may not be properly initialized. === Source Location Trace: === learning/45eac/tfrc/runtime/common_lib.cc:285 learning/45eac/tfrc/runtime/common_lib.cc:341 I0510 00:32:03.063246 130900437979136 cross_aot_test.py:58] Expected to fail to get topology I0510 00:32:03.079923 130900437979136 xla_bridge.py:884] Unable to initialize backend 'cuda': I0510 00:32:03.080080 130900437979136 xla_bridge.py:884] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' I0510 00:32:03.089399 130900437979136 xla_bridge.py:884] Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: No ba16c7433 device found. W0510 00:32:03.089633 130900437979136 xla_bridge.py:931] An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. /home/jieying/.local/lib/python3.10/site-packages/tensorflow/__init__.py:30: DeprecationWarning: The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives import distutils as _distutils 2024-05-10 00:32:03.359597: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-05-10 00:32:03.359652: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-05-10 00:32:03.361368: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2024-05-10 00:32:04.562557: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT [ OK ] JaxAotTest.test_tpu_profiler_registered_get_topology_from_devices ---------------------------------------------------------------------- Ran 1 test in 2.549s OK In tests/cross_aot_test.py class JaxAotTest(jtu.JaxTestCase): def test_tpu_profiler_registered_get_topology_from_devices(self): try: _ = topologies.get_topology_desc( topology_name='fake_topology', platform='tpu', ) except xla_extension.XlaRuntimeError: logging.info('Expected to fail to get topology') with tempfile.TemporaryDirectory() as tmpdir: try: jax.profiler.start_trace(tmpdir) jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')( jnp.ones(jax.local_device_count()) ) finally: jax.profiler.stop_trace() proto_path = glob.glob( os.path.join(tmpdir, '**/*.xplane.pb'), recursive=True ) self.assertLen(proto_path, 1) with open(proto_path[0], 'rb') as f: proto = f.read() # Sanity check that serialized proto contains host, and Python traces # without deserializing. self.assertIn(b'/host:metadata', proto) if jtu.test_device_matches(['tpu']): self.assertNotIn(b'/device:TPU', proto) self.assertIn(b'pxla.py', proto) PiperOrigin-RevId: 633076007
- Loading branch information