Skip to content

Commit

Permalink
[PJRT C API] Check whether the PJRT_Api* for the device type already …
Browse files Browse the repository at this point in the history
…exists before calling dlopen and dlsym.

PiperOrigin-RevId: 531295150
  • Loading branch information
jax authors committed May 11, 2023
1 parent 1bef7c9 commit 0037ab6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
7 changes: 6 additions & 1 deletion jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,12 @@ def factory():
library_path = path
options = None

xla_client.load_pjrt_plugin_dynamically(name, library_path)
if xla_extension_version >= 152:
# Plugin may already be statically linked in some configurations.
if not xla_client.pjrt_plugin_loaded(name):
xla_client.load_pjrt_plugin_dynamically(name, library_path)
else:
xla_client.load_pjrt_plugin_dynamically(name, library_path)
return xla_client.make_c_api_client(name, options)

return factory
Expand Down
31 changes: 25 additions & 6 deletions tests/xla_bridge_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,13 @@ def test_register_plugin(self):
with mock.patch.object(
xc, "load_pjrt_plugin_dynamically", autospec=True
) as mock_load_plugin:
client_factory()
if xc._version >= 152:
with mock.patch.object(
xc, "pjrt_plugin_loaded", autospec=True
) as mock_plugin_loaded:
client_factory()
else:
client_factory()

self.assertRegex(
log_output[1][0],
Expand All @@ -108,7 +114,10 @@ def test_register_plugin(self):
self.assertIn("name1", xb._backend_factories)
self.assertIn("name2", xb._backend_factories)
self.assertEqual(priotiy, 400)
mock_load_plugin.assert_called_once_with("name1", "path1")
if xc._version >= 152:
mock_plugin_loaded.assert_called_once_with("name1")
else:
mock_load_plugin.assert_called_once_with("name1", "path1")
if xc._version >= 134:
mock_make.assert_called_once_with("name1", None)
else:
Expand All @@ -126,13 +135,22 @@ def test_register_plugin_with_config(self):
with mock.patch.object(
xc, "load_pjrt_plugin_dynamically", autospec=True
) as mock_load_plugin:
client_factory()
if xc._version >= 152:
with mock.patch.object(
xc, "pjrt_plugin_loaded", autospec=True
) as mock_plugin_loaded:
client_factory()
else:
client_factory()

self.assertIn("name1", xb._backend_factories)
self.assertEqual(priority, 400)
mock_load_plugin.assert_called_once_with(
"name1", "/path/pjrt_plugin_name1.so"
)
if xc._version >= 152:
mock_plugin_loaded.assert_called_once_with("name1")
else:
mock_load_plugin.assert_called_once_with(
"name1", "/path/pjrt_plugin_name1.so"
)
mock_make.assert_called_once_with(
"name1",
{
Expand All @@ -147,6 +165,7 @@ def test_register_plugin_with_config(self):
class GetBackendTest(jtu.JaxTestCase):

class _DummyBackend:

def __init__(self, platform, device_count):
self.platform = platform
self._device_count = device_count
Expand Down

0 comments on commit 0037ab6

Please sign in to comment.