Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/winml/modelkit/inspect/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,12 @@ def get_known_tasks() -> set[str]:
"""
tasks: set[str] = set()

# From HF_MODEL_CLASS_MAPPING values (task part of each (model_type, task) key)
# From HF_MODEL_CLASS_MAPPING values (task part of each (model_type, task) key).
# Some entries use task=None as a sentinel for the per-model-type default task;
# skip those so sorted() in callers never receives a None value.
for _model_type, task in HF_MODEL_CLASS_MAPPING:
tasks.add(task)
if task is not None:
tasks.add(task)

# From HF_TASK_DEFAULTS keys
tasks.update(HF_TASK_DEFAULTS.keys())
Expand Down
22 changes: 22 additions & 0 deletions tests/_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Shared test helpers for winml CLI invocation.

Provides ``run_inspect``, a thin wrapper around ``CliRunner.invoke`` used
by both ``tests/cli/test_inspect_cli.py`` and ``tests/e2e/test_inspect_e2e.py``
so that the invocation envelope (``obj={}``, ``mix_stderr`` defaults, etc.)
lives in a single place.
"""

from __future__ import annotations

from click.testing import CliRunner, Result

from winml.modelkit.commands.inspect import inspect


def run_inspect(*args: str) -> Result:
"""Invoke the ``inspect`` Click command with *args and return the Result."""
return CliRunner().invoke(inspect, list(args), obj={})
116 changes: 116 additions & 0 deletions tests/cli/test_inspect_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""CLI surface tests for `winml inspect`.

Covers help text, no-args UsageError, invalid option choices, and
--list-tasks — all without downloading any model weights or hitting
the network.

These tests run under the default CI filter (no special marker required).
"""

from __future__ import annotations

import re

import pytest

from tests._helpers import run_inspect as _run


# ===========================================================================
# CLI surface
# ===========================================================================


class TestInspectCliSurface:
"""Help text, no-args errors, and format validation."""

@pytest.fixture(scope="class")
def help_output(self) -> str:
"""Invoke --help once and share the output across all parametrized cases."""
return _run("--help").output

def test_no_args_exits_usage_error(self) -> None:
"""Invoked with no arguments inspect must exit 2 with a UsageError."""
result = _run()
assert result.exit_code == 2
assert "At least one of" in result.output

def test_help_exits_zero(self) -> None:
"""--help must exit 0."""
assert _run("--help").exit_code == 0

@pytest.mark.parametrize(
"flags",
[
("-m", "--model"),
("-f", "--format"),
("--model-type",),
("--model-class",),
("--list-tasks",),
("-v", "--verbose"),
("-H", "--hierarchy"),
],
)
def test_help_documents_flag(self, help_output: str, flags: tuple[str, ...]) -> None:
"""Every documented flag appears in --help output."""
for flag in flags:
assert flag in help_output

def test_invalid_format(self) -> None:
"""An unrecognised --format value must exit non-zero and name the bad value."""
result = _run("--model-type", "bert", "--format", "xml")
assert result.exit_code != 0
output_lower = result.output.lower()
assert "xml" in output_lower or "choice" in output_lower or "invalid" in output_lower


# ===========================================================================
# --list-tasks
# ===========================================================================


class TestInspectListTasks:
"""--list-tasks must exit 0 and print one task-name per line."""

def test_list_tasks_exits_zero(self) -> None:
"""--list-tasks should not require a model argument and must exit 0."""
result = _run("--list-tasks")
assert result.exit_code == 0, f"--list-tasks exited {result.exit_code}:\n{result.output}"
Comment thread
DingmaomaoBJTU marked this conversation as resolved.

def test_list_tasks_output_is_nonempty(self) -> None:
"""--list-tasks must print at least one task."""
result = _run("--list-tasks")
assert result.exit_code == 0
lines = [line.strip() for line in result.output.splitlines() if line.strip()]
assert len(lines) > 0, "Expected at least one task line"

def test_list_tasks_lines_match_task_name_pattern(self) -> None:
"""Every line must be a valid HF task-name (lowercase, hyphens only)."""
result = _run("--list-tasks")
Comment thread
DingmaomaoBJTU marked this conversation as resolved.
assert result.exit_code == 0
for line in result.output.splitlines():
task = line.strip()
if task:
assert re.match(r"^[a-z][a-z0-9-]*$", task), (
f"Line does not match task-name pattern: {task!r}"
)

def test_list_tasks_includes_known_tasks(self) -> None:
"""Output must include ModelKit-registered tasks."""
result = _run("--list-tasks")
assert result.exit_code == 0
tasks = {line.strip() for line in result.output.splitlines() if line.strip()}
assert "feature-extraction" in tasks
assert "mask-generation" in tasks

@pytest.mark.parametrize("extra_args", [[], ["--model-type", "bert"]])
def test_list_tasks_is_sorted(self, extra_args: list[str]) -> None:
"""Output lines must be in ascending lexicographic order."""
result = _run("--list-tasks", *extra_args)
assert result.exit_code == 0
lines = [line.strip() for line in result.output.splitlines() if line.strip()]
assert lines == sorted(lines), "Task list is not sorted"
Loading
Loading