Skip to content

Commit

Permalink
Extend plugin discovery to also include entry-points.
Browse files Browse the repository at this point in the history
This effectively implements a mix of option 2 and option 3 from https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ as a pragmatic way to cover all packaging cases. The namespace/path based iteration works for situations where code has not been packaged and is present on the PYTHONPATH, whereas the advertised entry-points work around setuptools/pkgutil issues that make it impossible to reliably iterate over installed modules in certain scenarios (noted for editable installs which use a custom finder that does not implement iter_modules()).

A plugin entry-point can be advertised in setup.py (or equivalent pyproject.toml) with something like:

```
    entry_points={
        "jax_plugins": [
          "openxla-cpu = jax_plugins.openxla_cpu",
        ],
    }
```
  • Loading branch information
stellaraccident committed May 19, 2023
1 parent 1d20d2f commit 221aa76
Showing 1 changed file with 52 additions and 8 deletions.
60 changes: 52 additions & 8 deletions jax/_src/xla_bridge.py
Expand Up @@ -27,6 +27,7 @@
import os
import platform as py_platform
import pkgutil
import sys
import threading
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
import warnings
Expand Down Expand Up @@ -324,14 +325,57 @@ def discover_pjrt_plugins() -> None:
method, which calls jax._src.xla_bridge.register_plugin with its plugin_name,
path to .so file, and optional create options.
"""
if jax_plugins is None:
return
for _, name, _ in pkgutil.iter_modules(
jax_plugins.__path__, jax_plugins.__name__ + '.'
):
# TODO(b/261345120): Add a try-catch to defend against a broken plugin.
module = importlib.import_module(name)
module.initialize()
plugin_modules = set()
# Scan installed modules under |jax_plugins|. Note that not all packaging
# scenarios are amenable to such scanning, so we also use the entry-point
# method to seed the list.
if jax_plugins:
for _, name, _ in pkgutil.iter_modules(
jax_plugins.__path__, jax_plugins.__name__ + '.'
):
logger.debug("Discovered path based JAX plugin: %s", name)
plugin_modules.add(name)
else:
logger.debug("No jax_plugins namespace packages available")

# Augment with advertised entrypoints.
if sys.version_info < (3, 10):
# Use the backport library because it provides a forward-compatible
# implementation.
try:
from importlib_metadata import entry_points
except ModuleNotFoundError:
logger.debug(
f"No importlib_metadata found (for Python < 3.10): "
f"Plugins advertised from entrypoints will not be found.")
entry_points = None
else:
from importlib.metadata import entry_points
if entry_points:
for entry_point in entry_points(group="jax_plugins"):
logger.debug("Discovered entry-point based JAX plugin: %s",
entry_point.value)
plugin_modules.add(entry_point.value)

# Now load and initialize them all.
for plugin_module_name in plugin_modules:
logger.debug("Loading plugin module %s", plugin_module_name)
plugin_module = None
try:
plugin_module = importlib.import_module(plugin_module_name)
except ModuleNotFoundError:
logger.warning("Jax plugin configuration error: Plugin module %s "
"does not exist", plugin_module_name)
except ImportError:
logger.exception("Jax plugin configuration error: Plugin module %s "
"could not be loaded")

if plugin_module:
try:
plugin_module.initialize()
except:
logger.exception("Jax plugin configuration error: Exception when "
"calling %s.initialize()", plugin_module_name)


# TODO(b/261345120): decide on a public name and expose a public method which is
Expand Down

0 comments on commit 221aa76

Please sign in to comment.