diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc index 18771af9e0be..8799e4884fbd 100644 --- a/runtime/bindings/python/hal.cc +++ b/runtime/bindings/python/hal.cc @@ -905,29 +905,40 @@ std::vector HalDriver::Query() { return driver_names; } -py::object HalDriver::Create(const std::string& device_uri, - py::dict& driver_cache) { - iree_string_view_t driver_name, device_path, params_str; +HalDriver::DeviceUri::DeviceUri(const std::string& device_uri) { iree_string_view_t device_uri_sv{ device_uri.data(), static_cast(device_uri.size())}; iree_uri_split(device_uri_sv, &driver_name, &device_path, ¶ms_str); +} - // Check cache. - py::str cache_key(driver_name.data, driver_name.size); - py::object cached = driver_cache.attr("get")(cache_key); - if (!cached.is_none()) { - return cached; - } - - // Create. +py::object HalDriver::Create(const DeviceUri& device_uri) { iree_hal_driver_t* driver; CheckApiStatus(iree_hal_driver_registry_try_create( - iree_hal_driver_registry_default(), driver_name, + iree_hal_driver_registry_default(), device_uri.driver_name, iree_allocator_system(), &driver), "Error creating driver"); - // Cache. py::object driver_obj = py::cast(HalDriver::StealFromRawPtr(driver)); + return driver_obj; +} + +py::object HalDriver::Create(const std::string& device_uri) { + DeviceUri parsed_uri(device_uri); + return HalDriver::Create(parsed_uri); +} + +py::object HalDriver::Create(const std::string& device_uri, + py::dict& driver_cache) { + // Look up the driver by driver name in the cache, and return it if found. + DeviceUri parsed_uri(device_uri); + py::str cache_key(parsed_uri.driver_name.data, parsed_uri.driver_name.size); + py::object cached = driver_cache.attr("get")(cache_key); + if (!cached.is_none()) { + return cached; + } + + // Create a new driver and put it in the cache. + py::object driver_obj = HalDriver::Create(parsed_uri); driver_cache[cache_key] = driver_obj; return driver_obj; } @@ -1026,7 +1037,8 @@ HalDevice HalDriver::CreateDevice(iree_hal_device_id_t device_id, std::vector params; iree_hal_device_t* device; CheckApiStatus(iree_hal_driver_create_device_by_id( - raw_ptr(), device_id, params.size(), ¶ms.front(), + raw_ptr(), device_id, params.size(), + (params.empty() ? nullptr : ¶ms.front()), iree_allocator_system(), &device), "Error creating default device"); CheckApiStatus(ConfigureDevice(device, allocators), @@ -1289,6 +1301,14 @@ void SetupHalBindings(nanobind::module_ m) { }, py::arg("device_uri")); + m.def( + "create_hal_driver", + [](std::string device_uri) { return HalDriver::Create(device_uri); }, + py::arg("device_uri")); + + m.def("clear_hal_driver_cache", + [driver_cache]() { const_cast(driver_cache).clear(); }); + py::class_(m, "HalAllocator") .def("trim", [](HalAllocator& self) { diff --git a/runtime/bindings/python/hal.h b/runtime/bindings/python/hal.h index 29d02334e959..7dbc108917c3 100644 --- a/runtime/bindings/python/hal.h +++ b/runtime/bindings/python/hal.h @@ -12,6 +12,7 @@ #include "./binding.h" #include "./status_utils.h" #include "./vm.h" +#include "iree/base/string_view.h" #include "iree/hal/api.h" namespace iree { @@ -142,8 +143,27 @@ class HalDevice : public ApiRefCounted { }; class HalDriver : public ApiRefCounted { + // Object that holds the components of a device URI string. + struct DeviceUri { + iree_string_view_t driver_name; + iree_string_view_t device_path; + iree_string_view_t params_str; + + DeviceUri(const std::string& device_uri); + }; + + // Create a stand-alone driver (not residing in a cache) given the name, + // path, and params components of a device URI. + static py::object Create(const DeviceUri& device_uri); + public: static std::vector Query(); + + // Create a stand-alone driver (not residing in a cache) given a device URI. + static py::object Create(const std::string& device_uri); + + // Returns a driver from the given cache, creating it and placing it in + // the cache if not already found there. static py::object Create(const std::string& device_uri, py::dict& driver_cache); diff --git a/runtime/bindings/python/initialize_module.cc b/runtime/bindings/python/initialize_module.cc index 7eb9254cf3b0..c79da46353fc 100644 --- a/runtime/bindings/python/initialize_module.cc +++ b/runtime/bindings/python/initialize_module.cc @@ -4,6 +4,8 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include + #include "./binding.h" #include "./hal.h" #include "./invoke.h" @@ -16,6 +18,14 @@ #include "iree/base/internal/flags.h" #include "iree/hal/drivers/init.h" +namespace { +// Stable storage for flag processing. Flag handling uses string views, +// expecting the caller to keep the original strings around for as long +// as the flags are in use. This object holds one set of flag strings +// for each invocation of parse_flags. +std::vector>> alloced_flag_cache; +} // namespace + namespace iree { namespace python { @@ -34,19 +44,32 @@ NB_MODULE(_runtime, m) { SetupPyModuleBindings(m); SetupVmBindings(m); + // Adds the given set of strings to the global flags. These new flags + // take effect upon the next creation of a driver. They do not affect + // drivers already created. m.def("parse_flags", [](py::args py_flags) { - std::vector alloced_flags; + // Make a new set of strings at the back of the cache + alloced_flag_cache.emplace_back( + std::make_unique>(std::vector())); + auto &alloced_flags = *alloced_flag_cache.back(); + + // Add the given python strings to the std::string set. alloced_flags.push_back("python"); for (py::handle py_flag : py_flags) { alloced_flags.push_back(py::cast(py_flag)); } - // Must build pointer vector after filling so pointers are stable. + // As the flags-processing mechanism of the C API requires long-lived + // char * strings, create a set of char * strings from the std::strings, + // with the std::strings responsible for maintaining the storage. + // Must build pointer vector after filling std::strings so pointers are + // stable. std::vector flag_ptrs; for (auto &alloced_flag : alloced_flags) { flag_ptrs.push_back(const_cast(alloced_flag.c_str())); } + // Send the flags to the C API char **argv = &flag_ptrs[0]; int argc = flag_ptrs.size(); CheckApiStatus(iree_flags_parse(IREE_FLAGS_PARSE_MODE_CONTINUE_AFTER_HELP, diff --git a/runtime/bindings/python/iree/runtime/system_setup.py b/runtime/bindings/python/iree/runtime/system_setup.py index 0560003d5f1d..8cd117d56acc 100644 --- a/runtime/bindings/python/iree/runtime/system_setup.py +++ b/runtime/bindings/python/iree/runtime/system_setup.py @@ -26,7 +26,12 @@ def query_available_drivers() -> Collection[str]: def get_driver(device_uri: str) -> HalDriver: - """Returns a HAL driver by device_uri (or driver name).""" + """Returns a HAL driver by device_uri (or driver name). + + Args: + device_uri: The URI of the device, either just a driver name for the + default or a fully qualified "driver://path?params". + """ return get_cached_hal_driver(device_uri) diff --git a/runtime/bindings/python/tests/system_setup_test.py b/runtime/bindings/python/tests/system_setup_test.py index 2d0ddf9ea2cb..c55dc466598f 100644 --- a/runtime/bindings/python/tests/system_setup_test.py +++ b/runtime/bindings/python/tests/system_setup_test.py @@ -8,6 +8,7 @@ import unittest from iree.runtime import system_setup as ss +from iree.runtime._binding import create_hal_driver, clear_hal_driver_cache class DeviceSetupTest(unittest.TestCase): @@ -65,6 +66,29 @@ def testCreateDeviceWithAllocators(self): infos[0]["device_id"], allocators=["caching", "debug"] ) + def testDriverCacheInternals(self): + # Two drivers created with the same URI using the caching get_driver + # should return the same driver + driver1 = ss.get_driver("local-sync") + driver2 = ss.get_driver("local-sync") + self.assertIs(driver1, driver2) + + # A driver created using the non-caching create_hal_driver should be + # unique from cached drivers of the same URI + driver3 = create_hal_driver("local-sync") + self.assertIsNot(driver3, driver1) + + # Drivers created with create_hal_driver should all be unique from + # one another + driver4 = create_hal_driver("local-sync") + self.assertIsNot(driver4, driver3) + + # Clearing the cache should make any new driver unique from previously + # cached ones + clear_hal_driver_cache() + driver5 = ss.get_driver("local-sync") + self.assertIsNot(driver5, driver1) + if __name__ == "__main__": logging.basicConfig(level=logging.INFO)