Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 54 additions & 52 deletions src/winml/modelkit/session/ep_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ def __init__(self) -> None:
self._initialized = True

self._ep_paths: dict[EPName, str] = {}
self._registered_eps: list[EPName] = []
self._registered_eps: dict[str, list[EPName]] = {
"onnxruntime": [],
Comment thread
xieofxie marked this conversation as resolved.
"onnxruntime_genai": [],
}
self._winml_available = False
self._catalog = None

self._discover_eps()

Expand All @@ -75,33 +77,17 @@ def _load_ep_catalog(self) -> None:
"""Load EP catalog from WinML."""
from windowsml import EpCatalog

self._catalog = EpCatalog()
providers = self._catalog.find_all_providers()

for provider in providers:
try:
provider.ensure_ready()
except Exception as e:
logger.debug("Failed to ensure EP %s is ready: %s", provider.name, e)
continue
if provider.library_path == "":
continue
self._ep_paths[cast("EPName", provider.name)] = provider.library_path
logger.debug("Found EP: %s at %s", provider.name, provider.library_path)

# Workaround: WinMLEpCatalogRelease (called by EpCatalog.close() /
# EpCatalog.__del__) crashes with ACCESS_VIOLATION (0xC0000005) on some
# QNN NPU driver configurations — a Windows SEH exception that Python
# try/except cannot catch. All provider paths have been extracted
# above, so the catalog handle is no longer needed. Null it out
# immediately so that EpCatalog.close() becomes a no-op for the
# remainder of the process lifetime, whether invoked from a background
# thread or interpreter shutdown. Native resources are reclaimed by
# the OS when the process exits.
# TODO: Remove once windowsml fixes WinMLEpCatalogRelease to be safe
# during process teardown on all QNN NPU driver configurations.
if hasattr(self._catalog, "_handle"):
self._catalog._handle = None
with EpCatalog() as catalog:
for provider in catalog.find_all_providers():
try:
provider.ensure_ready()
except Exception as e:
logger.debug("Failed to ensure EP %s is ready: %s", provider.name, e)
continue
if provider.library_path == "":
continue
self._ep_paths[cast("EPName", provider.name)] = provider.library_path
logger.debug("Found EP: %s at %s", provider.name, provider.library_path)

def register_to_ort(self) -> list[EPName]:
"""Register discovered EPs to ONNX Runtime.
Expand All @@ -113,21 +99,47 @@ def register_to_ort(self) -> list[EPName]:
logger.warning("WinML not available, skipping EP registration")
return []

import onnxruntime as ort
result = self.register_execution_providers(ort=True)
return result.get("onnxruntime", []).copy()

for name, dll_path in self._ep_paths.items():
if name in self._registered_eps:
continue
def register_execution_providers(
self, ort: bool = True, ort_genai: bool = False
) -> dict[str, list[EPName]]:
"""Register WinML execution providers for ONNX Runtime modules.

try:
# Use ORT's native EP registration API
ort.register_execution_provider_library(name, dll_path)
self._registered_eps.append(name)
logger.debug("Registered EP: %s -> %s", name, dll_path)
except Exception as e:
logger.warning("Failed to register EP %s: %s", name, e)
Args:
ort: Whether to register for ONNX Runtime.
ort_genai: Whether to register for ONNX Runtime GenAI.

