Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 205 additions & 3 deletions src/winml/modelkit/session/ep_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, cast
import os
from typing import TYPE_CHECKING, Any, cast


if TYPE_CHECKING:
Expand All @@ -20,6 +21,197 @@

logger = logging.getLogger(__name__)


Comment thread
xieofxie marked this conversation as resolved.
def _ep_download_timeout_default() -> int:
"""Read ``WINMLCLI_EP_DOWNLOAD_TIMEOUT`` (seconds) or fall back to 5 minutes.

Lets users on slow networks raise the cap without code changes. Falls back
to the default when the env var is unset, empty, or non-integer.
"""
raw = os.environ.get("WINMLCLI_EP_DOWNLOAD_TIMEOUT")
if not raw:
return 5 * 60
try:
return int(raw)
except ValueError:
Comment thread
xieofxie marked this conversation as resolved.
logger.warning("Invalid WINMLCLI_EP_DOWNLOAD_TIMEOUT=%r; using default 300s.", raw)
return 5 * 60


# Evaluated once at module import. Changing WINMLCLI_EP_DOWNLOAD_TIMEOUT
# after import does NOT take effect for the running process; tests that need
# a different value should monkeypatch ep_registry.EP_DOWNLOAD_TIMEOUT_SECONDS
# directly.
EP_DOWNLOAD_TIMEOUT_SECONDS = _ep_download_timeout_default()


class _NoopBar:
Comment thread
xieofxie marked this conversation as resolved.
"""No-op stand-in for tqdm when the optional dependency is missing.

Exposes the attribute (``n``) and methods (``refresh``, ``close``) that
``_ensure_provider_ready`` touches, so the helper can stay branch-free.
"""

def __init__(self) -> None:
self.n = 0

def refresh(self) -> None:
return None

def close(self) -> None:
return None


def _make_progress_bar() -> Any:
"""Return a tqdm bar if tqdm is installed, else a silent no-op stand-in.

tqdm is a dev-only optional dep in this package, so production installs
without it must still complete EP downloads — they just lose the live bar.
The pre-download Console notice is emitted by the caller and is unaffected.

Format: ``Downloading... ████████████░░░░░░ 62%``
"""
try:
from tqdm import tqdm
except ImportError:
return _NoopBar()
return tqdm(
total=100,
bar_format="Downloading... {bar} {percentage:3.0f}%",
ascii="░█",
leave=True,
)


def _parse_ep_metadata_from_path(library_path: str) -> tuple[str, str]:
r"""Best-effort ``(version, package_family_name)`` from an EP's install path.

WinML's ``ExecutionProvider`` handle sometimes returns empty ``version`` /
``package_family_name`` even after the EP is Ready. When the EP is delivered
as an MSIX package its ``library_path`` lives under ``WindowsApps`` in a
folder named with the full package identity::

...\\WindowsApps\\<Name>_<Version>_<Arch>_<ResourceId>_<PublisherId>\\...

e.g. ``MicrosoftCorporationII.WinML.Intel.OpenVINO.EP.1.8_1.8.79.0_x64__8wekyb3d8bbwe``
yields version ``1.8.79.0`` and package family name
``MicrosoftCorporationII.WinML.Intel.OpenVINO.EP.1.8_8wekyb3d8bbwe`` (the
package family name is ``<Name>_<PublisherId>``).
Comment thread
xieofxie marked this conversation as resolved.

Returns ``("", "")`` when the path is empty or does not match this layout.
"""
import re
from itertools import pairwise
from pathlib import PurePath

if not library_path:
return "", ""

parts = PurePath(library_path).parts
pkg_folder = next(
(child for parent, child in pairwise(parts) if parent.lower() == "windowsapps"),
"",
)
# Full MSIX package name: Name_Version_Arch_ResourceId_PublisherId (ResourceId
# is usually empty, giving the doubled "__" before the publisher id).
segments = pkg_folder.split("_")
if len(segments) < 5:
return "", ""

name, version, publisher = segments[0], segments[1], segments[-1]
# Guard against unexpected folder shapes: version must be dotted-numeric.
if not re.fullmatch(r"\d+(\.\d+)*", version):
version = ""
package_family_name = f"{name}_{publisher}" if name and publisher else ""
return version, package_family_name


