Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/winml/modelkit/commands/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
"-p",
type=str,
default=None,
help="Quantization precision: int8, int16, or w{x}a{y} (e.g., w8a16). "
help="Quantization precision. Accepted: auto, int8, int16, or w{x}a{y} "
"where x,y in {8,16} (e.g., w8a8, w8a16, w16a16). "
"Overridden by explicit --weight-type/--activation-type.",
)
@click.option(
Expand Down Expand Up @@ -192,8 +193,7 @@ def quantize(
# Show info
Comment thread
zhenchaoni marked this conversation as resolved.
console.print(f"[bold blue]Input:[/bold blue] {model}")
console.print(f"[bold blue]Output:[/bold blue] {output}")
if precision:
console.print(f"[bold blue]Precision:[/bold blue] {precision}")
console.print(f"[bold blue]Precision:[/bold blue] {precision or 'auto'}")
console.print(f"[bold blue]Weight type:[/bold blue] {resolved_weight}")
console.print(f"[bold blue]Activation type:[/bold blue] {resolved_activation}")
console.print(f"[bold blue]Samples:[/bold blue] {samples}")
Expand Down Expand Up @@ -260,8 +260,15 @@ def _resolve_quant_types(

if precision and is_quantized_precision(precision):
default_w, default_a = resolve_quant_types(precision)
else:
elif precision is None or precision.lower() == "auto":
Comment thread
zhenchaoni marked this conversation as resolved.
default_w, default_a = "uint8", "uint8"
else:
raise click.BadParameter(
f"'{precision}' is not a supported quantization precision. "
"Accepted: auto, int8, int16, or w{x}a{y} with x,y in {8,16} "
"(e.g., w8a8, w8a16, w16a16).",
param_hint="'-p' / '--precision'",
)

# Explicit flags override precision defaults
resolved_w = weight_type if weight_type else default_w
Expand Down
70 changes: 48 additions & 22 deletions src/winml/modelkit/datasets/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,42 @@ def _get_default_dataset(self) -> None:
if self._dataset_name is None:
self._dataset_name = "timm/mini-imagenet"
self._data_split = "train"
self._config.setdefault("streaming", True)
Comment thread
zhenchaoni marked this conversation as resolved.

def _initialize(self) -> None:
"""Initialize the image classification dataset.
def _load_and_sample(self) -> Any:
"""Load the configured dataset and apply sample/shuffle.

Simplified approach:
1. Set defaults if needed
2. Load dataset
3. Detect columns
4. Apply efficient processing pipeline
"""
# 1. Set defaults if no dataset specified
if self._dataset_name is None:
self._get_default_dataset()
Shared by ImageDataset and ObjectDetectionDataset. Column detection
is *not* done here — callers run their own detection on the returned
dataset because the column schema differs by task.

# 2. Load dataset
Returns:
A materialized arrow Dataset of up to ``self._max_samples`` rows.
"""
# Streaming only helps when capped by max_samples; otherwise we'd
# iterate the full remote stream into memory, which is worse than a
# bulk download.
streaming = self._config.get("streaming", False) and self._max_samples is not None
logger.info(f"Loading dataset: {self._dataset_name} with split: {self._data_split}")
try:
dataset = load_dataset(self._dataset_name, split=self._data_split)
dataset = load_dataset(self._dataset_name, split=self._data_split, streaming=streaming)
except Exception as e:
logger.error(f"Failed to load dataset {self._dataset_name}: {e}")
raise

# 3. Detect columns using Features API
self._detect_columns(dataset)

# 5. Efficient sampling and processing pipeline
shuffle = self._config.get("shuffle", False)
seed = self._config.get("seed", 42)

if self._max_samples is not None:
if streaming:
# Streaming datasets aren't indexable: shuffle reservoir-samples
# within a buffer; take() pulls only the slice we need. Keep the
# 1000-item reservoir for class diversity on class-ordered streams.
if shuffle:
dataset = dataset.shuffle(seed=seed, buffer_size=1000)
Comment thread
zhenchaoni marked this conversation as resolved.
dataset = dataset.take(self._max_samples)
from datasets import Dataset as ArrowDataset
dataset = ArrowDataset.from_list(list(dataset), features=dataset.features)
elif self._max_samples is not None:
max_samples = min(self._max_samples, len(dataset))
indices = (
Random(seed).sample(range(len(dataset)), max_samples)
Expand All @@ -90,15 +96,35 @@ def _initialize(self) -> None:
elif shuffle:
dataset = dataset.shuffle(seed=seed)

# 6. Load processor and apply batch processing
processor = AutoImageProcessor.from_pretrained(self._model_name, use_fast=True)
return dataset

def _initialize(self) -> None:
"""Initialize the image classification dataset.

Simplified approach:
1. Set defaults if needed
2. Load dataset
3. Detect columns
4. Apply efficient processing pipeline
"""
# 1. Set defaults if no dataset specified
if self._dataset_name is None:
self._get_default_dataset()

# 2. Load + sample (shared with subclasses)
dataset = self._load_and_sample()

# 3. Detect columns using Features API
self._detect_columns(dataset)

# 4. Load processor and apply batch processing
processor = AutoImageProcessor.from_pretrained(self._model_name, use_fast=True)

# 6. Conditional label alignment using should_align_labels()
# 5. Conditional label alignment using should_align_labels()
if should_align_labels(self._dataset_name):
dataset = dataset.align_labels_with_mapping(get_imagenet_label_map(), self._label_col)

# 7. Apply image processing with proper batch dimension
# 6. Apply image processing with proper batch dimension
def preprocess_single_sample(example):
# Process single image and add batch dimension
return processor(example[self._image_col].convert("RGB"), return_tensors="pt")
Expand Down
30 changes: 2 additions & 28 deletions src/winml/modelkit/datasets/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
from __future__ import annotations

import logging
from random import Random
from typing import Any

from datasets import load_dataset
from datasets.features import Image
from transformers import AutoImageProcessor

Expand Down Expand Up @@ -97,36 +95,12 @@ def _initialize(self) -> None:
if self._dataset_name is None:
self._get_default_dataset()

# Load dataset
logger.info(
"Loading object detection dataset: %s with split: %s",
self._dataset_name,
self._data_split,
)
try:
dataset = load_dataset(self._dataset_name, split=self._data_split)
except Exception as e:
logger.error("Failed to load dataset %s: %s", self._dataset_name, e)
raise
# Load + sample (shared with ImageDataset)
dataset = self._load_and_sample()

# Detect image column (object detection may not have simple ClassLabel)
self._detect_image_column(dataset)

# Efficient sampling
shuffle = self._config.get("shuffle", False)
seed = self._config.get("seed", 42)

if self._max_samples is not None:
max_samples = min(self._max_samples, len(dataset))
indices = (
Random(seed).sample(range(len(dataset)), max_samples)
if shuffle
else list(range(max_samples))
)
dataset = dataset.select(indices)
elif shuffle:
dataset = dataset.shuffle(seed=seed)

# Derive ONNX-aware overrides
io_config = self._config.get("io_config")
overrides = self._derive_overrides(io_config)
Expand Down
17 changes: 14 additions & 3 deletions src/winml/modelkit/eval/object_detection_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,22 @@ def prepare_pipeline(self) -> Pipeline:

io_config = getattr(self.model, "io_config", None) or {}
input_shapes = io_config.get("input_shapes", [[]])
input_names = io_config.get("input_names", [])
if input_shapes and len(input_shapes[0]) == 4:
_, _, h, w = input_shapes[0]
pipe.image_processor.size = {"height": h, "width": w}
if hasattr(pipe.image_processor, "do_pad"):
pipe.image_processor.do_pad = False
if "pixel_mask" in input_names:
pipe.image_processor.size = {
"shortest_edge": min(h, w),
"longest_edge": max(h, w),
}
if hasattr(pipe.image_processor, "pad_size"):
pipe.image_processor.pad_size = {"height": h, "width": w}
if hasattr(pipe.image_processor, "do_pad"):
pipe.image_processor.do_pad = True
else:
pipe.image_processor.size = {"height": h, "width": w}
if hasattr(pipe.image_processor, "do_pad"):
pipe.image_processor.do_pad = False

return pipe

Expand Down
13 changes: 10 additions & 3 deletions tests/e2e/test_quantize_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,16 +362,23 @@ def test_explicit_weight_activation_type_override_precision(
model = _assert_quantized_output(input_onnx=tiny_onnx, output_onnx=out, stdout=r.output)
assert _weight_dq_zero_point_dtype(model) == onnx.TensorProto.INT8

def test_unknown_precision_falls_back_to_uint8(
def test_non_quant_precision_rejected(
self, runner: CliRunner, tiny_onnx: Path, tmp_path: Path
):
"""Float precisions like fp16 must be rejected at CLI parse time.

Replaces the legacy ``test_unknown_precision_falls_back_to_uint8`` which
documented the silent-fallback bug that PR #680 fixed.
"""
out = tmp_path / "a6.onnx"
r = _invoke(
runner,
["-m", str(tiny_onnx), "-o", str(out), "--precision", "fp16", "--samples", "4"],
expect_success=False,
)
model = _assert_quantized_output(input_onnx=tiny_onnx, output_onnx=out, stdout=r.output)
assert _zero_point_dtype(model, "QuantizeLinear") == onnx.TensorProto.UINT8
assert r.exit_code != 0
assert "not a supported quantization precision" in r.output
assert not out.exists()


# ===========================================================================
Expand Down
46 changes: 42 additions & 4 deletions tests/unit/commands/test_compile_quantize_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,11 @@ def test_precision_case_insensitive(self):
assert a == "uint8"

def test_unknown_precision_uses_defaults(self):
"""Unknown precision string falls through to defaults."""
w, a = _resolve_quant_types("fp16", None, None)
assert w == "uint8"
assert a == "uint8"
"""Explicit non-quantized precision (e.g., fp16) is rejected."""
import click

with pytest.raises(click.BadParameter, match="not a supported quantization precision"):
_resolve_quant_types("fp16", None, None)


class TestCompileDeviceDisplayLabel:
Expand Down Expand Up @@ -516,3 +517,40 @@ def test_non_object_top_level_raises_usage_error(self, tmp_path):
r = self._invoke(["-m", str(model), "--config", str(bc)])
assert r.exit_code != 0
assert "Build config must be a JSON object" in r.output


class TestQuantizePrecisionValidation:
"""Regression tests for issue #555.

`winml quantize --precision <unknown>` must reject the value before
running quantization, instead of silently falling back to uint8/uint8
and printing "Success!".
"""

@staticmethod
def _invoke(args):
from click.testing import CliRunner

from winml.modelkit.commands.quantize import quantize as quantize_cmd

return CliRunner().invoke(quantize_cmd, args, obj={}, catch_exceptions=False)

@pytest.mark.parametrize(
"bad_precision",
["banana", "w4a16", "int4", "fp64"],
)
def test_unknown_precision_rejected(self, tmp_path, bad_precision):
model, _ = TestQuantizeCliConfigPrecedence._setup(tmp_path)
ran: dict[str, bool] = {"called": False}

def fake_quantize(*_args, **_kwargs):
ran["called"] = True
raise AssertionError("quantize_onnx must not be called for invalid precision")

with patch("winml.modelkit.quant.quantize_onnx", side_effect=fake_quantize):
r = self._invoke(["-m", str(model), "--precision", bad_precision])

assert r.exit_code != 0, r.output
assert "not a supported quantization precision" in r.output
assert ran["called"] is False

13 changes: 7 additions & 6 deletions tests/unit/config/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,12 +578,13 @@ def test_no_precision_defaults_uint8(self) -> None:
assert w == "uint8"
assert a == "uint8"

# ---- Unsupported precision falls back to uint8/uint8 ----
def test_unsupported_precision_falls_back(self) -> None:
"""Unsupported precision (w4a16) is not quantized -> fallback to uint8."""
w, a = self._resolve(precision="w4a16")
assert w == "uint8"
assert a == "uint8"
# ---- Unsupported precision is rejected ----
def test_unsupported_precision_rejected(self) -> None:
"""Unsupported precision (w4a16) must raise BadParameter, not silently fall back."""
import click

with pytest.raises(click.BadParameter, match="not a supported quantization precision"):
self._resolve(precision="w4a16")

# ---- Explicit flags override precision ----
def test_explicit_weight_overrides_precision(self) -> None:
Expand Down
Loading
Loading