Skip to content

Commit

Permalink
[PJRT C API] Supports plugins to register a method to create topology.
Browse files Browse the repository at this point in the history
- Add a topology factory registration to xla_bridge.py.
- Move discovery and registration of plugins to the first time backends() or make_pjrt_topology() is called.

PiperOrigin-RevId: 609544983
  • Loading branch information
jyingl3 authored and jax authors committed Feb 23, 2024
1 parent be002b5 commit e0bba8f
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 21 deletions.
57 changes: 40 additions & 17 deletions jax/_src/xla_bridge.py
Expand Up @@ -168,6 +168,7 @@ def _log_warning():
# device.

BackendFactory = Callable[[], Union[xla_client.Client, None]]
TopologyFactory = Callable[..., Union[xla_client.DeviceTopology, None]]

@dataclasses.dataclass
class BackendRegistration:
Expand Down Expand Up @@ -195,6 +196,7 @@ class BackendRegistration:
_backend_lock = threading.Lock()
_plugins_registered: bool = False
_plugin_lock = threading.Lock()
_topology_factories: dict[str, TopologyFactory] = {}

# The set of known non-experimental plugins.
#
Expand All @@ -209,12 +211,15 @@ class BackendRegistration:
def register_backend_factory(name: str, factory: BackendFactory, *,
priority: int = 0,
fail_quietly: bool = True,
experimental: bool = False) -> None:
experimental: bool = False,
make_topology: TopologyFactory | None = None) -> None:
with _backend_lock:
if name in _backends:
raise RuntimeError(f"Backend {name} already initialized")
_backend_factories[name] = BackendRegistration(
factory, priority, fail_quietly, experimental)
if make_topology is not None:
_topology_factories[name] = make_topology


def make_cpu_client() -> xla_client.Client:
Expand Down Expand Up @@ -535,16 +540,21 @@ def factory():
logger.debug(
'registering PJRT plugin %s from %s', plugin_name, library_path
)
experimental = plugin_name not in _nonexperimental_plugins
register_backend_factory(plugin_name, factory, priority=priority,
fail_quietly=False, experimental=experimental)
if library_path is not None:
c_api = xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path) # type: ignore
xla_client.profiler.register_plugin_profiler(c_api)
else:
if xla_extension_version >= 236:
assert c_api is not None
xla_client.load_pjrt_plugin_with_c_api(plugin_name, c_api)
if xla_extension_version >= 239:
make_topology = partial(xla_client.make_c_api_device_topology, c_api)
else:
make_topology = None
experimental = plugin_name not in _nonexperimental_plugins
register_backend_factory(plugin_name, factory, priority=priority,
fail_quietly=False, experimental=experimental,
make_topology=make_topology)
return c_api


Expand Down Expand Up @@ -577,6 +587,23 @@ def register_pjrt_plugin_factories_from_env() -> None:
register_plugin(plugin_name, library_path=library_path, options=options)


def _discover_and_register_pjrt_plugins():
global _plugins_registered

# Needs a separate lock because register_backend_factory (called from
# register_plugin) requires to hold _backend_lock.
with _plugin_lock:
if not _plugins_registered:
# Plugins in the namespace package `jax_plugins` or have an entry-point
# under the `jax_plugins` group will be imported.
discover_pjrt_plugins()
# Registers plugins names and paths set in env var
# PJRT_NAMES_AND_LIBRARY_PATHS, in the format of 'name1:path1,name2:path2'
# ('name1;path1,name2;path2' for windows).
register_pjrt_plugin_factories_from_env()
_plugins_registered = True


_platform_aliases = {
"cuda": "gpu",
"rocm": "gpu",
Expand Down Expand Up @@ -638,20 +665,8 @@ def backends() -> dict[str, xla_client.Client]:
global _backends
global _backend_errors
global _default_backend
global _plugins_registered

# Needs a separate lock because register_backend_factory (called from
# register_plugin) requries to hold _backend_lock.
with _plugin_lock:
if not _plugins_registered:
# Plugins in the namespace package `jax_plugins` or have an entry-point
# under the `jax_plugins` group will be imported.
discover_pjrt_plugins()
# Registers plugins names and paths set in env var
# PJRT_NAMES_AND_LIBRARY_PATHS, in the format of 'name1:path1,name2:path2'
# ('name1;path1,name2;path2' for windows).
register_pjrt_plugin_factories_from_env()
_plugins_registered = True
_discover_and_register_pjrt_plugins()

with _backend_lock:
if _backends:
Expand Down Expand Up @@ -978,6 +993,14 @@ def host_ids(
def using_pjrt_c_api(backend=None):
return "PJRT C API" in get_backend(backend).platform_version

def make_pjrt_topology(platform: str, topology_name='', **kwargs):
_discover_and_register_pjrt_plugins()
actual_platform = canonicalize_platform(platform)
with _backend_lock:
if actual_platform in _topology_factories:
return _topology_factories[actual_platform](topology_name, **kwargs)
raise NotImplementedError("topology not implemented for %s" % platform)


# TODO(parkers): Get rid of this in favor of a generic way to get topologies.
def make_pjrt_tpu_topology(topology_name='', **kwargs):
Expand Down
13 changes: 10 additions & 3 deletions jax/experimental/topologies.py
Expand Up @@ -21,6 +21,7 @@
import jax
from jax.experimental import mesh_utils
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src import xla_bridge as xb

Device = xc.Device
Expand All @@ -44,9 +45,15 @@ def get_topology_desc(
topology_name, **kwargs
)._make_compile_only_devices()
)
raise NotImplementedError(
"get_topology_desc(platform=%s) is not implemented" % repr(platform)
)
try:
topology = xb.make_pjrt_topology(platform, topology_name, **kwargs)
return TopologyDescription(topology._make_compile_only_devices())
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
if msg.startswith("UNIMPLEMENTED"):
raise NotImplementedError(msg) from e
else:
raise


# -- future mesh_utils --
Expand Down
2 changes: 1 addition & 1 deletion tests/aot_test.py
Expand Up @@ -42,7 +42,7 @@

class JaxAotTest(jtu.JaxTestCase):

@jtu.run_on_devices('tpu')
@jtu.run_on_devices('tpu', 'gpu')
def test_pickle_pjit_lower(self):
def fun(x):
return x * x
Expand Down

0 comments on commit e0bba8f

Please sign in to comment.