Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow flags to be set with greater flexibility #17659

Merged
merged 3 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions runtime/bindings/python/hal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -906,16 +906,19 @@ std::vector<std::string> HalDriver::Query() {
}

py::object HalDriver::Create(const std::string& device_uri,
py::dict& driver_cache) {
py::dict& driver_cache,
std::optional<bool> clean) {
daveliddell marked this conversation as resolved.
Show resolved Hide resolved
iree_string_view_t driver_name, device_path, params_str;
iree_string_view_t device_uri_sv{
device_uri.data(), static_cast<iree_host_size_t>(device_uri.size())};
iree_uri_split(device_uri_sv, &driver_name, &device_path, &params_str);

// Check cache.
// Check cache. Use the cached value if present and there is no request
// to clean out the old value.
py::str cache_key(driver_name.data, driver_name.size);
py::object cached = driver_cache.attr("get")(cache_key);
if (!cached.is_none()) {
bool clean_requested = clean.has_value() && clean.value();
if (!clean_requested && !cached.is_none()) {
return cached;
}

Expand Down Expand Up @@ -1026,7 +1029,8 @@ HalDevice HalDriver::CreateDevice(iree_hal_device_id_t device_id,
std::vector<iree_string_pair_t> params;
iree_hal_device_t* device;
CheckApiStatus(iree_hal_driver_create_device_by_id(
raw_ptr(), device_id, params.size(), &params.front(),
raw_ptr(), device_id, params.size(),
(params.empty() ? nullptr : &params.front()),
daveliddell marked this conversation as resolved.
Show resolved Hide resolved
iree_allocator_system(), &device),
"Error creating default device");
CheckApiStatus(ConfigureDevice(device, allocators),
Expand Down Expand Up @@ -1283,11 +1287,11 @@ void SetupHalBindings(nanobind::module_ m) {

m.def(
daveliddell marked this conversation as resolved.
Show resolved Hide resolved
"get_cached_hal_driver",
[driver_cache](std::string device_uri) {
[driver_cache](std::string device_uri, std::optional<bool> clean) {
daveliddell marked this conversation as resolved.
Show resolved Hide resolved
return HalDriver::Create(device_uri,
const_cast<py::dict&>(driver_cache));
const_cast<py::dict&>(driver_cache), clean);
},
py::arg("device_uri"));
py::arg("device_uri"), py::arg("clean") = py::none());

py::class_<HalAllocator>(m, "HalAllocator")
.def("trim",
Expand Down
2 changes: 1 addition & 1 deletion runtime/bindings/python/hal.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class HalDriver : public ApiRefCounted<HalDriver, iree_hal_driver_t> {
public:
static std::vector<std::string> Query();
static py::object Create(const std::string& device_uri,
py::dict& driver_cache);
py::dict& driver_cache, std::optional<bool> clean);

py::list QueryAvailableDevices();
HalDevice CreateDefaultDevice(std::optional<py::list> allocators);
Expand Down
9 changes: 8 additions & 1 deletion runtime/bindings/python/initialize_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
#include "iree/base/internal/flags.h"
#include "iree/hal/drivers/init.h"

namespace {
// Stable storage for flag processing. Flag handling uses string views,
// expecting the caller to keep the original strings around for as long
// as the flags are in use.
std::vector<std::string> alloced_flags;
daveliddell marked this conversation as resolved.
Show resolved Hide resolved
} // namespace

namespace iree {
namespace python {

Expand All @@ -35,7 +42,7 @@ NB_MODULE(_runtime, m) {
SetupVmBindings(m);

m.def("parse_flags", [](py::args py_flags) {
std::vector<std::string> alloced_flags;
alloced_flags.clear();
alloced_flags.push_back("python");
for (py::handle py_flag : py_flags) {
alloced_flags.push_back(py::cast<std::string>(py_flag));
Expand Down
13 changes: 10 additions & 3 deletions runtime/bindings/python/iree/runtime/system_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,16 @@ def query_available_drivers() -> Collection[str]:
return HalDriver.query()


def get_driver(device_uri: str) -> HalDriver:
"""Returns a HAL driver by device_uri (or driver name)."""
return get_cached_hal_driver(device_uri)
def get_driver(device_uri: str, clean: bool = False) -> HalDriver:
daveliddell marked this conversation as resolved.
Show resolved Hide resolved
"""Returns a HAL driver by device_uri (or driver name).

Args:
device_uri: The URI of the device, either just a driver name for the
default or a fully qualified "driver://path?params".
clean: Whether to clean out any cached driver and make a new one
(default False).
"""
return get_cached_hal_driver(device_uri, clean)


def get_device(device_uri: str, cache: bool = True) -> HalDevice:
Expand Down
Loading