Skip to content

Commit

Permalink
Allow flags to be set with greater flexibility (iree-org#17659)
Browse files Browse the repository at this point in the history
Changes to the python binding to allow iree.runtime.flags.parse_flags to
take effect at times other than before the first time a driver is
created. Also includes fixes for bugs exposed during the development of
this feature.

- Added "internal" API functions `create_hal_driver()` and
`clear_hal_driver_cache()` to create a driver object independent of the
cache, and to clear the cache, respectively
- Added `HalDriver` class implementation functions for the above new API
functions. Refactored class to share as much common code as possible.
- Factored out driver URI processing into its own nested class for
easier handling of URI components
- Fixed dangling pointer bug. In the C layer flags are being kept by
reference as string views, requiring the caller to keep the original
flag strings (argc, argv) around for as long as the flags are being
used. However, the python binding was using a local variable for those
strings, letting them go out of scope and causing garbage values later
on. The fix is to move the strings to a file scope variable. Flag
handling does not appear to be getting used in a multi-threaded
environment, as other aspects of flag handling use static variables with
no mutex guarding that I could find.
- Fixed runtime assert in Windows debug build for the improper use of
std::vector<>::front() on an empty vector. The code never used the value
of front(), as it was guarded by a check for the vector's size, but the
assert prevents the debug build from running.

---------

Signed-off-by: Dave Liddell <dave.liddell@amd.com>
Signed-off-by: daveliddell <dave.liddell@amd.com>
  • Loading branch information
daveliddell committed Jun 14, 2024
1 parent 3428231 commit c5d4b96
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 17 deletions.
48 changes: 34 additions & 14 deletions runtime/bindings/python/hal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -905,29 +905,40 @@ std::vector<std::string> HalDriver::Query() {
return driver_names;
}

py::object HalDriver::Create(const std::string& device_uri,
py::dict& driver_cache) {
iree_string_view_t driver_name, device_path, params_str;
HalDriver::DeviceUri::DeviceUri(const std::string& device_uri) {
iree_string_view_t device_uri_sv{
device_uri.data(), static_cast<iree_host_size_t>(device_uri.size())};
iree_uri_split(device_uri_sv, &driver_name, &device_path, &params_str);
}

// Check cache.
py::str cache_key(driver_name.data, driver_name.size);
py::object cached = driver_cache.attr("get")(cache_key);
if (!cached.is_none()) {
return cached;
}

// Create.
py::object HalDriver::Create(const DeviceUri& device_uri) {
iree_hal_driver_t* driver;
CheckApiStatus(iree_hal_driver_registry_try_create(
iree_hal_driver_registry_default(), driver_name,
iree_hal_driver_registry_default(), device_uri.driver_name,
iree_allocator_system(), &driver),
"Error creating driver");

// Cache.
py::object driver_obj = py::cast(HalDriver::StealFromRawPtr(driver));
return driver_obj;
}

py::object HalDriver::Create(const std::string& device_uri) {
DeviceUri parsed_uri(device_uri);
return HalDriver::Create(parsed_uri);
}

py::object HalDriver::Create(const std::string& device_uri,
py::dict& driver_cache) {
// Look up the driver by driver name in the cache, and return it if found.
DeviceUri parsed_uri(device_uri);
py::str cache_key(parsed_uri.driver_name.data, parsed_uri.driver_name.size);
py::object cached = driver_cache.attr("get")(cache_key);
if (!cached.is_none()) {
return cached;
}

// Create a new driver and put it in the cache.
py::object driver_obj = HalDriver::Create(parsed_uri);
driver_cache[cache_key] = driver_obj;
return driver_obj;
}
Expand Down Expand Up @@ -1026,7 +1037,8 @@ HalDevice HalDriver::CreateDevice(iree_hal_device_id_t device_id,
std::vector<iree_string_pair_t> params;
iree_hal_device_t* device;
CheckApiStatus(iree_hal_driver_create_device_by_id(
raw_ptr(), device_id, params.size(), &params.front(),
raw_ptr(), device_id, params.size(),
(params.empty() ? nullptr : &params.front()),
iree_allocator_system(), &device),
"Error creating default device");
CheckApiStatus(ConfigureDevice(device, allocators),
Expand Down Expand Up @@ -1289,6 +1301,14 @@ void SetupHalBindings(nanobind::module_ m) {
},
py::arg("device_uri"));

m.def(
"create_hal_driver",
[](std::string device_uri) { return HalDriver::Create(device_uri); },
py::arg("device_uri"));

m.def("clear_hal_driver_cache",
[driver_cache]() { const_cast<py::dict&>(driver_cache).clear(); });

py::class_<HalAllocator>(m, "HalAllocator")
.def("trim",
[](HalAllocator& self) {
Expand Down
20 changes: 20 additions & 0 deletions runtime/bindings/python/hal.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "./binding.h"
#include "./status_utils.h"
#include "./vm.h"
#include "iree/base/string_view.h"
#include "iree/hal/api.h"

namespace iree {
Expand Down Expand Up @@ -142,8 +143,27 @@ class HalDevice : public ApiRefCounted<HalDevice, iree_hal_device_t> {
};

class HalDriver : public ApiRefCounted<HalDriver, iree_hal_driver_t> {
// Object that holds the components of a device URI string.
struct DeviceUri {
iree_string_view_t driver_name;
iree_string_view_t device_path;
iree_string_view_t params_str;

DeviceUri(const std::string& device_uri);
};

// Create a stand-alone driver (not residing in a cache) given the name,
// path, and params components of a device URI.
static py::object Create(const DeviceUri& device_uri);

public:
static std::vector<std::string> Query();

// Create a stand-alone driver (not residing in a cache) given a device URI.
static py::object Create(const std::string& device_uri);

// Returns a driver from the given cache, creating it and placing it in
// the cache if not already found there.
static py::object Create(const std::string& device_uri,
py::dict& driver_cache);

Expand Down
27 changes: 25 additions & 2 deletions runtime/bindings/python/initialize_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <memory>

#include "./binding.h"
#include "./hal.h"
#include "./invoke.h"
Expand All @@ -16,6 +18,14 @@
#include "iree/base/internal/flags.h"
#include "iree/hal/drivers/init.h"

namespace {
// Stable storage for flag processing. Flag handling uses string views,
// expecting the caller to keep the original strings around for as long
// as the flags are in use. This object holds one set of flag strings
// for each invocation of parse_flags.
std::vector<std::unique_ptr<std::vector<std::string>>> alloced_flag_cache;
} // namespace

namespace iree {
namespace python {

Expand All @@ -34,19 +44,32 @@ NB_MODULE(_runtime, m) {
SetupPyModuleBindings(m);
SetupVmBindings(m);

// Adds the given set of strings to the global flags. These new flags
// take effect upon the next creation of a driver. They do not affect
// drivers already created.
m.def("parse_flags", [](py::args py_flags) {
std::vector<std::string> alloced_flags;
// Make a new set of strings at the back of the cache
alloced_flag_cache.emplace_back(
std::make_unique<std::vector<std::string>>(std::vector<std::string>()));
auto &alloced_flags = *alloced_flag_cache.back();

// Add the given python strings to the std::string set.
alloced_flags.push_back("python");
for (py::handle py_flag : py_flags) {
alloced_flags.push_back(py::cast<std::string>(py_flag));
}

// Must build pointer vector after filling so pointers are stable.
// As the flags-processing mechanism of the C API requires long-lived
// char * strings, create a set of char * strings from the std::strings,
// with the std::strings responsible for maintaining the storage.
// Must build pointer vector after filling std::strings so pointers are
// stable.
std::vector<char *> flag_ptrs;
for (auto &alloced_flag : alloced_flags) {
flag_ptrs.push_back(const_cast<char *>(alloced_flag.c_str()));
}

// Send the flags to the C API
char **argv = &flag_ptrs[0];
int argc = flag_ptrs.size();
CheckApiStatus(iree_flags_parse(IREE_FLAGS_PARSE_MODE_CONTINUE_AFTER_HELP,
Expand Down
7 changes: 6 additions & 1 deletion runtime/bindings/python/iree/runtime/system_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ def query_available_drivers() -> Collection[str]:


def get_driver(device_uri: str) -> HalDriver:
"""Returns a HAL driver by device_uri (or driver name)."""
"""Returns a HAL driver by device_uri (or driver name).
Args:
device_uri: The URI of the device, either just a driver name for the
default or a fully qualified "driver://path?params".
"""
return get_cached_hal_driver(device_uri)


Expand Down
24 changes: 24 additions & 0 deletions runtime/bindings/python/tests/system_setup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import unittest

from iree.runtime import system_setup as ss
from iree.runtime._binding import create_hal_driver, clear_hal_driver_cache


class DeviceSetupTest(unittest.TestCase):
Expand Down Expand Up @@ -65,6 +66,29 @@ def testCreateDeviceWithAllocators(self):
infos[0]["device_id"], allocators=["caching", "debug"]
)

def testDriverCacheInternals(self):
# Two drivers created with the same URI using the caching get_driver
# should return the same driver
driver1 = ss.get_driver("local-sync")
driver2 = ss.get_driver("local-sync")
self.assertIs(driver1, driver2)

# A driver created using the non-caching create_hal_driver should be
# unique from cached drivers of the same URI
driver3 = create_hal_driver("local-sync")
self.assertIsNot(driver3, driver1)

# Drivers created with create_hal_driver should all be unique from
# one another
driver4 = create_hal_driver("local-sync")
self.assertIsNot(driver4, driver3)

# Clearing the cache should make any new driver unique from previously
# cached ones
clear_hal_driver_cache()
driver5 = ss.get_driver("local-sync")
self.assertIsNot(driver5, driver1)


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
Expand Down

0 comments on commit c5d4b96

Please sign in to comment.