From 7bad9821f025d07aa5fc2f97448534415ae23422 Mon Sep 17 00:00:00 2001 From: zhenchaoni Date: Wed, 20 May 2026 11:36:35 +0800 Subject: [PATCH 1/3] Fix quantization P0 bugs --- src/winml/modelkit/commands/quantize.py | 10 ++- src/winml/modelkit/datasets/image.py | 69 +++++++++++++------ .../modelkit/datasets/object_detection.py | 30 +------- .../eval/object_detection_evaluator.py | 17 ++++- .../commands/test_compile_quantize_flags.py | 46 +++++++++++-- .../eval/test_object_detection_evaluator.py | 69 +++++++++++++++++++ 6 files changed, 181 insertions(+), 60 deletions(-) diff --git a/src/winml/modelkit/commands/quantize.py b/src/winml/modelkit/commands/quantize.py index 7db030c02..f59da4e53 100644 --- a/src/winml/modelkit/commands/quantize.py +++ b/src/winml/modelkit/commands/quantize.py @@ -192,8 +192,6 @@ def quantize( # Show info console.print(f"[bold blue]Input:[/bold blue] {model}") console.print(f"[bold blue]Output:[/bold blue] {output}") - if precision: - console.print(f"[bold blue]Precision:[/bold blue] {precision}") console.print(f"[bold blue]Weight type:[/bold blue] {resolved_weight}") console.print(f"[bold blue]Activation type:[/bold blue] {resolved_activation}") console.print(f"[bold blue]Samples:[/bold blue] {samples}") @@ -260,8 +258,14 @@ def _resolve_quant_types( if precision and is_quantized_precision(precision): default_w, default_a = resolve_quant_types(precision) - else: + elif precision is None or precision.lower() == "auto": default_w, default_a = "uint8", "uint8" + else: + raise click.BadParameter( + f"'{precision}' is not a supported quantization precision. " + "See `winml quantize --help` for details.", + param_hint="'-p' / '--precision'", + ) # Explicit flags override precision defaults resolved_w = weight_type if weight_type else default_w diff --git a/src/winml/modelkit/datasets/image.py b/src/winml/modelkit/datasets/image.py index 57540a2f6..b570f6d02 100644 --- a/src/winml/modelkit/datasets/image.py +++ b/src/winml/modelkit/datasets/image.py @@ -50,36 +50,41 @@ def _get_default_dataset(self) -> None: if self._dataset_name is None: self._dataset_name = "timm/mini-imagenet" self._data_split = "train" + self._config.setdefault("streaming", True) - def _initialize(self) -> None: - """Initialize the image classification dataset. + def _load_and_sample(self) -> Any: + """Load the configured dataset and apply sample/shuffle. - Simplified approach: - 1. Set defaults if needed - 2. Load dataset - 3. Detect columns - 4. Apply efficient processing pipeline - """ - # 1. Set defaults if no dataset specified - if self._dataset_name is None: - self._get_default_dataset() + Shared by ImageDataset and ObjectDetectionDataset. Column detection + is *not* done here — callers run their own detection on the returned + dataset because the column schema differs by task. - # 2. Load dataset + Returns: + A materialized arrow Dataset of up to ``self._max_samples`` rows. + """ + # Streaming only helps when capped by max_samples; otherwise we'd + # iterate the full remote stream into memory, which is worse than a + # bulk download. + streaming = self._config.get("streaming", False) and self._max_samples is not None logger.info(f"Loading dataset: {self._dataset_name} with split: {self._data_split}") try: - dataset = load_dataset(self._dataset_name, split=self._data_split) + dataset = load_dataset(self._dataset_name, split=self._data_split, streaming=streaming) except Exception as e: logger.error(f"Failed to load dataset {self._dataset_name}: {e}") raise - # 3. Detect columns using Features API - self._detect_columns(dataset) - - # 5. Efficient sampling and processing pipeline shuffle = self._config.get("shuffle", False) seed = self._config.get("seed", 42) - if self._max_samples is not None: + if streaming: + # Streaming datasets aren't indexable: shuffle reservoir-samples + # within a buffer; take() pulls only the slice we need. + if shuffle: + dataset = dataset.shuffle(seed=seed, buffer_size=1000) + dataset = dataset.take(self._max_samples) + from datasets import Dataset as ArrowDataset + dataset = ArrowDataset.from_list(list(dataset), features=dataset.features) + elif self._max_samples is not None: max_samples = min(self._max_samples, len(dataset)) indices = ( Random(seed).sample(range(len(dataset)), max_samples) @@ -90,15 +95,35 @@ def _initialize(self) -> None: elif shuffle: dataset = dataset.shuffle(seed=seed) - # 6. Load processor and apply batch processing - processor = AutoImageProcessor.from_pretrained(self._model_name, use_fast=True) + return dataset + def _initialize(self) -> None: + """Initialize the image classification dataset. + + Simplified approach: + 1. Set defaults if needed + 2. Load dataset + 3. Detect columns + 4. Apply efficient processing pipeline + """ + # 1. Set defaults if no dataset specified + if self._dataset_name is None: + self._get_default_dataset() + + # 2. Load + sample (shared with subclasses) + dataset = self._load_and_sample() + + # 3. Detect columns using Features API + self._detect_columns(dataset) + + # 4. Load processor and apply batch processing + processor = AutoImageProcessor.from_pretrained(self._model_name, use_fast=True) - # 6. Conditional label alignment using should_align_labels() + # 5. Conditional label alignment using should_align_labels() if should_align_labels(self._dataset_name): dataset = dataset.align_labels_with_mapping(get_imagenet_label_map(), self._label_col) - # 7. Apply image processing with proper batch dimension + # 6. Apply image processing with proper batch dimension def preprocess_single_sample(example): # Process single image and add batch dimension return processor(example[self._image_col].convert("RGB"), return_tensors="pt") diff --git a/src/winml/modelkit/datasets/object_detection.py b/src/winml/modelkit/datasets/object_detection.py index cfdbabb20..628204f35 100644 --- a/src/winml/modelkit/datasets/object_detection.py +++ b/src/winml/modelkit/datasets/object_detection.py @@ -11,10 +11,8 @@ from __future__ import annotations import logging -from random import Random from typing import Any -from datasets import load_dataset from datasets.features import Image from transformers import AutoImageProcessor @@ -97,36 +95,12 @@ def _initialize(self) -> None: if self._dataset_name is None: self._get_default_dataset() - # Load dataset - logger.info( - "Loading object detection dataset: %s with split: %s", - self._dataset_name, - self._data_split, - ) - try: - dataset = load_dataset(self._dataset_name, split=self._data_split) - except Exception as e: - logger.error("Failed to load dataset %s: %s", self._dataset_name, e) - raise + # Load + sample (shared with ImageDataset) + dataset = self._load_and_sample() # Detect image column (object detection may not have simple ClassLabel) self._detect_image_column(dataset) - # Efficient sampling - shuffle = self._config.get("shuffle", False) - seed = self._config.get("seed", 42) - - if self._max_samples is not None: - max_samples = min(self._max_samples, len(dataset)) - indices = ( - Random(seed).sample(range(len(dataset)), max_samples) - if shuffle - else list(range(max_samples)) - ) - dataset = dataset.select(indices) - elif shuffle: - dataset = dataset.shuffle(seed=seed) - # Derive ONNX-aware overrides io_config = self._config.get("io_config") overrides = self._derive_overrides(io_config) diff --git a/src/winml/modelkit/eval/object_detection_evaluator.py b/src/winml/modelkit/eval/object_detection_evaluator.py index fd6d01c78..570a7f677 100644 --- a/src/winml/modelkit/eval/object_detection_evaluator.py +++ b/src/winml/modelkit/eval/object_detection_evaluator.py @@ -93,11 +93,22 @@ def prepare_pipeline(self) -> Pipeline: io_config = getattr(self.model, "io_config", None) or {} input_shapes = io_config.get("input_shapes", [[]]) + input_names = io_config.get("input_names", []) if input_shapes and len(input_shapes[0]) == 4: _, _, h, w = input_shapes[0] - pipe.image_processor.size = {"height": h, "width": w} - if hasattr(pipe.image_processor, "do_pad"): - pipe.image_processor.do_pad = False + if "pixel_mask" in input_names: + pipe.image_processor.size = { + "shortest_edge": min(h, w), + "longest_edge": max(h, w), + } + if hasattr(pipe.image_processor, "pad_size"): + pipe.image_processor.pad_size = {"height": h, "width": w} + if hasattr(pipe.image_processor, "do_pad"): + pipe.image_processor.do_pad = True + else: + pipe.image_processor.size = {"height": h, "width": w} + if hasattr(pipe.image_processor, "do_pad"): + pipe.image_processor.do_pad = False return pipe diff --git a/tests/unit/commands/test_compile_quantize_flags.py b/tests/unit/commands/test_compile_quantize_flags.py index 50739ba39..ae38a532e 100644 --- a/tests/unit/commands/test_compile_quantize_flags.py +++ b/tests/unit/commands/test_compile_quantize_flags.py @@ -239,10 +239,11 @@ def test_precision_case_insensitive(self): assert a == "uint8" def test_unknown_precision_uses_defaults(self): - """Unknown precision string falls through to defaults.""" - w, a = _resolve_quant_types("fp16", None, None) - assert w == "uint8" - assert a == "uint8" + """Explicit non-quantized precision (e.g., fp16) is rejected.""" + import click + + with pytest.raises(click.BadParameter, match="not a supported quantization precision"): + _resolve_quant_types("fp16", None, None) class TestCompileDeviceDisplayLabel: @@ -516,3 +517,40 @@ def test_non_object_top_level_raises_usage_error(self, tmp_path): r = self._invoke(["-m", str(model), "--config", str(bc)]) assert r.exit_code != 0 assert "Build config must be a JSON object" in r.output + + +class TestQuantizePrecisionValidation: + """Regression tests for issue #555. + + `winml quantize --precision ` must reject the value before + running quantization, instead of silently falling back to uint8/uint8 + and printing "Success!". + """ + + @staticmethod + def _invoke(args): + from click.testing import CliRunner + + from winml.modelkit.commands.quantize import quantize as quantize_cmd + + return CliRunner().invoke(quantize_cmd, args, obj={}, catch_exceptions=False) + + @pytest.mark.parametrize( + "bad_precision", + ["banana", "w4a16", "int4", "fp64"], + ) + def test_unknown_precision_rejected(self, tmp_path, bad_precision): + model, _ = TestQuantizeCliConfigPrecedence._setup(tmp_path) + ran: dict[str, bool] = {"called": False} + + def fake_quantize(*_args, **_kwargs): + ran["called"] = True + raise AssertionError("quantize_onnx must not be called for invalid precision") + + with patch("winml.modelkit.quant.quantize_onnx", side_effect=fake_quantize): + r = self._invoke(["-m", str(model), "--precision", bad_precision]) + + assert r.exit_code != 0, r.output + assert "not a supported quantization precision" in r.output + assert ran["called"] is False + diff --git a/tests/unit/eval/test_object_detection_evaluator.py b/tests/unit/eval/test_object_detection_evaluator.py index ea16a4ec4..83750b326 100644 --- a/tests/unit/eval/test_object_detection_evaluator.py +++ b/tests/unit/eval/test_object_detection_evaluator.py @@ -195,3 +195,72 @@ def test_user_mapping_overrides_model(self): result = ev.align_labels(ds, ds_config) assert result["objects"][0]["category"] == [1, 2] + + +# --------------------------------------------------------------------------- +# Pipeline preparation +# --------------------------------------------------------------------------- + + +class _FakeProcessor: + """Minimal stand-in for an HF image processor with the attributes we read.""" + + def __init__(self) -> None: + self.size: dict | None = None + self.pad_size: dict | None = None + self.do_pad: bool | None = None + + +class _FakePipe: + def __init__(self) -> None: + self.image_processor = _FakeProcessor() + + +def _patch_super_prepare_pipeline(monkeypatch, fake_pipe) -> None: + """Patch the *parent* prepare_pipeline so super() in the subclass returns fake_pipe. + + Uses monkeypatch so the override is reverted on teardown — directly assigning + to a parent class attribute leaks across tests in the same session. + """ + parent = WinMLObjectDetectionEvaluator.__mro__[1] + monkeypatch.setattr(parent, "prepare_pipeline", lambda self: fake_pipe) + + +class TestPreparePipeline: + """Cover the two branches in prepare_pipeline().""" + + def test_pixel_mask_path_enables_padding(self, monkeypatch) -> None: + """Model declaring pixel_mask: aspect-preserving size + pad_size + do_pad.""" + ev = make_evaluator() + ev.model.io_config = { + "input_names": ["pixel_values", "pixel_mask"], + "input_shapes": [[1, 3, 800, 800], [1, 800, 800]], + } + fake_pipe = _FakePipe() + _patch_super_prepare_pipeline(monkeypatch, fake_pipe) + + pipe = ev.prepare_pipeline() + proc = pipe.image_processor + assert proc.size == {"shortest_edge": 800, "longest_edge": 800} + assert proc.pad_size == {"height": 800, "width": 800} + assert proc.do_pad is True + + def test_no_pixel_mask_path_disables_padding(self, monkeypatch) -> None: + """No pixel_mask input: exact resize, no padding (legacy behavior).""" + ev = make_evaluator() + ev.model.io_config = { + "input_names": ["pixel_values"], + "input_shapes": [[1, 3, 640, 640]], + } + fake_pipe = _FakePipe() + # Sentinel values let us prove the pixel_mask branch did NOT run. + fake_pipe.image_processor.pad_size = "untouched" + fake_pipe.image_processor.do_pad = "untouched" + _patch_super_prepare_pipeline(monkeypatch, fake_pipe) + + pipe = ev.prepare_pipeline() + proc = pipe.image_processor + assert proc.size == {"height": 640, "width": 640} + assert proc.do_pad is False + # pad_size must remain the sentinel — pixel_mask branch never executed. + assert proc.pad_size == "untouched" From 93e54997f7b6f69dd5549a9ac8710d5f38067b15 Mon Sep 17 00:00:00 2001 From: zhenchaoni Date: Wed, 20 May 2026 13:36:13 +0800 Subject: [PATCH 2/3] fix unit --- tests/unit/config/test_precision.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/unit/config/test_precision.py b/tests/unit/config/test_precision.py index 2d472d8e7..c41ee4e31 100644 --- a/tests/unit/config/test_precision.py +++ b/tests/unit/config/test_precision.py @@ -578,12 +578,13 @@ def test_no_precision_defaults_uint8(self) -> None: assert w == "uint8" assert a == "uint8" - # ---- Unsupported precision falls back to uint8/uint8 ---- - def test_unsupported_precision_falls_back(self) -> None: - """Unsupported precision (w4a16) is not quantized -> fallback to uint8.""" - w, a = self._resolve(precision="w4a16") - assert w == "uint8" - assert a == "uint8" + # ---- Unsupported precision is rejected ---- + def test_unsupported_precision_rejected(self) -> None: + """Unsupported precision (w4a16) must raise BadParameter, not silently fall back.""" + import click + + with pytest.raises(click.BadParameter, match="not a supported quantization precision"): + self._resolve(precision="w4a16") # ---- Explicit flags override precision ---- def test_explicit_weight_overrides_precision(self) -> None: From ae69cb850b80f8fc91960efc30d94fac90be902a Mon Sep 17 00:00:00 2001 From: zhenchaoni Date: Thu, 21 May 2026 14:29:49 +0800 Subject: [PATCH 3/3] Resolve comments --- src/winml/modelkit/commands/quantize.py | 7 +- src/winml/modelkit/datasets/image.py | 3 +- tests/e2e/test_quantize_e2e.py | 13 +- tests/unit/datasets/test_image_streaming.py | 127 ++++++++++++++++++ .../eval/test_object_detection_evaluator.py | 25 ++-- 5 files changed, 161 insertions(+), 14 deletions(-) create mode 100644 tests/unit/datasets/test_image_streaming.py diff --git a/src/winml/modelkit/commands/quantize.py b/src/winml/modelkit/commands/quantize.py index f59da4e53..8c50540f6 100644 --- a/src/winml/modelkit/commands/quantize.py +++ b/src/winml/modelkit/commands/quantize.py @@ -48,7 +48,8 @@ "-p", type=str, default=None, - help="Quantization precision: int8, int16, or w{x}a{y} (e.g., w8a16). " + help="Quantization precision. Accepted: auto, int8, int16, or w{x}a{y} " + "where x,y in {8,16} (e.g., w8a8, w8a16, w16a16). " "Overridden by explicit --weight-type/--activation-type.", ) @click.option( @@ -192,6 +193,7 @@ def quantize( # Show info console.print(f"[bold blue]Input:[/bold blue] {model}") console.print(f"[bold blue]Output:[/bold blue] {output}") + console.print(f"[bold blue]Precision:[/bold blue] {precision or 'auto'}") console.print(f"[bold blue]Weight type:[/bold blue] {resolved_weight}") console.print(f"[bold blue]Activation type:[/bold blue] {resolved_activation}") console.print(f"[bold blue]Samples:[/bold blue] {samples}") @@ -263,7 +265,8 @@ def _resolve_quant_types( else: raise click.BadParameter( f"'{precision}' is not a supported quantization precision. " - "See `winml quantize --help` for details.", + "Accepted: auto, int8, int16, or w{x}a{y} with x,y in {8,16} " + "(e.g., w8a8, w8a16, w16a16).", param_hint="'-p' / '--precision'", ) diff --git a/src/winml/modelkit/datasets/image.py b/src/winml/modelkit/datasets/image.py index b570f6d02..92c6b885b 100644 --- a/src/winml/modelkit/datasets/image.py +++ b/src/winml/modelkit/datasets/image.py @@ -78,7 +78,8 @@ def _load_and_sample(self) -> Any: if streaming: # Streaming datasets aren't indexable: shuffle reservoir-samples - # within a buffer; take() pulls only the slice we need. + # within a buffer; take() pulls only the slice we need. Keep the + # 1000-item reservoir for class diversity on class-ordered streams. if shuffle: dataset = dataset.shuffle(seed=seed, buffer_size=1000) dataset = dataset.take(self._max_samples) diff --git a/tests/e2e/test_quantize_e2e.py b/tests/e2e/test_quantize_e2e.py index fe0df8e5e..745ca043b 100644 --- a/tests/e2e/test_quantize_e2e.py +++ b/tests/e2e/test_quantize_e2e.py @@ -362,16 +362,23 @@ def test_explicit_weight_activation_type_override_precision( model = _assert_quantized_output(input_onnx=tiny_onnx, output_onnx=out, stdout=r.output) assert _weight_dq_zero_point_dtype(model) == onnx.TensorProto.INT8 - def test_unknown_precision_falls_back_to_uint8( + def test_non_quant_precision_rejected( self, runner: CliRunner, tiny_onnx: Path, tmp_path: Path ): + """Float precisions like fp16 must be rejected at CLI parse time. + + Replaces the legacy ``test_unknown_precision_falls_back_to_uint8`` which + documented the silent-fallback bug that PR #680 fixed. + """ out = tmp_path / "a6.onnx" r = _invoke( runner, ["-m", str(tiny_onnx), "-o", str(out), "--precision", "fp16", "--samples", "4"], + expect_success=False, ) - model = _assert_quantized_output(input_onnx=tiny_onnx, output_onnx=out, stdout=r.output) - assert _zero_point_dtype(model, "QuantizeLinear") == onnx.TensorProto.UINT8 + assert r.exit_code != 0 + assert "not a supported quantization precision" in r.output + assert not out.exists() # =========================================================================== diff --git a/tests/unit/datasets/test_image_streaming.py b/tests/unit/datasets/test_image_streaming.py new file mode 100644 index 000000000..68d58ce7a --- /dev/null +++ b/tests/unit/datasets/test_image_streaming.py @@ -0,0 +1,127 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Regression tests for ImageDataset streaming behavior. + +Pins the default-dataset-streaming win (no full bulk download for tiny +calibration sets) and the documented fallback when no max_samples is set. +""" + +from __future__ import annotations + +import pytest + +from winml.modelkit.datasets.image import ImageDataset + + +def _make_uninitialized( + *, + dataset_name: str | None, + max_samples: int | None, + config: dict | None = None, +) -> ImageDataset: + """Build an ImageDataset bypassing __init__ (which triggers HF downloads).""" + ds = ImageDataset.__new__(ImageDataset) + ds._model_name = "x" + ds._dataset_name = dataset_name + ds._max_samples = max_samples + ds._data_split = None + ds._config = config if config is not None else {} + ds._dataset = None + ds._metadata = {} + return ds + + +class _FakeStreamingDataset: + """Minimal IterableDataset stand-in capturing shuffle args.""" + + def __init__(self) -> None: + self.shuffle_calls: list[dict] = [] + self.take_calls: list[int] = [] + self.features = None + + def shuffle(self, seed, buffer_size): + self.shuffle_calls.append({"seed": seed, "buffer_size": buffer_size}) + return self + + def take(self, n): + self.take_calls.append(n) + return self + + def __iter__(self): + return iter([]) + + +class TestDefaultDatasetStreaming: + def test_default_dataset_enables_streaming(self) -> None: + ds = _make_uninitialized(dataset_name=None, max_samples=10) + ds._get_default_dataset() + assert ds._config.get("streaming") is True + assert ds._dataset_name == "timm/mini-imagenet" + + def test_custom_dataset_does_not_force_streaming(self) -> None: + ds = _make_uninitialized(dataset_name="cifar10", max_samples=10) + ds._get_default_dataset() # no-op when dataset_name set + assert ds._config.get("streaming") in (None, False) + + def test_streaming_without_max_samples_degrades_to_bulk(self, monkeypatch) -> None: + """Documented fallback: streaming=True + max_samples=None => bulk load.""" + ds = _make_uninitialized( + dataset_name="cifar10", + max_samples=None, + config={"streaming": True}, + ) + + captured: dict = {} + + class _FakeBulk: + def __len__(self) -> int: + return 0 + + def shuffle(self, *a, **kw): + return self + + def select(self, *a, **kw): + return self + + def fake_load(name, split, streaming): + captured["streaming"] = streaming + return _FakeBulk() + + monkeypatch.setattr( + "winml.modelkit.datasets.image.load_dataset", fake_load + ) + ds._load_and_sample() + assert captured["streaming"] is False + + @pytest.mark.parametrize("max_samples", [10, 100, 5000]) + def test_streaming_buffer_is_1000_for_class_diversity( + self, monkeypatch, max_samples + ) -> None: + """Pin the 1000-item reservoir for class diversity on class-ordered streams.""" + ds = _make_uninitialized( + dataset_name="cifar10", + max_samples=max_samples, + config={"streaming": True, "shuffle": True}, + ) + + fake = _FakeStreamingDataset() + + def fake_load(name, split, streaming): + assert streaming is True + return fake + + monkeypatch.setattr( + "winml.modelkit.datasets.image.load_dataset", fake_load + ) + # Bypass ArrowDataset.from_list — fake has no real records. + monkeypatch.setattr( + "datasets.Dataset.from_list", + lambda records, features=None: records, + ) + + ds._load_and_sample() + assert len(fake.shuffle_calls) == 1 + assert fake.shuffle_calls[0]["buffer_size"] == 1000 + assert fake.take_calls == [max_samples] diff --git a/tests/unit/eval/test_object_detection_evaluator.py b/tests/unit/eval/test_object_detection_evaluator.py index 83750b326..8ece05e00 100644 --- a/tests/unit/eval/test_object_detection_evaluator.py +++ b/tests/unit/eval/test_object_detection_evaluator.py @@ -9,7 +9,7 @@ from datasets import ClassLabel, Dataset, Features, Sequence, Value from winml.modelkit.datasets import DatasetConfig -from winml.modelkit.eval import WinMLObjectDetectionEvaluator +from winml.modelkit.eval import WinMLEvaluator, WinMLObjectDetectionEvaluator # --------------------------------------------------------------------------- @@ -220,29 +220,38 @@ def _patch_super_prepare_pipeline(monkeypatch, fake_pipe) -> None: """Patch the *parent* prepare_pipeline so super() in the subclass returns fake_pipe. Uses monkeypatch so the override is reverted on teardown — directly assigning - to a parent class attribute leaks across tests in the same session. + to a parent class attribute leaks across tests in the same session. Patches + ``WinMLEvaluator`` explicitly rather than ``__mro__[1]`` so that inserting a + mixin in the future fails loudly at collection instead of silently mis-patching. """ - parent = WinMLObjectDetectionEvaluator.__mro__[1] - monkeypatch.setattr(parent, "prepare_pipeline", lambda self: fake_pipe) + monkeypatch.setattr(WinMLEvaluator, "prepare_pipeline", lambda self: fake_pipe) class TestPreparePipeline: """Cover the two branches in prepare_pipeline().""" - def test_pixel_mask_path_enables_padding(self, monkeypatch) -> None: + @pytest.mark.parametrize( + ("h", "w", "shortest", "longest"), + [ + (800, 800, 800, 800), + (800, 1333, 800, 1333), + (1333, 800, 800, 1333), + ], + ) + def test_pixel_mask_path_enables_padding(self, monkeypatch, h, w, shortest, longest) -> None: """Model declaring pixel_mask: aspect-preserving size + pad_size + do_pad.""" ev = make_evaluator() ev.model.io_config = { "input_names": ["pixel_values", "pixel_mask"], - "input_shapes": [[1, 3, 800, 800], [1, 800, 800]], + "input_shapes": [[1, 3, h, w], [1, h, w]], } fake_pipe = _FakePipe() _patch_super_prepare_pipeline(monkeypatch, fake_pipe) pipe = ev.prepare_pipeline() proc = pipe.image_processor - assert proc.size == {"shortest_edge": 800, "longest_edge": 800} - assert proc.pad_size == {"height": 800, "width": 800} + assert proc.size == {"shortest_edge": shortest, "longest_edge": longest} + assert proc.pad_size == {"height": h, "width": w} assert proc.do_pad is True def test_no_pixel_mask_path_disables_padding(self, monkeypatch) -> None: