From e0bba8ff3981ba241ff7dd74579f3f59969e1eba Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Thu, 22 Feb 2024 16:54:52 -0800 Subject: [PATCH] [PJRT C API] Supports plugins to register a method to create topology. - Add a topology factory registration to xla_bridge.py. - Move discovery and registration of plugins to the first time backends() or make_pjrt_topology() is called. PiperOrigin-RevId: 609544983 --- jax/_src/xla_bridge.py | 57 ++++++++++++++++++++++++---------- jax/experimental/topologies.py | 13 ++++++-- tests/aot_test.py | 2 +- 3 files changed, 51 insertions(+), 21 deletions(-) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 945b7d030d2d..89344e2ee261 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -168,6 +168,7 @@ def _log_warning(): # device. BackendFactory = Callable[[], Union[xla_client.Client, None]] +TopologyFactory = Callable[..., Union[xla_client.DeviceTopology, None]] @dataclasses.dataclass class BackendRegistration: @@ -195,6 +196,7 @@ class BackendRegistration: _backend_lock = threading.Lock() _plugins_registered: bool = False _plugin_lock = threading.Lock() +_topology_factories: dict[str, TopologyFactory] = {} # The set of known non-experimental plugins. # @@ -209,12 +211,15 @@ class BackendRegistration: def register_backend_factory(name: str, factory: BackendFactory, *, priority: int = 0, fail_quietly: bool = True, - experimental: bool = False) -> None: + experimental: bool = False, + make_topology: TopologyFactory | None = None) -> None: with _backend_lock: if name in _backends: raise RuntimeError(f"Backend {name} already initialized") _backend_factories[name] = BackendRegistration( factory, priority, fail_quietly, experimental) + if make_topology is not None: + _topology_factories[name] = make_topology def make_cpu_client() -> xla_client.Client: @@ -535,9 +540,6 @@ def factory(): logger.debug( 'registering PJRT plugin %s from %s', plugin_name, library_path ) - experimental = plugin_name not in _nonexperimental_plugins - register_backend_factory(plugin_name, factory, priority=priority, - fail_quietly=False, experimental=experimental) if library_path is not None: c_api = xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path) # type: ignore xla_client.profiler.register_plugin_profiler(c_api) @@ -545,6 +547,14 @@ def factory(): if xla_extension_version >= 236: assert c_api is not None xla_client.load_pjrt_plugin_with_c_api(plugin_name, c_api) + if xla_extension_version >= 239: + make_topology = partial(xla_client.make_c_api_device_topology, c_api) + else: + make_topology = None + experimental = plugin_name not in _nonexperimental_plugins + register_backend_factory(plugin_name, factory, priority=priority, + fail_quietly=False, experimental=experimental, + make_topology=make_topology) return c_api @@ -577,6 +587,23 @@ def register_pjrt_plugin_factories_from_env() -> None: register_plugin(plugin_name, library_path=library_path, options=options) +def _discover_and_register_pjrt_plugins(): + global _plugins_registered + + # Needs a separate lock because register_backend_factory (called from + # register_plugin) requires to hold _backend_lock. + with _plugin_lock: + if not _plugins_registered: + # Plugins in the namespace package `jax_plugins` or have an entry-point + # under the `jax_plugins` group will be imported. + discover_pjrt_plugins() + # Registers plugins names and paths set in env var + # PJRT_NAMES_AND_LIBRARY_PATHS, in the format of 'name1:path1,name2:path2' + # ('name1;path1,name2;path2' for windows). + register_pjrt_plugin_factories_from_env() + _plugins_registered = True + + _platform_aliases = { "cuda": "gpu", "rocm": "gpu", @@ -638,20 +665,8 @@ def backends() -> dict[str, xla_client.Client]: global _backends global _backend_errors global _default_backend - global _plugins_registered - # Needs a separate lock because register_backend_factory (called from - # register_plugin) requries to hold _backend_lock. - with _plugin_lock: - if not _plugins_registered: - # Plugins in the namespace package `jax_plugins` or have an entry-point - # under the `jax_plugins` group will be imported. - discover_pjrt_plugins() - # Registers plugins names and paths set in env var - # PJRT_NAMES_AND_LIBRARY_PATHS, in the format of 'name1:path1,name2:path2' - # ('name1;path1,name2;path2' for windows). - register_pjrt_plugin_factories_from_env() - _plugins_registered = True + _discover_and_register_pjrt_plugins() with _backend_lock: if _backends: @@ -978,6 +993,14 @@ def host_ids( def using_pjrt_c_api(backend=None): return "PJRT C API" in get_backend(backend).platform_version +def make_pjrt_topology(platform: str, topology_name='', **kwargs): + _discover_and_register_pjrt_plugins() + actual_platform = canonicalize_platform(platform) + with _backend_lock: + if actual_platform in _topology_factories: + return _topology_factories[actual_platform](topology_name, **kwargs) + raise NotImplementedError("topology not implemented for %s" % platform) + # TODO(parkers): Get rid of this in favor of a generic way to get topologies. def make_pjrt_tpu_topology(topology_name='', **kwargs): diff --git a/jax/experimental/topologies.py b/jax/experimental/topologies.py index 9ac9cb1dd59d..7866564e9c01 100644 --- a/jax/experimental/topologies.py +++ b/jax/experimental/topologies.py @@ -21,6 +21,7 @@ import jax from jax.experimental import mesh_utils from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension from jax._src import xla_bridge as xb Device = xc.Device @@ -44,9 +45,15 @@ def get_topology_desc( topology_name, **kwargs )._make_compile_only_devices() ) - raise NotImplementedError( - "get_topology_desc(platform=%s) is not implemented" % repr(platform) - ) + try: + topology = xb.make_pjrt_topology(platform, topology_name, **kwargs) + return TopologyDescription(topology._make_compile_only_devices()) + except xla_extension.XlaRuntimeError as e: + msg, *_ = e.args + if msg.startswith("UNIMPLEMENTED"): + raise NotImplementedError(msg) from e + else: + raise # -- future mesh_utils -- diff --git a/tests/aot_test.py b/tests/aot_test.py index e6b6817e7bad..dacfa620c628 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -42,7 +42,7 @@ class JaxAotTest(jtu.JaxTestCase): - @jtu.run_on_devices('tpu') + @jtu.run_on_devices('tpu', 'gpu') def test_pickle_pjit_lower(self): def fun(x): return x * x