diff --git a/src/winml/modelkit/commands/compile.py b/src/winml/modelkit/commands/compile.py index 142b57155..8ecf470a2 100644 --- a/src/winml/modelkit/commands/compile.py +++ b/src/winml/modelkit/commands/compile.py @@ -25,7 +25,7 @@ import click from rich.console import Console -from ..config.precision import _DEVICE_TO_PROVIDER, VALID_EPS +from ..config.precision import _DEVICE_TO_PROVIDER, _EP_TO_DEVICE, VALID_EPS from ..onnx import is_compiled_onnx from ..utils.logging import configure_logging @@ -191,7 +191,7 @@ def compile( # Show info console.print(f"[bold blue]Input:[/bold blue] {model}") - console.print(f"[bold blue]Device:[/bold blue] {device}") + console.print(f"[bold blue]Device:[/bold blue] {_EP_TO_DEVICE.get(provider, device)}") if ep: console.print(f"[bold blue]EP:[/bold blue] {ep}") console.print(f"[bold blue]Provider:[/bold blue] {provider}") diff --git a/tests/unit/commands/test_compile_quantize_flags.py b/tests/unit/commands/test_compile_quantize_flags.py index 2c8b43ce5..7f5274a9b 100644 --- a/tests/unit/commands/test_compile_quantize_flags.py +++ b/tests/unit/commands/test_compile_quantize_flags.py @@ -6,6 +6,8 @@ from __future__ import annotations +from unittest.mock import MagicMock, patch + import pytest from winml.modelkit.commands.compile import _resolve_compile_provider @@ -119,3 +121,31 @@ def test_unknown_precision_uses_defaults(self): w, a = _resolve_quant_types("fp16", None, None) assert w == "uint8" assert a == "uint8" + + +class TestCompileDeviceDisplayLabel: + """Device label in compile summary must reflect the resolved EP, not the CLI default.""" + + def test_dml_ep_shows_gpu_device(self, tmp_path): + from click.testing import CliRunner + + from winml.modelkit.commands.compile import compile + + model_file = tmp_path / "model.onnx" + model_file.write_bytes(b"fake") + + mock_result = MagicMock() + mock_result.success = True + mock_result.output_path = None + mock_result.compile_time = None + mock_result.total_time = None + + with ( + patch("winml.modelkit.commands.compile.is_compiled_onnx", return_value=False), + patch("winml.modelkit.compiler.compile_onnx", return_value=mock_result), + patch("winml.modelkit.compiler.WinMLCompileConfig"), + ): + result = CliRunner().invoke(compile, ["-m", str(model_file), "--ep", "dml"]) + + assert "Device: gpu" in result.output + assert "Device: npu" not in result.output