Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/winml/modelkit/commands/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import click
from rich.console import Console

from ..config.precision import _DEVICE_TO_PROVIDER, VALID_EPS
from ..config.precision import _DEVICE_TO_PROVIDER, _EP_TO_DEVICE, VALID_EPS
from ..onnx import is_compiled_onnx
from ..utils.logging import configure_logging

Expand Down Expand Up @@ -191,7 +191,7 @@ def compile(

# Show info
console.print(f"[bold blue]Input:[/bold blue] {model}")
console.print(f"[bold blue]Device:[/bold blue] {device}")
console.print(f"[bold blue]Device:[/bold blue] {_EP_TO_DEVICE.get(provider, device)}")
if ep:
console.print(f"[bold blue]EP:[/bold blue] {ep}")
console.print(f"[bold blue]Provider:[/bold blue] {provider}")
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/commands/test_compile_quantize_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from __future__ import annotations

from unittest.mock import MagicMock, patch

import pytest

from winml.modelkit.commands.compile import _resolve_compile_provider
Expand Down Expand Up @@ -119,3 +121,31 @@ def test_unknown_precision_uses_defaults(self):
w, a = _resolve_quant_types("fp16", None, None)
assert w == "uint8"
assert a == "uint8"


class TestCompileDeviceDisplayLabel:
"""Device label in compile summary must reflect the resolved EP, not the CLI default."""

def test_dml_ep_shows_gpu_device(self, tmp_path):
from click.testing import CliRunner

from winml.modelkit.commands.compile import compile

model_file = tmp_path / "model.onnx"
model_file.write_bytes(b"fake")

mock_result = MagicMock()
mock_result.success = True
mock_result.output_path = None
mock_result.compile_time = None
mock_result.total_time = None

with (
patch("winml.modelkit.commands.compile.is_compiled_onnx", return_value=False),
patch("winml.modelkit.compiler.compile_onnx", return_value=mock_result),
patch("winml.modelkit.compiler.WinMLCompileConfig"),
):
result = CliRunner().invoke(compile, ["-m", str(model_file), "--ep", "dml"])

assert "Device: gpu" in result.output
assert "Device: npu" not in result.output
Loading