def _ensure_provider_ready(provider: Any) -> None:
"""Ensure an EP is ready, showing a tqdm progress bar when downloading.
Comment thread
xieofxie marked this conversation as resolved.

Providers already in the ``Ready`` state take the synchronous fast path so
cached EPs do not flash a 0-100% bar. Otherwise drives a tqdm bar from
``ensure_ready_async``'s ``on_progress`` callback (cumulative fraction
0.0-1.0, per windowsml docs) and waits for the ``on_complete`` callback
via a threading.Event with a ``EP_DOWNLOAD_TIMEOUT_SECONDS`` timeout. On
timeout the async op is cancelled and ``TimeoutError`` is raised.
"""
import threading

from windowsml import EpReadyState

if provider.ready_state == EpReadyState.Ready:
provider.ensure_ready()
return

# Lazy-import to keep ep_registry import cheap (rich pulls in pygments etc.);
# this branch only runs on the cold "EP needs download" path.
from ..utils.console import get_console

console = get_console()
Comment thread
xieofxie marked this conversation as resolved.
console.print(f"[WinML] Installing Execution Provider: [bold]{provider.name}[/bold]")

bar = _make_progress_bar()
done = threading.Event()

def _on_progress(fraction: float) -> None:
Comment thread
xieofxie marked this conversation as resolved.
# Native ops may fire a stale on_progress after on_complete; once done
# is set the main thread owns bar.n (forces it to 100 and closes the
# bar), so silently drop late callbacks instead of clobbering 100 with
# an earlier fraction or writing to a closed bar.
if done.is_set():
return
bar.n = max(0, min(100, int(fraction * 100)))
bar.refresh()
Comment thread
xieofxie marked this conversation as resolved.

op = None
success = False
try:
op = provider.ensure_ready_async(on_complete=done.set, on_progress=_on_progress)
if not done.wait(timeout=EP_DOWNLOAD_TIMEOUT_SECONDS):
op.cancel()
raise TimeoutError(
f"EP {provider.name!r} download did not complete within "
f"{EP_DOWNLOAD_TIMEOUT_SECONDS}s; cancelled."
Comment thread
xieofxie marked this conversation as resolved.
Comment thread
xieofxie marked this conversation as resolved.
)
# Surface any native failure (raises OSError on error).
op.get_status()
# Success: providers usually fire on_progress(1.0) before on_complete,
# but force the bar to 100 in case they didn't.
bar.n = 100
bar.refresh()
success = True
finally:
bar.close()
if op is not None:
op.close()
if not success:
# Failure-path notice — kept in finally so it fires for every
# non-success exit (launch failure, timeout, get_status OSError).
# Printed after bar.close() so it appears below the bar's last frame.
console.print(f"[red]❌ Failed to download {provider.name} EP[/red]")
console.print("Try:")
console.print(" 1. Check your internet connection")
console.print(" 2. Troubleshoot: https://aka.ms/winmlcli/ep-errors")

console.print(f"{provider.name} EP installed successfully.")

# The native handle sometimes reports empty version/PFN even once Ready;
# fall back to parsing them from the MSIX install path. Skip a line entirely
# when its value can't be determined rather than printing a blank field.
version = provider.version
package_family_name = provider.package_family_name
if not version or not package_family_name:
parsed_version, parsed_pfn = _parse_ep_metadata_from_path(provider.library_path)
version = version or parsed_version
package_family_name = package_family_name or parsed_pfn
if version:
console.print(f"- Version: {version}", soft_wrap=True)
if package_family_name:
# soft_wrap so long package family names aren't hard-wrapped mid-string.
console.print(f"- Package Family Name: {package_family_name}", soft_wrap=True)


# Singleton instance
_winml_ep_registry: WinMLEPRegistry | None = None

Expand Down Expand Up @@ -80,9 +272,19 @@ def _load_ep_catalog(self) -> None:
with EpCatalog() as catalog:
for provider in catalog.find_all_providers():
try:
provider.ensure_ready()
_ensure_provider_ready(provider)
except OSError as e:
# windowsml maps native HRESULT failures to OSError; surface
# winerror so the HRESULT is grep-able in logs.
logger.info(
"Failed to ensure EP %s is ready: %s (winerror=%s)",
provider.name,
e,
getattr(e, "winerror", None),
)
continue
except Exception as e:
logger.debug("Failed to ensure EP %s is ready: %s", provider.name, e)
logger.info("Failed to ensure EP %s is ready: %s", provider.name, e)
Comment thread
xieofxie marked this conversation as resolved.
continue
if provider.library_path == "":
continue
Expand Down
Loading
Loading