return self._registered_eps.copy()
Returns:
Dictionary of registered execution provider names by module.
"""
modules = []
if ort:
import onnxruntime

modules.append(onnxruntime)
if ort_genai:
import onnxruntime_genai # type: ignore[import-not-found]

modules.append(onnxruntime_genai)
for name, path in self._ep_paths.items():
for module in modules:
if name not in self._registered_eps[module.__name__]:
try:
module.register_execution_provider_library(name, path)
self._registered_eps[module.__name__].append(name)
logger.debug(
"Registered EP: %s from %s for module %s", name, path, module.__name__
)
except Exception:
logger.exception(
"Failed to register %s from %s for module %s",
name,
path,
module.__name__,
)
return self._registered_eps

def get_ep_library_path(self, ep_name: EPName) -> str | None:
"""Get the library path for an EP."""
Expand All @@ -139,7 +151,7 @@ def get_available_eps(self) -> dict[EPName, str]:

def get_registered_eps(self) -> list[EPName]:
"""Get list of EPs registered with ORT."""
return self._registered_eps.copy()
return self._registered_eps["onnxruntime"].copy()

def is_ep_available(self, ep_name: EPName) -> bool:
"""Check if an EP is available."""
Expand All @@ -150,16 +162,6 @@ def winml_available(self) -> bool:
"""Whether WinML is available."""
return self._winml_available

def __del__(self) -> None:
"""Cleanup EP catalog."""
catalog = getattr(self, "_catalog", None)
if catalog is not None:
try:
catalog.close()
except Exception as e:
logger.debug("Error cleaning up EP catalog: %s", e)
self._catalog = None

@classmethod
def get_instance(cls) -> WinMLEPRegistry:
"""Get singleton instance."""
Expand Down
103 changes: 8 additions & 95 deletions src/winml/modelkit/winml.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,110 +15,23 @@

logger = logging.getLogger(__name__)

_winml_instance: WinML | None = None


class WinML:
"""Singleton class for managing WinML execution providers."""

_initialized: bool

def __new__(cls, *args: Any, **kwargs: Any) -> WinML:
"""Create or return the singleton instance."""
global _winml_instance
if _winml_instance is None:
_winml_instance = super().__new__(cls, *args, **kwargs)
_winml_instance._initialized = False
return _winml_instance

def __init__(self) -> None:
"""Initialize WinML execution provider catalog."""
if self._initialized:
return
self._initialized = True

from windowsml import EpCatalog

self._catalog = EpCatalog()
self._providers = self._catalog.find_all_providers()
self._ep_paths: dict[str, str] = {}
for provider in self._providers:
provider.ensure_ready()
if provider.library_path == "":
continue
self._ep_paths[provider.name] = provider.library_path
self._registered_eps: dict[str, list[str]] = {
"onnxruntime": [],
"onnxruntime_genai": [],
}

# Workaround: WinMLEpCatalogRelease (called by EpCatalog.close() /
# EpCatalog.__del__) crashes with ACCESS_VIOLATION (0xC0000005) on some
# QNN NPU driver configurations — a Windows SEH exception that Python
# try/except cannot catch. All provider paths have been extracted
# above, so the catalog handle is no longer needed. Null it out
# immediately so that EpCatalog.close() becomes a no-op for the
# remainder of the process lifetime, whether invoked from a background
# thread or interpreter shutdown. Native resources are reclaimed by
# the OS when the process exits.
# TODO: Remove once windowsml fixes WinMLEpCatalogRelease to be safe
# during process teardown on all QNN NPU driver configurations.
if hasattr(self._catalog, "_handle"):
self._catalog._handle = None

def __del__(self) -> None:
"""Clean up WinML resources."""
self._providers = None
self._catalog = None

def register_execution_providers(
self, ort: bool = True, ort_genai: bool = False
) -> dict[str, list[str]]:
"""Register WinML execution providers for ONNX Runtime modules.

Args:
ort: Whether to register for ONNX Runtime.
ort_genai: Whether to register for ONNX Runtime GenAI.

Returns:
Dictionary of registered execution provider names by module.
"""
modules = []
if ort:
import onnxruntime

modules.append(onnxruntime)
if ort_genai:
import onnxruntime_genai # type: ignore[import-not-found]

modules.append(onnxruntime_genai)
for name, path in self._ep_paths.items():
for module in modules:
if name not in self._registered_eps[module.__name__]:
try:
module.register_execution_provider_library(name, path)
self._registered_eps[module.__name__].append(name)
except Exception:
logger.exception(
"Failed to register %s for module %s",
name,
module.__name__,
)
return self._registered_eps


def register_execution_providers(ort: bool = True, ort_genai: bool = False) -> dict[str, list[str]]:

def register_execution_providers(
ort: bool = True, ort_genai: bool = False
) -> dict[str, list[EPName]]:
"""Register WinML execution providers for ONNX Runtime and ONNX Runtime GenAI.

Args:
ort (bool): Whether to register for ONNX Runtime.
ort_genai (bool): Whether to register for ONNX Runtime GenAI.

Returns:
dict[str, list[str]]: Dictionary of registered execution provider names
dict[str, list[EPName]]: Dictionary of registered execution provider names
by module.
"""
return WinML().register_execution_providers(ort=ort, ort_genai=ort_genai)
from .session import WinMLEPRegistry

return WinMLEPRegistry.get_instance().register_execution_providers(ort=ort, ort_genai=ort_genai)


@functools.lru_cache(maxsize=1)
Expand Down
Loading