Skip to content

Commit

Permalink
Remove calling configure_library_path during jax import and get libtp…
Browse files Browse the repository at this point in the history
…u path from libtpu_module.get_library_path().

PiperOrigin-RevId: 572306461
  • Loading branch information
jyingl3 authored and jax authors committed Oct 10, 2023
1 parent 269d7ce commit b81a3e1
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
4 changes: 1 addition & 3 deletions jax/_src/cloud_tpu_init.py
Expand Up @@ -58,9 +58,7 @@ def cloud_tpu_init() -> None:
# if the following hold: a) libtpu is installed b) JAX_FORCE_TPU_INIT is set
# Exit early if we're not running on Cloud TPU.
libtpu_module = maybe_import_libtpu()
if libtpu_module is not None:
libtpu_module.configure_library_path()
elif not jax_force_tpu_init():
if libtpu_module is None and not jax_force_tpu_init():
return

running_in_cloud_tpu_vm = True
Expand Down
19 changes: 18 additions & 1 deletion jax/_src/xla_bridge.py
Expand Up @@ -37,6 +37,7 @@

from jax._src import config
from jax._src import distributed
from jax._src.cloud_tpu_init import maybe_import_libtpu
from jax._src.lib import cuda_versions
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
Expand Down Expand Up @@ -98,6 +99,19 @@

# Backends


def _get_tpu_library_path() -> Optional[str]:
path_from_env = os.getenv("TPU_LIBRARY_PATH")
if path_from_env is not None:
return path_from_env

libtpu_module = maybe_import_libtpu()
if libtpu_module is not None:
return libtpu_module.get_library_path()

return None


def tpu_client_timer_callback(timer_secs: float) -> Optional[xla_client.Client]:
def _log_warning():
warnings.warn(
Expand All @@ -111,7 +125,10 @@ def _log_warning():
t.start()

try:
client = xla_client.make_tpu_client()
if xla_extension_version >= 205:
client = xla_client.make_tpu_client(_get_tpu_library_path()) # type: ignore
else:
client = xla_client.make_tpu_client() # type: ignore
finally:
t.cancel()

Expand Down
2 changes: 1 addition & 1 deletion tests/xla_bridge_test.py
Expand Up @@ -146,7 +146,7 @@ def test_timer_tpu_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

def _mock_tpu_client():
def _mock_tpu_client(library_path=None):
time_to_wait = 5
start = time.time()
while not w:
Expand Down

0 comments on commit b81a3e1

Please sign in to comment.