diff --git a/src/winml/modelkit/inspect/resolver.py b/src/winml/modelkit/inspect/resolver.py index 27550ffe6..94bfb0c57 100644 --- a/src/winml/modelkit/inspect/resolver.py +++ b/src/winml/modelkit/inspect/resolver.py @@ -64,9 +64,12 @@ def get_known_tasks() -> set[str]: """ tasks: set[str] = set() - # From HF_MODEL_CLASS_MAPPING values (task part of each (model_type, task) key) + # From HF_MODEL_CLASS_MAPPING values (task part of each (model_type, task) key). + # Some entries use task=None as a sentinel for the per-model-type default task; + # skip those so sorted() in callers never receives a None value. for _model_type, task in HF_MODEL_CLASS_MAPPING: - tasks.add(task) + if task is not None: + tasks.add(task) # From HF_TASK_DEFAULTS keys tasks.update(HF_TASK_DEFAULTS.keys()) diff --git a/tests/_helpers.py b/tests/_helpers.py new file mode 100644 index 000000000..4424cab6d --- /dev/null +++ b/tests/_helpers.py @@ -0,0 +1,22 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Shared test helpers for winml CLI invocation. + +Provides ``run_inspect``, a thin wrapper around ``CliRunner.invoke`` used +by both ``tests/cli/test_inspect_cli.py`` and ``tests/e2e/test_inspect_e2e.py`` +so that the invocation envelope (``obj={}``, ``mix_stderr`` defaults, etc.) +lives in a single place. +""" + +from __future__ import annotations + +from click.testing import CliRunner, Result + +from winml.modelkit.commands.inspect import inspect + + +def run_inspect(*args: str) -> Result: + """Invoke the ``inspect`` Click command with *args and return the Result.""" + return CliRunner().invoke(inspect, list(args), obj={}) diff --git a/tests/cli/test_inspect_cli.py b/tests/cli/test_inspect_cli.py new file mode 100644 index 000000000..0d0b9793e --- /dev/null +++ b/tests/cli/test_inspect_cli.py @@ -0,0 +1,116 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""CLI surface tests for `winml inspect`. + +Covers help text, no-args UsageError, invalid option choices, and +--list-tasks — all without downloading any model weights or hitting +the network. + +These tests run under the default CI filter (no special marker required). +""" + +from __future__ import annotations + +import re + +import pytest + +from tests._helpers import run_inspect as _run + + +# =========================================================================== +# CLI surface +# =========================================================================== + + +class TestInspectCliSurface: + """Help text, no-args errors, and format validation.""" + + @pytest.fixture(scope="class") + def help_output(self) -> str: + """Invoke --help once and share the output across all parametrized cases.""" + return _run("--help").output + + def test_no_args_exits_usage_error(self) -> None: + """Invoked with no arguments inspect must exit 2 with a UsageError.""" + result = _run() + assert result.exit_code == 2 + assert "At least one of" in result.output + + def test_help_exits_zero(self) -> None: + """--help must exit 0.""" + assert _run("--help").exit_code == 0 + + @pytest.mark.parametrize( + "flags", + [ + ("-m", "--model"), + ("-f", "--format"), + ("--model-type",), + ("--model-class",), + ("--list-tasks",), + ("-v", "--verbose"), + ("-H", "--hierarchy"), + ], + ) + def test_help_documents_flag(self, help_output: str, flags: tuple[str, ...]) -> None: + """Every documented flag appears in --help output.""" + for flag in flags: + assert flag in help_output + + def test_invalid_format(self) -> None: + """An unrecognised --format value must exit non-zero and name the bad value.""" + result = _run("--model-type", "bert", "--format", "xml") + assert result.exit_code != 0 + output_lower = result.output.lower() + assert "xml" in output_lower or "choice" in output_lower or "invalid" in output_lower + + +# =========================================================================== +# --list-tasks +# =========================================================================== + + +class TestInspectListTasks: + """--list-tasks must exit 0 and print one task-name per line.""" + + def test_list_tasks_exits_zero(self) -> None: + """--list-tasks should not require a model argument and must exit 0.""" + result = _run("--list-tasks") + assert result.exit_code == 0, f"--list-tasks exited {result.exit_code}:\n{result.output}" + + def test_list_tasks_output_is_nonempty(self) -> None: + """--list-tasks must print at least one task.""" + result = _run("--list-tasks") + assert result.exit_code == 0 + lines = [line.strip() for line in result.output.splitlines() if line.strip()] + assert len(lines) > 0, "Expected at least one task line" + + def test_list_tasks_lines_match_task_name_pattern(self) -> None: + """Every line must be a valid HF task-name (lowercase, hyphens only).""" + result = _run("--list-tasks") + assert result.exit_code == 0 + for line in result.output.splitlines(): + task = line.strip() + if task: + assert re.match(r"^[a-z][a-z0-9-]*$", task), ( + f"Line does not match task-name pattern: {task!r}" + ) + + def test_list_tasks_includes_known_tasks(self) -> None: + """Output must include ModelKit-registered tasks.""" + result = _run("--list-tasks") + assert result.exit_code == 0 + tasks = {line.strip() for line in result.output.splitlines() if line.strip()} + assert "feature-extraction" in tasks + assert "mask-generation" in tasks + + @pytest.mark.parametrize("extra_args", [[], ["--model-type", "bert"]]) + def test_list_tasks_is_sorted(self, extra_args: list[str]) -> None: + """Output lines must be in ascending lexicographic order.""" + result = _run("--list-tasks", *extra_args) + assert result.exit_code == 0 + lines = [line.strip() for line in result.output.splitlines() if line.strip()] + assert lines == sorted(lines), "Task list is not sorted" diff --git a/tests/e2e/test_inspect_e2e.py b/tests/e2e/test_inspect_e2e.py index 5b774c1d0..1531440a3 100644 --- a/tests/e2e/test_inspect_e2e.py +++ b/tests/e2e/test_inspect_e2e.py @@ -4,20 +4,19 @@ # -------------------------------------------------------------------------- """E2E tests for the inspect CLI command. -These tests exercise the full inspect pipeline with REAL models -downloaded from HuggingFace Hub. They validate JSON output structure -and content for various model-task combinations. +Offline tests (no network) use --model-type or --model-class flags and +validate JSON structure invariants without downloading any model weights. -Note: The inspect command's validate_task() has a limited task vocabulary -(feature-extraction, image-feature-extraction, image-segmentation, -mask-generation, next-sentence-prediction). Tasks outside this set are -rejected when passed via --task override. Auto-detect (no --task) uses -TasksManager directly and supports a broader set of tasks. +Network tests use real HuggingFace model IDs (-m flag) and validate +auto-detected task, model_type, and full JSON structure. + +CLI surface and --list-tasks tests live in tests/cli/test_inspect_cli.py. Markers: - e2e: Full end-to-end test with real models + e2e: Full end-to-end test (required to run any test in this file) network: Requires network access to HuggingFace Hub """ + from __future__ import annotations import json @@ -25,13 +24,14 @@ import pytest from click.testing import CliRunner +from tests._helpers import run_inspect as _run from winml.modelkit.commands.inspect import inspect -pytestmark = [pytest.mark.e2e, pytest.mark.network] +pytestmark = [pytest.mark.e2e] # --------------------------------------------------------------------------- -# Helpers +# Constants # --------------------------------------------------------------------------- EXPECTED_TOP_KEYS = { @@ -51,56 +51,153 @@ "io_config", } +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + -def _run_inspect(model: str, task: str | None = None) -> dict: - """Invoke the inspect command and return parsed JSON output. +def _run_json(*args: str) -> dict: + """Invoke inspect with *args + '-f json' and return parsed JSON. + + Asserts exit_code == 0 and that the output is valid JSON. + """ + result = _run(*args, "-f", "json") + assert result.exit_code == 0, f"inspect exited {result.exit_code}:\n{result.output}" + return json.loads(result.output) + + +def _run_network(model: str, task: str | None = None) -> dict: + """Invoke inspect with a real model ID and return parsed JSON output. Raises AssertionError when the command exits non-zero or the output is not valid JSON. """ - runner = CliRunner() - args = ["-m", model, "-f", "json"] + args: list[str] = ["-m", model, "-f", "json"] if task: args.extend(["-t", task]) - result = runner.invoke(inspect, args, obj={}, catch_exceptions=False) - assert result.exit_code == 0, ( - f"inspect failed (exit {result.exit_code}):\n{result.output}" - ) + result = CliRunner().invoke(inspect, args, obj={}, catch_exceptions=False) + assert result.exit_code == 0, f"inspect failed (exit {result.exit_code}):\n{result.output}" return json.loads(result.output) def _assert_common_structure(data: dict, model_id: str, expected_task: str) -> None: """Assert the standard JSON structure returned by inspect.""" - # All top-level keys present assert EXPECTED_TOP_KEYS.issubset(data.keys()), ( f"Missing keys: {EXPECTED_TOP_KEYS - data.keys()}" ) - assert data["model_id"] == model_id assert data["task"] == expected_task - # Loader section loader = data["loader"] assert "hf_model_class" in loader assert "support_level" in loader - # Exporter section exporter = data["exporter"] assert "onnx_config_class" in exporter assert "support_level" in exporter assert isinstance(exporter.get("input_tensors"), list) assert isinstance(exporter.get("output_tensors"), list) - # WinML section winml = data["winml"] assert "winml_class" in winml assert "support_level" in winml # =========================================================================== -# BERT +# Offline inspection — no network required # =========================================================================== + +class TestInspectModelTypeOnly: + """Use --model-type / --model-class without downloading any weights.""" + + def test_bert_default_task_json(self): + """--model-type bert resolves to a bert model_type with some task.""" + data = _run_json("--model-type", "bert") + assert data["model_type"] == "bert" + assert isinstance(data["task"], str) and data["task"] + + def test_bert_feature_extraction_json(self): + """--model-type bert -t feature-extraction resolves correctly.""" + data = _run_json("--model-type", "bert", "-t", "feature-extraction") + assert data["model_type"] == "bert" + assert data["task"] == "feature-extraction" + + def test_resnet_default_task_json(self): + """--model-type resnet resolves to a resnet model_type.""" + data = _run_json("--model-type", "resnet") + assert data["model_type"] == "resnet" + assert isinstance(data["task"], str) and data["task"] + + def test_model_class_bert_for_masked_lm(self): + """--model-class BertForMaskedLM resolves to bert / fill-mask.""" + data = _run_json("--model-class", "BertForMaskedLM") + assert data["model_type"] == "bert" + assert data["task"] == "fill-mask" + + def test_verbose_flag_accepted(self): + """--verbose must be accepted without error.""" + data = _run_json("--model-type", "bert", "--verbose") + assert data["model_type"] == "bert" + + def test_short_verbose_flag_accepted(self): + """-v short flag must be accepted without error.""" + data = _run_json("--model-type", "bert", "-v") + assert data["model_type"] == "bert" + + def test_json_output_contains_all_top_level_keys(self): + """JSON output must include every key in EXPECTED_TOP_KEYS.""" + data = _run_json("--model-type", "bert") + missing = EXPECTED_TOP_KEYS - data.keys() + assert not missing, f"Missing top-level keys: {missing}" + + def test_loader_section_structure(self): + """loader section must have hf_model_class and support_level.""" + data = _run_json("--model-type", "bert") + loader = data["loader"] + assert "hf_model_class" in loader + assert "support_level" in loader + + def test_exporter_section_structure(self): + """exporter section must have onnx_config_class, support_level, tensors.""" + data = _run_json("--model-type", "bert") + exporter = data["exporter"] + assert "onnx_config_class" in exporter + assert "support_level" in exporter + assert isinstance(exporter.get("input_tensors"), list) + assert isinstance(exporter.get("output_tensors"), list) + + def test_winml_section_structure(self): + """winml section must have winml_class and support_level.""" + data = _run_json("--model-type", "bert") + winml = data["winml"] + assert "winml_class" in winml + assert "support_level" in winml + + def test_table_format_exits_zero(self): + """Default table format must exit 0 (Rich output is not captured, but exit code is).""" + result = _run("--model-type", "bert") + assert result.exit_code == 0 + + def test_unknown_model_type_exits_nonzero(self): + """An unrecognised model type must produce a non-zero exit code.""" + result = _run("--model-type", "totally_nonexistent_model_xyz_123") + assert result.exit_code != 0 + + def test_hierarchy_flag_accepted_without_model(self): + """--hierarchy flag must be accepted even without a model download.""" + # Without -m, hierarchy_info will be None (skipped), but command should succeed + data = _run_json("--model-type", "bert", "--hierarchy") + assert data["model_type"] == "bert" + assert data["hierarchy"] is None # hierarchy requires -m + + +# =========================================================================== +# Network tests — require HuggingFace Hub access +# =========================================================================== + + +@pytest.mark.network class TestInspectBert: """Inspect bert-base-uncased with auto-detect and explicit tasks.""" @@ -108,47 +205,37 @@ class TestInspectBert: def test_auto_detect_fill_mask(self): """Auto-detect should resolve BERT to fill-mask via TasksManager.""" - data = _run_inspect(self.MODEL) + data = _run_network(self.MODEL) _assert_common_structure(data, self.MODEL, "fill-mask") assert data["model_type"] == "bert" assert data["task_source"] == "TasksManager" def test_feature_extraction(self): - """feature-extraction is in the known task list; explicit override works.""" - data = _run_inspect(self.MODEL, task="feature-extraction") + """feature-extraction task override must work.""" + data = _run_network(self.MODEL, task="feature-extraction") _assert_common_structure(data, self.MODEL, "feature-extraction") assert data["model_type"] == "bert" - def test_explicit_unknown_task_rejected(self): - """Tasks not in validate_task vocabulary are cleanly rejected.""" - runner = CliRunner() - args = ["-m", self.MODEL, "-f", "json", "-t", "text-classification"] - result = runner.invoke(inspect, args, obj={}) - assert result.exit_code != 0 - assert "Unknown task" in result.output - def test_next_sentence_prediction(self): - """next-sentence-prediction is in the known task list. + """next-sentence-prediction task override: clean success or clean error. We assert it either succeeds with valid JSON or fails with a clean ClickException (non-zero exit code), but never crashes with an unhandled traceback. """ - runner = CliRunner() - args = ["-m", self.MODEL, "-f", "json", "-t", "next-sentence-prediction"] - result = runner.invoke(inspect, args, obj={}) + result = CliRunner().invoke( + inspect, + ["-m", self.MODEL, "-f", "json", "-t", "next-sentence-prediction"], + obj={}, + ) if result.exit_code == 0: data = json.loads(result.output) assert "model_id" in data else: - # Should be a clean error, not a raw traceback assert "Traceback (most recent call last)" not in result.output -# =========================================================================== -# Vision models -# =========================================================================== - +@pytest.mark.network class TestInspectVision: """Inspect vision models via auto-detect (image-classification).""" @@ -163,15 +250,12 @@ class TestInspectVision: ) def test_auto_detect_image_classification(self, model_id: str): """Auto-detect should resolve vision models to image-classification.""" - data = _run_inspect(model_id) + data = _run_network(model_id) _assert_common_structure(data, model_id, "image-classification") assert data["model_type"] in {"resnet", "convnext", "vit"} -# =========================================================================== -# CLIP -# =========================================================================== - +@pytest.mark.network class TestInspectCLIP: """Inspect CLIP with multi-modal tasks.""" @@ -179,21 +263,18 @@ class TestInspectCLIP: def test_auto_detect_feature_extraction(self): """Auto-detect should resolve CLIP to feature-extraction.""" - data = _run_inspect(self.MODEL) + data = _run_network(self.MODEL) assert data["model_type"] == "clip" assert data["task"] in {"feature-extraction", "zero-shot-image-classification"} def test_image_feature_extraction(self): - """image-feature-extraction is in the known task list.""" - data = _run_inspect(self.MODEL, task="image-feature-extraction") + """image-feature-extraction task override must work.""" + data = _run_network(self.MODEL, task="image-feature-extraction") _assert_common_structure(data, self.MODEL, "image-feature-extraction") assert data["model_type"] == "clip" -# =========================================================================== -# DETR -# =========================================================================== - +@pytest.mark.network class TestInspectDETR: """Inspect DETR with object-detection.""" @@ -201,19 +282,7 @@ class TestInspectDETR: def test_auto_detect_object_detection(self): """Auto-detect should resolve DETR to object-detection.""" - data = _run_inspect(self.MODEL) + data = _run_network(self.MODEL) assert data["model_id"] == self.MODEL assert data["model_type"] == "detr" assert data["task"] == "object-detection" - - def test_explicit_object_detection_rejected(self): - """object-detection is NOT in validate_task vocabulary. - - Explicit override should be cleanly rejected, while - auto-detect (tested above) succeeds. - """ - runner = CliRunner() - args = ["-m", self.MODEL, "-f", "json", "-t", "object-detection"] - result = runner.invoke(inspect, args, obj={}) - assert result.exit_code != 0 - assert "Unknown task" in result.output