Skip to content

Commit

Permalink
On Cloud TPU, use pip-installed libtpu instead of system default if a…
Browse files Browse the repository at this point in the history
…pplicable.
  • Loading branch information
skye committed Jun 22, 2021
1 parent 3b84f85 commit ba972f0
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions jax/_src/cloud_tpu_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,27 @@ def cloud_tpu_init():
as JAX's C++ backend is loaded! I.e. call this before xla_bridge or xla_client
is imported.**
These environment variables are used to tell the TPU runtime what kind of mesh
topology to use. It assumes a single-host topology by default, so we manually
set them here to default to the full pod slice if applicable.
Some of these environment variables are used to tell the TPU runtime what kind
of mesh topology to use. It assumes a single-host topology by default, so we
manually set them here to default to the full pod slice if applicable.
This will not set any env vars if a single topology-related env var is already
set.
"""
if not _running_in_cloud_tpu_vm():
return

# Use pip-installed libtpu if applicable, rather than system default.
try:
# pylint: disable=import-outside-toplevel
# pytype: disable=import-error
import libtpu
# pytype: enable=import-error
# pylint: enable=import-outside-toplevel
libtpu.configure_library_path()
except ImportError:
pass

os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')

# If the user has set any topology-related env vars, don't set any
Expand Down

0 comments on commit ba972f0

Please sign in to comment.