From 1b8da4c126a9b0ce5a07b6172b05193ee10c0c31 Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Tue, 14 Apr 2026 11:45:09 +0800 Subject: [PATCH 01/32] re init branch --- docs/design/pipeline-model.md | 133 +++++++ src/winml/modelkit/commands/build.py | 9 + src/winml/modelkit/commands/config.py | 120 ++++++ src/winml/modelkit/config/build.py | 1 + src/winml/modelkit/export/htp/exporter.py | 10 +- src/winml/modelkit/models/auto.py | 1 + src/winml/modelkit/models/hf/__init__.py | 14 + .../modelkit/models/hf/encoder_decoder.py | 323 +++++++++++++++ src/winml/modelkit/models/hf/kv_cache.py | 98 +++++ src/winml/modelkit/models/hf/mu2.py | 288 ++++++++++++++ src/winml/modelkit/models/hf/qwen.py | 329 +++++++++++++++ src/winml/modelkit/models/hf/t5.py | 331 ++++++++++++++++ .../modelkit/models/winml/decoder_only.py | 375 ++++++++++++++++++ .../modelkit/models/winml/pipeline_model.py | 205 ++++++++++ tests/unit/export/test_io.py | 157 ++++++++ 15 files changed, 2392 insertions(+), 2 deletions(-) create mode 100644 docs/design/pipeline-model.md create mode 100644 src/winml/modelkit/models/hf/encoder_decoder.py create mode 100644 src/winml/modelkit/models/hf/kv_cache.py create mode 100644 src/winml/modelkit/models/hf/mu2.py create mode 100644 src/winml/modelkit/models/hf/qwen.py create mode 100644 src/winml/modelkit/models/hf/t5.py create mode 100644 src/winml/modelkit/models/winml/decoder_only.py create mode 100644 src/winml/modelkit/models/winml/pipeline_model.py diff --git a/docs/design/pipeline-model.md b/docs/design/pipeline-model.md new file mode 100644 index 000000000..45773f280 --- /dev/null +++ b/docs/design/pipeline-model.md @@ -0,0 +1,133 @@ +# Multi-Model Pipeline Design + +## Problem + +`WinMLAutoModel.from_pretrained` builds ONE ONNX model. Multi-component +architectures (T5 encoder+decoder, SD text_encoder+unet+vae) need multiple +ONNX models composed together. + +## Class Hierarchy + +``` +WinMLPipelineModel(PreTrainedModel) — multi-component base + └─ WinMLEncoderDecoderModel(GenerationMixin) — encoder-decoder with StaticCache + └─ WinMLT5Model — T5 tasks + generation config +``` + +- **WinMLPipelineModel**: `_SUB_MODEL_CONFIG` mapping, `from_pretrained` builds + each component via `WinMLAutoModel`, provides `device`/`to`/`dtype`. +- **WinMLEncoderDecoderModel**: `forward()` with StaticCache KV management, + `_EncoderWithInputPadding` wrapper, `get_encoder()`, `prepare_inputs_for_generation()`. + Auto-pads undersized inputs to ONNX expected shapes via `_pad_inputs`. +- **WinMLT5Model**: declares `_SUB_MODEL_CONFIG` and `generation_config` only. + +## Registry + +`@register_pipeline_model(model_type, task)` registers a pipeline class. +`winml config` checks the registry to generate per-component configs. + +```python +@register_pipeline_model("t5", "translation") +class WinMLT5Model(WinMLEncoderDecoderModel): + _SUB_MODEL_CONFIG = { + "encoder": "feature-extraction", + "decoder": "text2text-generation", + } +``` + +## ONNX Export + +Each component is exported independently via the existing pipeline +(export → optimize → compile). Export wrappers in `models/hf/t5.py`: + +| Component | Class | Description | +|---|---|---| +| Encoder | `T5EncoderWrapper` | `forward(input_ids, attention_mask) → encoder_hidden_states` | +| Decoder | `T5DecoderWrapper` | StaticCache + EncoderDecoderCache from flat KV inputs, extracts new token KV via `gather` | +| Decoder IO | `T5DecoderIOConfig` | OnnxConfig with custom DummyInputGenerators for KV cache tensors | + +### Decoder ONNX I/O (all static shapes) + +``` +Inputs: + decoder_input_ids [1, 1] + encoder_hidden_states [1, enc_seq, d_model] + attention_mask [1, enc_seq] + decoder_attention_mask [1, max_decode] + cache_position [1] + past_{i}_key [1, heads, max_decode, d_kv] # i=0..num_layers-1 + past_{i}_value [1, heads, max_decode, d_kv] + +Outputs: + logits [1, 1, vocab_size] + present_{i}_key [1, heads, 1, d_kv] # new token only + present_{i}_value [1, heads, 1, d_kv] +``` + +Cross-attention KV is always recomputed from `encoder_hidden_states` +(empty cross-attention cache → `is_updated=False` → never constant-folded). + +## KV Cache Design + +Uses HF `StaticCache` for both export and inference: + +- **Export**: `StaticCache.update()` uses `index_copy_` which traces correctly + in `torch.onnx.export`. `KV_index = sequence_position` always holds, so T5's + relative position bias computes correct distances. +- **Inference**: Same `StaticCache` object persists across generation steps, + mutated in-place via `cache.update()`. `get_seq_length()` counts non-zero + positions automatically. +- **GenerationMixin integration**: `StaticCache` flows through the generate loop + via `Seq2SeqLMOutput.past_key_values`. GenerationMixin may wrap it in an + `EncoderDecoderCache`; `forward()` unwraps to find the `StaticCache`. + +Known limitation: OpenVINO EP does not support ScatterElements, requires CPU +EP fallback for decoder inference. + +## Usage + +### 1. Generate configs (one per component) + +``` +winml config -m google-t5/t5-small --task translation --device cpu -o t5.json +``` + +Produces two files: +- `t5_encoder.json` — task `feature-extraction` +- `t5_decoder.json` — task `text2text-generation` + +### 2. Build ONNX models independently + +``` +winml build -c t5_encoder.json -m google-t5/t5-small -o output/encoder +winml build -c t5_decoder.json -m google-t5/t5-small -o output/decoder +``` + +### 3. Run translation pipeline + +```python +from winml.modelkit.models.winml.seq2seq import WinMLPipelineModel +from transformers import AutoTokenizer, pipeline + +model = WinMLPipelineModel.from_pretrained("google-t5/t5-small", "translation") +tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + +pipe = pipeline("translation_en_to_fr", model=model, tokenizer=tokenizer) +result = pipe("Hello, how are you?", num_beams=1) +print(result[0]["translation_text"]) +# Bonjour, comment êtes-vous ? +``` + +`from_pretrained` resolves the concrete class from `PIPELINE_MODEL_REGISTRY`, +builds both ONNX sub-models via `WinMLAutoModel`, and returns a model that +plugs into HF `transformers.pipeline` as a drop-in replacement for +`T5ForConditionalGeneration`. + +## Open Questions + +- Manage KV cache and attention mask jointly in same cache class? +- Update KV in numpy to avoid pytorch tensor <-> numpy array round trip? +- Handle quantized cache (channel-wise quantization for accuracy)? +- EP-specific KV cache management to avoid ORT <-> EP round trip? +- Beam search support (requires dynamic batch)? +- Is it possible/better to use a shared model class for both export and inference? diff --git a/src/winml/modelkit/commands/build.py b/src/winml/modelkit/commands/build.py index ee9a9159c..a15d7851e 100644 --- a/src/winml/modelkit/commands/build.py +++ b/src/winml/modelkit/commands/build.py @@ -304,6 +304,12 @@ def _build_modules( default=None, help="Maximum autoconf re-optimization rounds (default: 3). --no-analyze sets this to 0.", ) +@click.option( + "--trust-remote-code", + is_flag=True, + default=False, + help="Trust remote code for custom model architectures (e.g., Mu2).", +) @click.option( "-v", "--verbose", @@ -327,6 +333,7 @@ def build( device: str | None, no_analyze: bool, max_optim_iterations: int | None, + trust_remote_code: bool, verbose: bool, ) -> None: r"""Build a WinML-optimized ONNX model from a HuggingFace model or .onnx file. @@ -408,6 +415,8 @@ def build( extra_kwargs["hack_max_optim_iterations"] = 0 elif max_optim_iterations is not None: extra_kwargs["hack_max_optim_iterations"] = max_optim_iterations + if trust_remote_code: + extra_kwargs["trust_remote_code"] = True if is_module_mode: # ---- MODULE MODE: array config, one build per submodule ---- diff --git a/src/winml/modelkit/commands/config.py b/src/winml/modelkit/commands/config.py index f37ea22a4..3e28f823a 100644 --- a/src/winml/modelkit/commands/config.py +++ b/src/winml/modelkit/commands/config.py @@ -166,6 +166,11 @@ def _is_onnx_file(model_input: str) -> bool: default=False, help="Allow running custom code from model repository", ) +@click.option( + "--subfolder", + default=None, + help="Subfolder within HF repo to load from (e.g., 'text_encoder' for Stable Diffusion).", +) def config( hf_model: str | None, task: str | None, @@ -183,6 +188,7 @@ def config( no_quant: bool, no_compile: bool, trust_remote_code: bool, + subfolder: str | None, ) -> None: r"""Generate WinMLBuildConfig for a HuggingFace model or .onnx file. @@ -302,6 +308,31 @@ def config( label = hf_model or model_type console.print(f"[dim]Generating config for {label}...[/dim]") + # Check pipeline model registry: (model_type, task) → multi-config + pipeline_components = _resolve_pipeline_components( + hf_model, model_type, task, trust_remote_code=trust_remote_code + ) + if pipeline_components: + # Pipeline model: generate one config per sub-component + _generate_pipeline_configs( + pipeline_components, + hf_model=hf_model, + model_class=model_class, + model_type=model_type, + override=override, + shape_config=shape_config, + library_name=library_name, + device=device, + precision=precision, + trust_remote_code=trust_remote_code, + ep=ep, + no_quant=no_quant, + no_compile=no_compile, + output=output, + console=console, + ) + return + # Generate config(s) - returns single or list based on module parameter result = generate_hf_build_config( model_id=hf_model, @@ -316,6 +347,7 @@ def config( precision=precision, trust_remote_code=trust_remote_code, ep=ep, + subfolder=subfolder, ) # Handle output format @@ -363,3 +395,91 @@ def config( if verbose: logger.exception("Unexpected error during config generation") raise click.ClickException(f"Unexpected error: {e}") from e + + +def _resolve_pipeline_components( + hf_model: str | None, + model_type: str | None, + task: str | None, + trust_remote_code: bool = False, +) -> dict[str, str] | None: + """Check if (model_type, task) is a registered pipeline model. + + Returns _SUB_MODEL_CONFIG dict if found, None otherwise. + """ + if task is None: + return None + + import winml.modelkit.models.hf # noqa: F401 # trigger pipeline registrations + + from ..models.winml.pipeline_model import PIPELINE_MODEL_REGISTRY + + # Resolve model_type from HF config if not provided + resolved_type = model_type + if resolved_type is None and hf_model is not None: + from transformers import AutoConfig + + resolved_type = AutoConfig.from_pretrained( + hf_model, trust_remote_code=trust_remote_code + ).model_type + + if resolved_type is None: + return None + + cls = PIPELINE_MODEL_REGISTRY.get((resolved_type, task)) + return cls._SUB_MODEL_CONFIG if cls is not None else None + + +def _generate_pipeline_configs( + components: dict[str, str], + *, + hf_model: str | None, + model_class: str | None, + model_type: str | None, + override: Any, + shape_config: dict | None, + library_name: str, + device: str, + precision: str, + trust_remote_code: bool, + ep: str | None, + no_quant: bool, + no_compile: bool, + output: str | None, + console: Any, +) -> None: + """Generate and save one config file per pipeline sub-component.""" + from ..config import generate_hf_build_config + + for component_name, component_task in components.items(): + console.print( + f"[dim]Generating config for component '{component_name}' " + f"(task={component_task})...[/dim]" + ) + + cfg = generate_hf_build_config( + model_id=hf_model, + task=component_task, + model_class=model_class, + model_type=model_type, + override=override, + shape_config=shape_config, + library_name=library_name, + device=device, + precision=precision, + trust_remote_code=trust_remote_code, + ep=ep, + ) + _apply_stage_overrides(cfg, no_quant=no_quant, no_compile=no_compile) + + config_json = json.dumps(cfg.to_dict(), indent=2) + + if output: + out_path = Path(output) + suffixed = out_path.with_stem(f"{out_path.stem}_{component_name}") + suffixed.parent.mkdir(parents=True, exist_ok=True) + suffixed.write_text(config_json) + console.print(f"[green]Config saved to:[/green] {suffixed}") + else: + console.print(f"[bold]--- {component_name} ({component_task}) ---[/bold]") + print(config_json) diff --git a/src/winml/modelkit/config/build.py b/src/winml/modelkit/config/build.py index 1eca3bfc4..7f63d4ec0 100644 --- a/src/winml/modelkit/config/build.py +++ b/src/winml/modelkit/config/build.py @@ -855,6 +855,7 @@ def _merge_export_config( dynamic_axes=( override.dynamic_axes if override.dynamic_axes is not None else base.dynamic_axes ), + dynamo=override.dynamo if override.dynamo else base.dynamo, ) diff --git a/src/winml/modelkit/export/htp/exporter.py b/src/winml/modelkit/export/htp/exporter.py index f6466be07..0c5fa8214 100644 --- a/src/winml/modelkit/export/htp/exporter.py +++ b/src/winml/modelkit/export/htp/exporter.py @@ -441,9 +441,15 @@ def _convert_model_to_onnx( if export_config.dynamic_axes: onnx_kwargs["dynamic_axes"] = export_config.dynamic_axes - tuple(inputs.values()) with self._get_optimum_patcher(model, task): - torch.onnx.export(model, (), output_path, kwargs=inputs, **onnx_kwargs) + # Models can override input binding by implementing + # get_export_args(inputs) → tuple of positional args. + # Default: pass inputs dict as kwargs. + if hasattr(model, "get_export_args"): + export_args = model.get_export_args(inputs) + torch.onnx.export(model, export_args, output_path, **onnx_kwargs) + else: + torch.onnx.export(model, (), output_path, kwargs=inputs, **onnx_kwargs) @staticmethod def _get_optimum_patcher(model: nn.Module, task: str | None) -> Any: diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index 3195d0fee..45db1299b 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -292,6 +292,7 @@ def from_pretrained( shape_config=shape_config, device=device, precision=precision, + trust_remote_code=trust_remote_code, ep=kwargs.get("ep"), ) diff --git a/src/winml/modelkit/models/hf/__init__.py b/src/winml/modelkit/models/hf/__init__.py index d26685b12..97b7cda64 100644 --- a/src/winml/modelkit/models/hf/__init__.py +++ b/src/winml/modelkit/models/hf/__init__.py @@ -36,11 +36,21 @@ from .depth_anything import DepthAnythingIOConfig as _DepthAnythingIOConfig # triggers registration from .depth_pro import DepthProIOConfig as _DepthProIOConfig # triggers registration from .detr import DETR_CONFIG +from .mu2 import MODEL_CLASS_MAPPING as _MU2_CLASS_MAPPING +from .mu2 import Mu2DecoderIOConfig as _Mu2DecoderIOConfig # triggers registration +from .mu2 import Mu2EncoderIOConfig as _Mu2EncoderIOConfig # triggers registration +from .qwen import MODEL_CLASS_MAPPING as _QWEN_CLASS_MAPPING +from .qwen import QWEN_CONFIG +from .qwen import QwenGenIOConfig as _QwenGenIOConfig +from .qwen import QwenPrefillIOConfig as _QwenPrefillIOConfig from .roberta import ROBERTA_FAMILY_CONFIG from .roberta import RobertaIOConfig as _RobertaIOConfig # triggers registration from .sam import MODEL_CLASS_MAPPING as _SAM2_CLASS_MAPPING from .segformer import MODEL_CLASS_MAPPING as _SEGFORMER_CLASS_MAPPING from .segformer import SegformerIOConfig as _SegformerIOConfig # triggers registration +from .t5 import MODEL_CLASS_MAPPING as _T5_CLASS_MAPPING +from .t5 import T5DecoderIOConfig as _T5DecoderIOConfig # triggers registration +from .t5 import T5EncoderIOConfig as _T5EncoderIOConfig # triggers registration from .vision_encoder_decoder import VISION_ENCODER_DECODER_CONFIG from .zoedepth import ZoeDepthIOConfig as _ZoeDepthIOConfig # triggers registration @@ -48,8 +58,11 @@ # Aggregated model class mappings: (model_type, task) -> HF model class MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = { **_CLIP_CLASS_MAPPING, + **_MU2_CLASS_MAPPING, + **_QWEN_CLASS_MAPPING, **_SAM2_CLASS_MAPPING, **_SEGFORMER_CLASS_MAPPING, + **_T5_CLASS_MAPPING, } # Registry: model_type -> WinMLBuildConfig @@ -64,6 +77,7 @@ "clip-vision-model": CLIP_CONFIG, "detr": DETR_CONFIG, "roberta": ROBERTA_FAMILY_CONFIG, + "qwen3": QWEN_CONFIG, "vision-encoder-decoder": VISION_ENCODER_DECODER_CONFIG, "xlm-roberta": ROBERTA_FAMILY_CONFIG, } diff --git a/src/winml/modelkit/models/hf/encoder_decoder.py b/src/winml/modelkit/models/hf/encoder_decoder.py new file mode 100644 index 000000000..7710bee00 --- /dev/null +++ b/src/winml/modelkit/models/hf/encoder_decoder.py @@ -0,0 +1,323 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""WinML Encoder-Decoder inference model and shared input generator. + +Provides ``WinMLEncoderDecoderModel`` — inference wrapper for encoder-decoder +pipelines (T5, mBART, etc.) with static KV cache, and +``EncoderDecoderInputGenerator`` — reusable ``DummyInputGenerator`` for +decoder base inputs shared across encoder-decoder architectures. + +Class hierarchy:: + + WinMLPipelineModel(PreTrainedModel) — multi-component base + └─ WinMLEncoderDecoderModel(GenerationMixin) — encoder-decoder with StaticCache + └─ WinMLT5Model (in t5.py) — T5 tasks + generation config + +How it works: + +1. Each pipeline model declares ``_SUB_MODEL_CONFIG = {"encoder": "feature-extraction", + "decoder": "text2text-generation"}``. ``from_pretrained()`` builds each component + via ``WinMLAutoModel`` (export → optimize → compile) independently. + +2. The encoder is wrapped in ``_EncoderWithInputPadding`` which reads ONNX input + names/shapes from ``io_config`` and zero-pads any undersized inputs. + +3. ``forward()`` takes ``(*, encoder_outputs, past_key_values, input_ids, **model_kwargs)`` + where ``model_kwargs`` carries decoder inputs like ``decoder_input_ids`` and + ``attention_mask``. Feeds are built from model_kwargs + generated inputs + (encoder_hidden_states, decoder_attention_mask, cache_position, KV cache), + filtered to decoder ONNX input names, and auto-padded. + +4. KV cache uses HF ``StaticCache`` — same class for both export (``index_copy_`` + traces correctly in ``torch.onnx.export``) and inference (mutated in-place via + ``cache.update()``). The ONNX decoder takes the full static buffer as input + and outputs only the new token's KV ``[batch, heads, 1, d_kv]``. + +Key findings from T5 KV cache study: + +- HF's ``DynamicCache`` is stateful (same object, mutated in-place via ``cat``). + ``GenerationMixin._update_model_kwargs_for_generation`` reads ``past_key_values`` + from the output and reassigns it in ``model_kwargs`` — but for stateful caches + it's the same reference. +- ``StaticCache`` uses ``index_copy_`` at ``cache_position`` (traces correctly). + ``StaticCache.get_seq_length()`` counts non-zero positions automatically. +- ``EncoderDecoderCache`` with empty cross-attn cache → ``is_updated`` dict is + empty → cross-attention always recomputed from ``encoder_hidden_states`` → + prevents constant-folding during ONNX export. +- ``GenerationMixin`` may wrap our ``StaticCache`` in an ``EncoderDecoderCache`` + before passing it back. ``forward()`` must unwrap to find the ``StaticCache``. +- ``TranslationPipeline`` passes its own ``generation_config`` with ``num_beams=4`` + to ``generate()``. Use ``num_beams=1`` at call time or override in subclass. + +Design principles: + +- NEVER guard config access with default values. Use ``self.config.param`` + directly and let AttributeError raise if the config is missing a field. +- ONNX I/O names and shapes are read from ``io_config``, never hardcoded. +- Inputs smaller than ONNX expected shape are zero-padded automatically. + Inputs larger than expected are NOT truncated — let ORT raise the error. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +import torch +from optimum.utils.input_generators import DummyInputGenerator +from transformers import Cache, StaticCache +from transformers.generation.utils import GenerationMixin +from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput + +from ..winml.pipeline_model import WinMLPipelineModel + + +if TYPE_CHECKING: + from optimum.utils import NormalizedConfig + from transformers import PretrainedConfig + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# EncoderDecoderInputGenerator — shared dummy input generator +# ============================================================================= + + +class EncoderDecoderInputGenerator(DummyInputGenerator): + """Generates decoder base inputs for encoder-decoder models. + + Produces ``decoder_input_ids``, ``encoder_hidden_states``, + ``attention_mask`` (encoder), ``decoder_attention_mask``, and + ``cache_position``. Reads dimensions from ``NormalizedConfig``. + """ + + SUPPORTED_INPUT_NAMES = ( + "decoder_input_ids", + "encoder_hidden_states", + "attention_mask", + "decoder_attention_mask", + "cache_position", + ) + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + batch_size: int = 1, + max_cache_len: int | None = None, + sequence_length: int | None = None, + **kwargs: Any, + ) -> None: + self.batch_size = batch_size + self.d_model = normalized_config.hidden_size + self.enc_seq = sequence_length or getattr(normalized_config, "sequence_length", 16) + self.max_cache_len = max_cache_len or normalized_config.max_cache_len + self.vocab_size = normalized_config.vocab_size + + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ) -> torch.Tensor: + """Generate a dummy tensor for the given input name.""" + if input_name == "decoder_input_ids": + return self.random_int_tensor( + (self.batch_size, 1), + max_value=self.vocab_size, + framework=framework, + dtype=int_dtype, + ) + if input_name == "encoder_hidden_states": + return self.random_float_tensor( + (self.batch_size, self.enc_seq, self.d_model), + framework=framework, + dtype=float_dtype, + ) + if input_name == "attention_mask": + return torch.ones(self.batch_size, self.enc_seq, dtype=torch.int64) + if input_name == "decoder_attention_mask": + return torch.ones(self.batch_size, self.max_cache_len, dtype=torch.int64) + if input_name == "cache_position": + return torch.tensor([5], dtype=torch.int64) # arbitrary position for tracing + raise ValueError(f"Unknown input: {input_name}") + + +# ============================================================================= +# WinMLEncoderDecoderModel — encoder-decoder with StaticCache +# ============================================================================= + + +class WinMLEncoderDecoderModel(WinMLPipelineModel, GenerationMixin): + """Pipeline model with HF GenerationMixin support. + + Expects sub-components ``"encoder"`` and ``"decoder"`` in + ``_SUB_MODEL_CONFIG``. Provides the full interface required by + ``GenerationMixin.generate()`` for encoder-decoder models with + static KV cache. + + Input/output names and shapes are read from ONNX I/O metadata — no + model-specific names are assumed. + """ + + main_input_name = "input_ids" + base_model_prefix = "" + _is_stateful = False + _supports_cache_class = False + + def __init__( + self, + sub_models: dict[str, Any], + config: PretrainedConfig, + ) -> None: + super().__init__(sub_models, config) + raw_encoder = sub_models["encoder"] + self._decoder = sub_models["decoder"] + + # Build {name: shape} lookups from ONNX I/O metadata + enc_io = raw_encoder.io_config + enc_expected = dict( + zip(enc_io.get("input_names", []), enc_io.get("input_shapes", []), strict=False) + ) + # Wrap encoder with auto-padding so all callsites just use self._encoder(...) + self._encoder = self._EncoderWithInputPadding(raw_encoder, enc_expected) + + dec_io = self._decoder.io_config + self._dec_expected = dict( + zip(dec_io.get("input_names", []), dec_io.get("input_shapes", []), strict=False) + ) + + # Max decode length from decoder ONNX KV input shape + self._max_dec = self._dec_expected["past_0_key"][2] + self._num_kv_layers = sum( + 1 for n in self._dec_expected if n.startswith("past_") and n.endswith("_key") + ) + + # ----- Encoder ----- + + class _EncoderWithInputPadding(torch.nn.Module): + """Wraps an encoder sub-model with auto-padding to ONNX expected shapes. + + Matches kwargs against ONNX input names, pads undersized tensors, + and forwards to the underlying WinMLAutoModel. Used as both + ``self._encoder`` (direct calls) and the return value of + ``get_encoder()`` (GenerationMixin contract). + """ + + def __init__(self, encoder: Any, expected: dict[str, list[int]]) -> None: + super().__init__() + self._encoder = encoder + self._expected = expected + + def forward(self, **kwargs: Any) -> BaseModelOutput: + feeds = WinMLPipelineModel._pad_inputs(kwargs, self._expected) + return self._encoder(**feeds) + + def get_encoder(self) -> torch.nn.Module: + """Return encoder for GenerationMixin (already wrapped with padding).""" + return self._encoder + + def can_generate(self) -> bool: # noqa: D102 + return True + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Cache | None = None, + attention_mask: torch.Tensor | None = None, + encoder_outputs: BaseModelOutput | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Build decoder inputs for each generate() step.""" + return { + "decoder_input_ids": input_ids[:, -1:], + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + } + + # ----- Forward (decoder via WinMLAutoModel + KV cache) ----- + + def forward( + self, + *, + encoder_outputs: BaseModelOutput | tuple | None = None, + past_key_values: Cache | None = None, + input_ids: torch.Tensor | None = None, + **model_kwargs: Any, + ) -> Seq2SeqLMOutput: + """Run decoder with static KV cache. + + Args: + encoder_outputs: Pre-computed encoder hidden states. + past_key_values: StaticCache (or wrapper) from previous step. + input_ids: Fallback — run encoder if encoder_outputs is None. + **model_kwargs: Remaining kwargs forwarded to the decoder ONNX + (e.g., decoder_input_ids, attention_mask). Each tensor is + auto-padded to match the ONNX model's expected input shape. + """ + # Encoder hidden states + if encoder_outputs is None and input_ids is not None: + encoder_outputs = self._encoder(input_ids=input_ids, **model_kwargs) + if encoder_outputs is None: + raise ValueError("Either encoder_outputs or input_ids required") + enc_h = encoder_outputs["last_hidden_state"] + + # Resolve the self-attention cache. + # GenerationMixin may pass None, a StaticCache, or an + # EncoderDecoderCache wrapping a DynamicCache (auto-created). + cache = None + if isinstance(past_key_values, StaticCache): + cache = past_key_values + elif hasattr(past_key_values, "self_attention_cache"): + sa = past_key_values.self_attention_cache + if isinstance(sa, StaticCache): + cache = sa + if cache is None: + # Read KV geometry from ONNX metadata (architecture-agnostic) + kv_shape = self._dec_expected["past_0_key"] # [batch, heads, max_dec, head_dim] + cache = StaticCache(self.config, max_cache_len=self._max_dec) + cache.early_initialization( + batch_size=1, + num_heads=kv_shape[1], + head_dim=kv_shape[3], + dtype=torch.float32, + device=torch.device("cpu"), + ) + + # Determine write position from cache occupancy + fc = cache.get_seq_length() + dec_mask = torch.zeros(1, self._max_dec, dtype=torch.int64) + dec_mask[0, : fc + 1] = 1 + + # Build feeds: model_kwargs first, then fill in generated inputs + feeds: dict[str, Any] = dict(model_kwargs) + feeds.setdefault("encoder_hidden_states", enc_h.detach()) + feeds.setdefault("decoder_attention_mask", dec_mask) + feeds.setdefault("cache_position", torch.tensor([fc], dtype=torch.int64)) + for i in range(self._num_kv_layers): + layer = cache.layers[i] + feeds[f"past_{i}_key"] = layer.keys.detach() + feeds[f"past_{i}_value"] = layer.values.detach() + + # Filter to decoder ONNX inputs and pad any undersized tensors + outputs = self._decoder(**self._pad_inputs(feeds, self._dec_expected)) + + # Write new token's KV into the StaticCache in-place + cache_kwargs = {"cache_position": torch.tensor([fc], dtype=torch.int64)} + for i in range(self._num_kv_layers): + cache.update( + outputs[f"present_{i}_key"], + outputs[f"present_{i}_value"], + layer_idx=i, + cache_kwargs=cache_kwargs, + ) + + return Seq2SeqLMOutput( + logits=outputs["logits"], + past_key_values=cache, + ) diff --git a/src/winml/modelkit/models/hf/kv_cache.py b/src/winml/modelkit/models/hf/kv_cache.py new file mode 100644 index 000000000..f4bc53289 --- /dev/null +++ b/src/winml/modelkit/models/hf/kv_cache.py @@ -0,0 +1,98 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Shared KV cache utilities for ONNX export wrappers. + +Provides ``CapturingStaticCache`` — a ``StaticCache`` subclass that captures +each layer's new-token KV from ``update()``, eliminating the scatter→gather +round-trip in the exported ONNX graph. + +Also provides ``PastKeyValueInputGenerator`` — a reusable ``DummyInputGenerator`` +for static KV cache inputs (``past_{i}_key``, ``past_{i}_value``), shared by +T5, Qwen, and future models with static KV cache export. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from optimum.utils.input_generators import DummyInputGenerator +from transformers import StaticCache + + +if TYPE_CHECKING: + import torch + from optimum.utils import NormalizedConfig + + +class CapturingStaticCache(StaticCache): + """StaticCache that captures each layer's new-token KV from ``update()``. + + Standard ``StaticCache.update()`` does ``index_copy_`` (ScatterElements in + ONNX) to write the new KV into the full buffer, then returns the full + buffer for attention. The old approach then used ``gather`` + (GatherElements) to extract the same KV back — a pointless round-trip. + + This subclass intercepts ``update()`` to save the *incoming* + ``key_states`` / ``value_states`` before they enter the buffer, so the + wrapper can return them directly as ONNX outputs. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.captured: dict[int, tuple[torch.Tensor, torch.Tensor]] = {} + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Capture new-token KV, then delegate to parent ``index_copy_``.""" + self.captured[layer_idx] = (key_states, value_states) + return super().update(key_states, value_states, layer_idx, cache_kwargs) + + +class PastKeyValueInputGenerator(DummyInputGenerator): + """Generates ``past_{i}_key`` / ``past_{i}_value`` tensors for static KV cache. + + Reads ``num_layers``, ``num_attention_heads``, ``head_dim``, and + ``max_cache_len`` from the ``NormalizedConfig``. Each model's + ``NORMALIZED_CONFIG_CLASS`` maps these to the appropriate HF config fields + (e.g. T5: ``head_dim="d_kv"``, ``max_cache_len="n_positions"``). + """ + + SUPPORTED_INPUT_NAMES = () # dynamic — built in __init__ + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + batch_size: int = 1, + max_cache_len: int | None = None, + **kwargs: Any, + ) -> None: + self.batch_size = batch_size + self.num_layers: int = normalized_config.num_layers + self.num_heads: int = normalized_config.num_attention_heads + self.head_dim: int = normalized_config.head_dim + self.max_cache_len: int = max_cache_len or normalized_config.max_cache_len + self.SUPPORTED_INPUT_NAMES = tuple( + name for i in range(self.num_layers) for name in (f"past_{i}_key", f"past_{i}_value") + ) + + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ) -> torch.Tensor: + """Return a random float tensor of shape ``[batch, heads, max_cache_len, head_dim]``.""" + return self.random_float_tensor( + (self.batch_size, self.num_heads, self.max_cache_len, self.head_dim), + framework=framework, + dtype=float_dtype, + ) diff --git a/src/winml/modelkit/models/hf/mu2.py b/src/winml/modelkit/models/hf/mu2.py new file mode 100644 index 000000000..c77ff774d --- /dev/null +++ b/src/winml/modelkit/models/hf/mu2.py @@ -0,0 +1,288 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Mu2 HuggingFace Model Configuration. + +Provides encoder/decoder export wrappers and OnnxConfig registrations for +Mu2 encoder-decoder models with static KV cache. + +Export Strategy (split by task): +- Mu2EncoderWrapper + Mu2EncoderIOConfig: ``feature-extraction`` task + → encoder-only ONNX (input_ids, attention_mask → encoder_hidden_states) +- Mu2DecoderWrapper + Mu2DecoderIOConfig: ``text2text-generation`` task + → decoder ONNX with static KV buffer input + single-token KV output. + Input past KV: full static buffer [batch, n_kv_head, max_decode, head_dim]. + Output present KV: new token only [batch, n_kv_head, 1, head_dim]. + +The Mu2 model's native attention (MuAttentionSDPA) does NOT support HF's +cache mechanism. The decoder wrapper reimplements the decoder forward pass +using the original layer weights, adding CapturingStaticCache for +self-attention KV. Cross-attention KV is always recomputed from +encoder_hidden_states (no cache needed). + +Model: local Mu2ForCausalLM with trust_remote_code=True. + +Usage: + wmk config -m path/to/mu2 --task feature-extraction → encoder + wmk config -m path/to/mu2 --task text2text-generation → decoder +""" + +from __future__ import annotations + +from typing import Any, ClassVar + +import torch +import torch.nn as nn +from optimum.exporters.onnx import OnnxConfig +from optimum.utils import NormalizedConfig +from optimum.utils.input_generators import DummyTextInputGenerator + +from ...export import register_onnx_overwrite +from ..winml.pipeline_model import register_pipeline_model +from .encoder_decoder import EncoderDecoderInputGenerator, WinMLEncoderDecoderModel +from .kv_cache import CapturingStaticCache as _CapturingStaticCache +from .kv_cache import PastKeyValueInputGenerator + + +# ============================================================================= +# Wrapper nn.Modules +# ============================================================================= + + +class Mu2EncoderWrapper(nn.Module): + """Wraps Mu2 encoder for standalone ONNX export.""" + + def __init__(self, model: nn.Module) -> None: + super().__init__() + self.encoder = model.encoder + self.config = model.config + + @classmethod + def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> Mu2EncoderWrapper: + """Load full Mu2, extract encoder.""" + from transformers import AutoModelForSeq2SeqLM + + full_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, **kwargs) + wrapper = cls(full_model) + wrapper.eval() + return wrapper + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + """Return encoder last hidden state.""" + return self.encoder( + input_ids=input_ids, attention_mask=attention_mask.bool() + ).last_hidden_state + + +class Mu2DecoderWrapper(nn.Module): + """Wraps Mu2 decoder with CapturingStaticCache for ONNX export. + + Delegates to the model's own decoder (which now accepts ``past_key_values`` + and ``cache_position``). This wrapper just builds the cache from flat + KV inputs, calls the decoder, and collects captured KV outputs. + + Same pattern as ``T5DecoderWrapper``. + """ + + def __init__(self, model: nn.Module) -> None: + super().__init__() + self.model = model + self.config = model.config + self.num_layers = model.config.n_decoder_layer + + @classmethod + def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> Mu2DecoderWrapper: + """Load full Mu2, wrap for cached decoder export.""" + from transformers import AutoModelForSeq2SeqLM + + full_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, **kwargs) + wrapper = cls(full_model) + wrapper.eval() + return wrapper + + def get_export_args(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, ...]: + """Convert dict inputs to positional args for torch.onnx.export.""" + return tuple(inputs.values()) + + def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Run decoder with static KV cache. + + Positional args (order matches OnnxConfig.inputs): + decoder_input_ids, encoder_hidden_states, attention_mask (encoder), + decoder_attention_mask, cache_position, + past_0_key, past_0_value, past_1_key, past_1_value, ... + + Returns: + (logits, present_0_key, present_0_value, ...) where each + present KV is [batch, n_kv_head, 1, head_dim]. + """ + decoder_input_ids = args[0] + encoder_hidden_states = args[1] + encoder_attention_mask = args[2] # "attention_mask" in OnnxConfig + decoder_attention_mask = args[3] + cache_position = args[4] + kv_start = 5 + + # Build CapturingStaticCache from input KV tensors + self_attn_cache = _CapturingStaticCache(self.config, max_cache_len=args[kv_start].size(2)) + self_attn_cache.early_initialization( + batch_size=decoder_input_ids.size(0), + num_heads=self.config.n_kv_head, + head_dim=self.config.head_dim, + dtype=args[kv_start].dtype, + device=decoder_input_ids.device, + ) + for i in range(self.num_layers): + self_attn_cache.layers[i].keys = args[kv_start + i * 2] + self_attn_cache.layers[i].values = args[kv_start + i * 2 + 1] + + # Delegate to model's decoder (now supports past_key_values + cache_position) + hidden_states = self.model.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=self_attn_cache, + cache_position=cache_position, + ) + logits = self.model.lm_head(hidden_states) + + # Collect captured KV + result: list[torch.Tensor] = [logits] + for i in range(self.num_layers): + k, v = self_attn_cache.captured[i] + result.extend([k, v]) + return tuple(result) + + +# ============================================================================= +# OnnxConfig Registrations +# ============================================================================= + + +@register_onnx_overwrite("mu2", "feature-extraction", library_name="transformers") +class Mu2EncoderIOConfig(OnnxConfig): + """ONNX config for Mu2 encoder (feature-extraction task).""" + + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( + vocab_size="vocab_size", + allow_new=True, + ) + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator,) + + @property + def inputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + return { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + } + + @property + def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + return { + "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, + } + + +@register_onnx_overwrite("mu2", "text2text-generation", library_name="transformers") +class Mu2DecoderIOConfig(OnnxConfig): + """ONNX config for Mu2 decoder with static KV cache.""" + + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( + hidden_size="n_embd", + num_layers="n_decoder_layer", + num_attention_heads="n_kv_head", + head_dim="head_dim", + max_cache_len="block_size", + vocab_size="vocab_size", + allow_new=True, + ) + DUMMY_INPUT_GENERATOR_CLASSES = ( + EncoderDecoderInputGenerator, + PastKeyValueInputGenerator, + ) + + @property + def inputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + result: dict[str, dict[int, str]] = { + "decoder_input_ids": {0: "batch_size"}, + "encoder_hidden_states": {0: "batch_size"}, + "attention_mask": {0: "batch_size"}, + "decoder_attention_mask": {0: "batch_size"}, + "cache_position": {}, + } + num_layers = self._normalized_config.num_layers + for i in range(num_layers): + result[f"past_{i}_key"] = {0: "batch_size"} + result[f"past_{i}_value"] = {0: "batch_size"} + return result + + @property + def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + result: dict[str, dict[int, str]] = {"logits": {0: "batch_size"}} + num_layers = self._normalized_config.num_layers + for i in range(num_layers): + result[f"present_{i}_key"] = {0: "batch_size"} + result[f"present_{i}_value"] = {0: "batch_size"} + return result + + +# ============================================================================= +# Model Class Mapping + WinML Inference Model +# ============================================================================= + +MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = { + ("mu2", "feature-extraction"): Mu2EncoderWrapper, + ("mu2", "text2text-generation"): Mu2DecoderWrapper, +} + + +@register_pipeline_model("mu2", "translation") +class WinMLMu2Model(WinMLEncoderDecoderModel): + """Mu2 encoder-decoder model for translation. + + Declares Mu2 sub-component tasks and generation config defaults. + All encoder-decoder forward/cache logic lives in ``WinMLEncoderDecoderModel``. + """ + + _SUB_MODEL_CONFIG: ClassVar[dict[str, str]] = { + "encoder": "feature-extraction", + "decoder": "text2text-generation", + } + + @property + def generation_config(self): # noqa: D102 + if not hasattr(self, "_generation_config"): + from transformers import GenerationConfig + + gc_kw: dict[str, Any] = {} + if self.config is not None: + for attr in ( + "decoder_start_token_id", + "bos_token_id", + "eos_token_id", + "pad_token_id", + ): + val = getattr(self.config, attr, None) + if val is not None: + gc_kw[attr] = val + gc_kw.setdefault("max_new_tokens", self._max_dec - 1) + gc_kw.setdefault("num_beams", 1) + gc_kw.setdefault("do_sample", False) + self._generation_config = GenerationConfig(**gc_kw) + return self._generation_config + + @generation_config.setter + def generation_config(self, value: Any) -> None: + self._generation_config = value + + +__all__ = [ + "MODEL_CLASS_MAPPING", + "Mu2DecoderIOConfig", + "Mu2DecoderWrapper", + "Mu2EncoderIOConfig", + "Mu2EncoderWrapper", + "WinMLMu2Model", +] diff --git a/src/winml/modelkit/models/hf/qwen.py b/src/winml/modelkit/models/hf/qwen.py new file mode 100644 index 000000000..35cbcd956 --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen.py @@ -0,0 +1,329 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Qwen3 HuggingFace Model Configuration. + +Provides decoder export wrappers and OnnxConfig registrations for +Qwen3 decoder-only models with static KV cache, split into prefill +and generation sub-models. + +Export Strategy (split by task): +- QwenDecoderWrapper + QwenPrefillIOConfig: ``feature-extraction`` task + → prefill ONNX (input_ids [1, 64] → logits [1, 64, vocab] + KV [1, kv_heads, 64, head_dim]) +- QwenDecoderWrapper + QwenGenIOConfig: ``text-generation`` task + → generation ONNX (input_ids [1, 1] → logits [1, 1, vocab] + KV [1, kv_heads, 1, head_dim]) + +Both tasks share the same wrapper class; OnnxConfig determines static shapes. +Uses ``CapturingStaticCache`` (from ``kv_cache.py``) to return new-token KV +directly as ONNX outputs, eliminating the scatter→gather round-trip. + +How it works: + +1. ``QwenDecoderWrapper.forward()`` takes positional args (order matches + OnnxConfig.inputs): input_ids, attention_mask, position_ids, cache_position, + past_0_key, past_0_value, ... It builds a ``CapturingStaticCache`` from the + input KV buffers, runs ``Qwen3ForCausalLM``, and returns logits + captured KV. + +2. Decoder-only models need NO ``EncoderDecoderCache`` wrapping — + ``StaticCache`` is passed directly as ``past_key_values``. (Contrast with + T5 where ``EncoderDecoderCache`` is required to route self-attention and + cross-attention to separate caches.) + +3. Logits are returned for ALL input positions (not just last token). + This matches HF convention and enables both generation (last-token logits) + and perplexity evaluation (all-position logits with shifted labels). + +4. ``dynamo=True`` is required for Qwen3 ONNX export — the TorchScript + exporter fails with an internal error. Dynamo produces opset 18 models; + opset 17 downconversion currently fails for these graphs. + +Task name constraints (Optimum compatibility): + +- Task names must exist in ``TasksManager.get_all_tasks()`` to pass + validation in ``register_onnx_overwrite``. Custom names like + ``"causal-lm-prefill"`` require pre-registration in + ``TasksManager._LIBRARY_TO_TASKS_TO_MODEL_LOADER_MAP``. +- ``"causal-lm"`` is a synonym for ``"text-generation"`` in Optimum's + ``_SYNONYM_TASK_MAP`` — registering an OnnxConfig under ``"causal-lm"`` + silently resolves to ``"text-generation"`` at lookup time. +- ``"text-generation-with-past"`` requires the OnnxConfig to implement + ``with_past`` support (raises ``ValueError`` otherwise). +- We use ``"feature-extraction"`` (prefill) and ``"text-generation"`` (gen) + as they are standard tasks with no normalization surprises. + +Model: Qwen/Qwen3-0.6B, Qwen/Qwen3-1.7B, etc. + +Usage:: + + # Generate both configs (pipeline mode) + winml config -m Qwen/Qwen3-0.6B --task text-generation -o qwen.json + + # Build both sub-models + from winml.modelkit.models.winml.decoder_only import WinMLQwen3Model + model = WinMLQwen3Model.from_pretrained("Qwen/Qwen3-0.6B") + + # Or load pre-built ONNX directly (skip_build=True avoids re-optimization) + from winml.modelkit.models.auto import WinMLAutoModel + prefill = WinMLAutoModel.from_pretrained("prefill.onnx", skip_build=True) + gen = WinMLAutoModel.from_pretrained("gen.onnx", skip_build=True) + model = WinMLQwen3Model(sub_models={...}, config=hf_config) +""" + +from __future__ import annotations + +from typing import Any, ClassVar + +import torch +import torch.nn as nn +from optimum.exporters.onnx import OnnxConfig +from optimum.utils import NormalizedConfig +from transformers import AutoModelForCausalLM + +from ...config import WinMLBuildConfig +from ...export import register_onnx_overwrite +from ...export.config import WinMLExportConfig +from ..winml import register_specialization +from ..winml.decoder_only import ( + DecoderOnlyInputGenerator, + DecoderOnlyPrefillInputGenerator, + WinMLDecoderOnlyModel, +) +from ..winml.pipeline_model import register_pipeline_model +from .kv_cache import CapturingStaticCache as _CapturingStaticCache +from .kv_cache import PastKeyValueInputGenerator + + +# ============================================================================= +# Wrapper nn.Module +# ============================================================================= + + +class QwenDecoderWrapper(nn.Module): + """Wraps Qwen3ForCausalLM with static KV cache I/O. + + Used for both prefill and generation ONNX export — same forward logic, + different OnnxConfig determines the static input shapes. + + Input KV: full static buffer ``[batch, kv_heads, max_cache_len, head_dim]``. + Output KV: new positions only ``[batch, kv_heads, seq_len, head_dim]``. + Logits: last position only ``[batch, 1, vocab_size]`` (both prefill and gen). + """ + + def __init__(self, model: nn.Module, num_layers: int) -> None: + super().__init__() + self.model = model + self.num_layers = num_layers + self.config = model.config + + @classmethod + def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> QwenDecoderWrapper: + """Load Qwen3ForCausalLM and wrap for export.""" + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **kwargs) + wrapper = cls(model, model.config.num_hidden_layers) + wrapper.eval() + return wrapper + + def get_export_args(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, ...]: + """Convert dict inputs to positional args for torch.onnx.export.""" + return tuple(inputs.values()) + + def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Run decoder with static KV cache. + + Positional args (order matches OnnxConfig.inputs): + input_ids, attention_mask, position_ids, cache_position, + past_0_key, past_0_value, past_1_key, past_1_value, ... + + Returns: + (logits, present_0_key, present_0_value, ...) where: + - logits is ``[batch, 1, vocab_size]`` (last position only) + - present KV is ``[batch, kv_heads, seq_len, head_dim]`` + """ + input_ids = args[0] + attention_mask = args[1] + position_ids = args[2] + cache_position = args[3] + kv_start = 4 + + # Build CapturingStaticCache from input KV tensors. + # Decoder-only: pass StaticCache directly (no EncoderDecoderCache needed). + cache = _CapturingStaticCache(self.config, max_cache_len=args[kv_start].size(2)) + cache.early_initialization( + batch_size=input_ids.size(0), + num_heads=args[kv_start].size(1), + head_dim=args[kv_start].size(3), + dtype=args[kv_start].dtype, + device=input_ids.device, + ) + for i in range(self.num_layers): + cache.layers[i].keys = args[kv_start + i * 2] + cache.layers[i].values = args[kv_start + i * 2 + 1] + + out = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=cache, + use_cache=True, + cache_position=cache_position, + ) + + # All logits + captured KV directly (no gather). + # forward() selects the right position for padded prefill inputs. + result: list[torch.Tensor] = [out.logits] + for i in range(self.num_layers): + k, v = cache.captured[i] + result.extend([k, v]) + return tuple(result) + + +# Sub-models must use GenericTask (raw ONNX outputs) — task-specific +# wrappers like WinMLModelForFeatureExtraction would discard KV outputs. +register_specialization("qwen3", "feature-extraction", "WinMLModelForGenericTask") +register_specialization("qwen3", "text-generation", "WinMLModelForGenericTask") + + +# ============================================================================= +# OnnxConfig Registrations (using standard Optimum task names) +# ============================================================================= + +_QWEN_NORMALIZED = NormalizedConfig.with_args( + hidden_size="hidden_size", + num_layers="num_hidden_layers", + num_attention_heads="num_key_value_heads", # KV cache uses GQA heads + head_dim="head_dim", + max_cache_len="max_position_embeddings", + vocab_size="vocab_size", + allow_new=True, +) + + +def _qwen_io_inputs(num_layers: int) -> dict[str, dict[int, str]]: + result: dict[str, dict[int, str]] = { + "input_ids": {0: "batch_size"}, + "attention_mask": {0: "batch_size"}, + "position_ids": {0: "batch_size"}, + "cache_position": {}, + } + for i in range(num_layers): + result[f"past_{i}_key"] = {0: "batch_size"} + result[f"past_{i}_value"] = {0: "batch_size"} + return result + + +def _qwen_io_outputs(num_layers: int) -> dict[str, dict[int, str]]: + result: dict[str, dict[int, str]] = {"logits": {0: "batch_size"}} + for i in range(num_layers): + result[f"present_{i}_key"] = {0: "batch_size"} + result[f"present_{i}_value"] = {0: "batch_size"} + return result + + +@register_onnx_overwrite("qwen3", "feature-extraction", library_name="transformers") +class QwenPrefillIOConfig(OnnxConfig): + """ONNX config for Qwen3 prefill (feature-extraction task). + + Inputs: input_ids [1, 64], attention_mask [1, 256], position_ids [1, 64], + cache_position [64], past_{i}_key/value [1, 8, 256, 128] + Outputs: logits [1, 1, vocab], present_{i}_key/value [1, 8, 64, 128] + """ + + NORMALIZED_CONFIG_CLASS = _QWEN_NORMALIZED + DUMMY_INPUT_GENERATOR_CLASSES = (DecoderOnlyPrefillInputGenerator, PastKeyValueInputGenerator) + + @property + def inputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + return _qwen_io_inputs(self._normalized_config.num_layers) + + @property + def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + return _qwen_io_outputs(self._normalized_config.num_layers) + + +@register_onnx_overwrite("qwen3", "text-generation", library_name="transformers") +class QwenGenIOConfig(OnnxConfig): + """ONNX config for Qwen3 generation (text-generation task). + + Inputs: input_ids [1, 1], attention_mask [1, 256], position_ids [1, 1], + cache_position [1], past_{i}_key/value [1, 8, 256, 128] + Outputs: logits [1, 1, vocab], present_{i}_key/value [1, 8, 1, 128] + """ + + NORMALIZED_CONFIG_CLASS = _QWEN_NORMALIZED + DUMMY_INPUT_GENERATOR_CLASSES = (DecoderOnlyInputGenerator, PastKeyValueInputGenerator) + + @property + def inputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + return _qwen_io_inputs(self._normalized_config.num_layers) + + @property + def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + return _qwen_io_outputs(self._normalized_config.num_layers) + + +# ============================================================================= +# Build Config (dynamo=True required for Qwen3) +# ============================================================================= + +QWEN_CONFIG = WinMLBuildConfig( + export=WinMLExportConfig(dynamo=True, opset_version=18), +) + + +# ============================================================================= +# Model Class Mapping +# ============================================================================= + +MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = { + ("qwen3", "feature-extraction"): QwenDecoderWrapper, + ("qwen3", "text-generation"): QwenDecoderWrapper, +} + +# ============================================================================= +# WinMLQwen3Model — inference wrapper (registered as pipeline model) +# ============================================================================= + + +@register_pipeline_model("qwen3", "text-generation") +class WinMLQwen3Model(WinMLDecoderOnlyModel): + """Qwen3 decoder-only model for text generation. + + Declares Qwen3 sub-component tasks and generation config defaults. + All forward/cache logic lives in ``WinMLDecoderOnlyModel``. + """ + + _SUB_MODEL_CONFIG: ClassVar[dict[str, str]] = { + "decoder_prefill": "feature-extraction", + "decoder_gen": "text-generation", + } + + @property + def generation_config(self): # noqa: D102 + if not hasattr(self, "_generation_config"): + from transformers import GenerationConfig + + gc_kw: dict[str, Any] = {} + for attr in ("bos_token_id", "eos_token_id", "pad_token_id"): + val = getattr(self.config, attr, None) + if val is not None: + gc_kw[attr] = val + gc_kw.setdefault("max_new_tokens", self._max_cache_len - self._prefill_seq_len) + gc_kw.setdefault("num_beams", 1) + gc_kw.setdefault("do_sample", False) + self._generation_config = GenerationConfig(**gc_kw) + return self._generation_config + + @generation_config.setter + def generation_config(self, value: Any) -> None: + self._generation_config = value + + +__all__ = [ + "MODEL_CLASS_MAPPING", + "QWEN_CONFIG", + "QwenDecoderWrapper", + "QwenGenIOConfig", + "QwenPrefillIOConfig", + "WinMLQwen3Model", +] diff --git a/src/winml/modelkit/models/hf/t5.py b/src/winml/modelkit/models/hf/t5.py new file mode 100644 index 000000000..02c142f97 --- /dev/null +++ b/src/winml/modelkit/models/hf/t5.py @@ -0,0 +1,331 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""T5 HuggingFace Model Configuration. + +Provides encoder/decoder export wrappers and OnnxConfig registrations for +T5 encoder-decoder models with static KV cache. + +Export Strategy (split by task): +- T5EncoderWrapper + T5EncoderIOConfig: ``feature-extraction`` task + → encoder-only ONNX (input_ids, attention_mask → encoder_hidden_states) +- T5DecoderWrapper + T5DecoderIOConfig: ``text2text-generation`` task + → decoder ONNX with static buffer input + single-token KV output. + Uses HF StaticCache (index_copy_ at cache_position) for attention. + Output is only the new token's KV [batch, heads, 1, d_kv]. + +Model: google-t5/t5-small, google-t5/t5-base, etc. + +Usage: + wmk config -m google-t5/t5-small --task feature-extraction → encoder + wmk config -m google-t5/t5-small --task text2text-generation → decoder +""" + +from __future__ import annotations + +from typing import Any, ClassVar + +import torch +import torch.nn as nn +from optimum.exporters.onnx import OnnxConfig +from optimum.utils import NormalizedConfig +from optimum.utils.input_generators import DummyTextInputGenerator +from transformers import T5ForConditionalGeneration +from transformers.cache_utils import DynamicCache, EncoderDecoderCache + +from ...export import register_onnx_overwrite +from ..winml.pipeline_model import register_pipeline_model +from .encoder_decoder import EncoderDecoderInputGenerator, WinMLEncoderDecoderModel +from .kv_cache import CapturingStaticCache as _CapturingStaticCache +from .kv_cache import PastKeyValueInputGenerator + + +# ============================================================================= +# Wrapper nn.Modules (with from_pretrained, like SAM2 wrappers) +# ============================================================================= + + +class T5EncoderWrapper(nn.Module): + """Wraps T5 encoder for standalone ONNX export. + + Loads the full T5ForConditionalGeneration and extracts the encoder. + """ + + def __init__(self, encoder: nn.Module) -> None: + super().__init__() + self.encoder = encoder + + @classmethod + def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> T5EncoderWrapper: + """Load full T5, extract encoder.""" + full_model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, **kwargs) + wrapper = cls(full_model.encoder) + wrapper.eval() + return wrapper + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + """Return encoder last hidden state.""" + return self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + ).last_hidden_state + + +class T5DecoderWrapper(nn.Module): + """Wraps T5ForConditionalGeneration with static KV cache I/O. + + Input: full static buffer ``[batch, heads, max_decode, d_kv]`` per layer. + Output: only the new token's KV ``[batch, heads, 1, d_kv]`` per layer. + + Uses HF ``StaticCache`` (``index_copy_`` at ``cache_position``) wrapped + in ``EncoderDecoderCache`` (cross-attn empty → always recomputed from + ``encoder_hidden_states``). ``KV_index = sequence_position`` holds, so + T5's relative position bias computes correct distances. + + The inference wrapper (WinMLT5Model) uses the same + ``StaticCache`` class — it writes the single-token output KV back + into the buffer via ``cache.update()`` before the next step. + """ + + def __init__(self, model: nn.Module, num_layers: int) -> None: + super().__init__() + self.model = model + self.num_layers = num_layers + # Expose config for OnnxConfig / NormalizedConfig access + self.config = model.config + + @classmethod + def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> T5DecoderWrapper: + """Load full T5, wrap with static cache.""" + full_model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, **kwargs) + num_layers = full_model.config.num_layers + wrapper = cls(full_model, num_layers) + wrapper.eval() + return wrapper + + def get_export_args(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, ...]: + """Convert dict inputs to positional args for torch.onnx.export.""" + return tuple(inputs.values()) + + def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Run decoder with static KV cache. + + Positional args (order matches OnnxConfig.inputs): + decoder_input_ids, encoder_hidden_states, attention_mask, + decoder_attention_mask, cache_position, + past_0_key, past_0_value, past_1_key, past_1_value, ... + + Returns: + (logits, present_0_key, present_0_value, ...) where each + present KV is [batch, heads, 1, d_kv] — the new token only. + """ + decoder_input_ids = args[0] + encoder_hidden_states = args[1] + attention_mask = args[2] + decoder_attention_mask = args[3] + cache_position = args[4] + kv_start = 5 + + # Build CapturingStaticCache from input KV tensors. + # update() uses index_copy_ at cache_position for correct attention, + # and captures the incoming key/value states for direct output + # (eliminating the old scatter→gather round-trip in the ONNX graph). + self_attn_cache = _CapturingStaticCache(self.config, max_cache_len=args[kv_start].size(2)) + self_attn_cache.early_initialization( + batch_size=decoder_input_ids.size(0), + num_heads=args[kv_start].size(1), + head_dim=args[kv_start].size(3), + dtype=args[kv_start].dtype, + device=decoder_input_ids.device, + ) + for i in range(self.num_layers): + self_attn_cache.layers[i].keys = args[kv_start + i * 2] + self_attn_cache.layers[i].values = args[kv_start + i * 2 + 1] + + # EncoderDecoderCache is structurally required: T5Attention routes + # self-attention → self_attention_cache, cross-attention → cross_attention_cache. + # Without the wrapper, both would share the same cache + layer indices. + # DynamicCache for cross-attn is a no-op during export (each layer + # computes fresh from encoder_hidden_states, never reuses). + cross_attn_cache = DynamicCache() + cache = EncoderDecoderCache(self_attn_cache, cross_attn_cache) + + out = self.model( + decoder_input_ids=decoder_input_ids, + encoder_outputs=(encoder_hidden_states,), + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + past_key_values=cache, + use_cache=True, + cache_position=cache_position, + ) + + # Return new-token KV directly from the capturing cache. + # The old approach did gather(ScatterElements output) — a round-trip. + # _CapturingStaticCache already saved the incoming key/value states. + result: list[torch.Tensor] = [out.logits] + for i in range(self.num_layers): + k, v = self_attn_cache.captured[i] + result.extend([k, v]) + return tuple(result) + + +# ============================================================================= +# OnnxConfig Registrations +# ============================================================================= + + +@register_onnx_overwrite("t5", "feature-extraction", library_name="transformers") +class T5EncoderIOConfig(OnnxConfig): + """ONNX config for T5 encoder (feature-extraction task). + + Inputs: input_ids, attention_mask + Outputs: encoder_hidden_states + """ + + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( + vocab_size="vocab_size", + allow_new=True, + ) + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator,) + + @property + def inputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + return { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + } + + @property + def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + return { + "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, + } + + +@register_onnx_overwrite("t5", "text2text-generation", library_name="transformers") +class T5DecoderIOConfig(OnnxConfig): + """ONNX config for T5 decoder with static KV cache. + + Inputs: decoder_input_ids, encoder_hidden_states, attention_mask, + decoder_attention_mask, cache_position, past_{i}_key/value + Outputs: logits, present_{i}_key/value + + Input past KV: full static buffer [batch, heads, max_decode, d_kv]. + Output present KV: new token only [batch, heads, 1, d_kv]. + """ + + # T5Config: d_model, num_layers, num_heads, d_kv, vocab_size, n_positions. + # sequence_length uses Optimum default (16) — NOT n_positions (512, too large). + # head_dim maps to d_kv for PastKeyValueInputGenerator. + # max_cache_len maps to n_positions (decoder static buffer size). + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( + hidden_size="d_model", + num_layers="num_layers", + num_attention_heads="num_heads", + head_dim="d_kv", + max_cache_len="n_positions", + vocab_size="vocab_size", + allow_new=True, + ) + DUMMY_INPUT_GENERATOR_CLASSES = ( + EncoderDecoderInputGenerator, + PastKeyValueInputGenerator, + ) + + @property + def inputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + result: dict[str, dict[int, str]] = { + "decoder_input_ids": {0: "batch_size"}, + "encoder_hidden_states": {0: "batch_size"}, + "attention_mask": {0: "batch_size"}, + "decoder_attention_mask": {0: "batch_size"}, + "cache_position": {}, + } + num_layers = self._normalized_config.num_layers + for i in range(num_layers): + result[f"past_{i}_key"] = {0: "batch_size"} + result[f"past_{i}_value"] = {0: "batch_size"} + return result + + @property + def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + result: dict[str, dict[int, str]] = { + "logits": {0: "batch_size"}, + } + num_layers = self._normalized_config.num_layers + for i in range(num_layers): + result[f"present_{i}_key"] = {0: "batch_size"} + result[f"present_{i}_value"] = {0: "batch_size"} + return result + + +# ============================================================================= +# Model Class Mapping (same pattern as SAM2 and CLIP) +# ============================================================================= + +MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = { + ("t5", "feature-extraction"): T5EncoderWrapper, + ("t5", "text2text-generation"): T5DecoderWrapper, +} + + +# ============================================================================= +# WinMLT5Model — inference wrapper (registered as pipeline model) +# ============================================================================= + + +@register_pipeline_model("t5", "translation") +class WinMLT5Model(WinMLEncoderDecoderModel): + """T5 encoder-decoder model for translation. + + Declares T5 sub-component tasks and generation config defaults. + All encoder-decoder forward/cache logic lives in ``WinMLEncoderDecoderModel``. + """ + + _SUB_MODEL_CONFIG: ClassVar[dict[str, str]] = { + "encoder": "feature-extraction", + "decoder": "text2text-generation", + } + + @property + def generation_config(self): # noqa: D102 + if not hasattr(self, "_generation_config"): + from transformers import GenerationConfig + + gc_kw: dict[str, Any] = {} + if self.config is not None: + for attr in ( + "decoder_start_token_id", + "bos_token_id", + "eos_token_id", + "pad_token_id", + ): + val = getattr(self.config, attr, None) + if val is not None: + gc_kw[attr] = val + gc_kw.setdefault("max_new_tokens", self._max_dec - 1) + # Static batch=1 ONNX models don't support beam search + gc_kw.setdefault("num_beams", 1) + gc_kw.setdefault("do_sample", False) + self._generation_config = GenerationConfig(**gc_kw) + return self._generation_config + + @generation_config.setter + def generation_config(self, value: Any) -> None: + self._generation_config = value + + +__all__ = [ + "MODEL_CLASS_MAPPING", + "T5DecoderIOConfig", + "T5DecoderWrapper", + "T5EncoderIOConfig", + "T5EncoderWrapper", + "WinMLT5Model", +] diff --git a/src/winml/modelkit/models/winml/decoder_only.py b/src/winml/modelkit/models/winml/decoder_only.py new file mode 100644 index 000000000..18cf9c21e --- /dev/null +++ b/src/winml/modelkit/models/winml/decoder_only.py @@ -0,0 +1,375 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""WinML Decoder-Only Pipeline Model. + +Class hierarchy:: + + WinMLPipelineModel(PreTrainedModel) — multi-component base + └─ WinMLDecoderOnlyModel(GenerationMixin) — prefill + gen with StaticCache + └─ WinMLQwen3Model — Qwen3 tasks + generation config + +How it works: + +1. ``@register_pipeline_model("qwen3", "text-generation")`` hooks into + ``winml config`` so that ``winml config -m Qwen/Qwen3-0.6B --task text-generation`` + generates ``qwen_decoder_prefill.json`` + ``qwen_decoder_gen.json``. + +2. ``from_pretrained()`` builds each component via ``WinMLAutoModel`` + independently. Sub-models are registered as ``WinMLModelForGenericTask`` + (via ``register_specialization``) so their raw ONNX outputs (logits + KV) + are returned as-is — task-specific wrappers like + ``WinMLModelForFeatureExtraction`` would discard the KV outputs. + +3. ``forward()`` is called by ``GenerationMixin.generate()`` on each step: + + - **Prefill** (``input_ids`` has multiple tokens): chunks into + ``prefill_seq_len`` pieces and runs the prefill ONNX model in a loop. + Right-pads the last chunk; only writes real tokens' KV into the cache + (padding positions are discarded). Returns logits for ALL real + positions ``[1, seq_len, vocab]`` — matches HF convention, enabling + both generation (last-token selection) and perplexity evaluation + (shifted cross-entropy over all positions). + + - **Generation** (``input_ids`` has 1 token): runs the gen ONNX model + with the single token + full KV cache buffer as input. + +4. KV cache uses HF ``StaticCache`` — same class as T5. ``get_seq_length()`` + counts non-zero positions; ``cache.update()`` writes new KV via + ``index_copy_``. The cache persists across generate() steps via + ``CausalLMOutputWithPast.past_key_values``. + +5. ``prepare_inputs_for_generation()`` handles a subtle interaction with + ``GenerationMixin``: on the FIRST call, GenerationMixin may pass an + auto-created ``DynamicCache`` (empty). We detect this (not a + ``StaticCache`` or empty) and pass the full prompt through for prefill + rather than trimming to the last token. On subsequent calls with a + populated ``StaticCache``, we trim to the last token as usual. + +Design principles (same as pipeline_model.py): + +- ONNX I/O names and shapes are read from ``io_config``, never hardcoded. +- Inputs smaller than ONNX expected shape are zero-padded via ``_pad_inputs``. +- ``_pad_inputs`` is reused from ``WinMLEncoderDecoderModel`` (static method). +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +import torch +from optimum.utils.input_generators import DummyInputGenerator +from transformers import Cache, StaticCache +from transformers.generation.utils import GenerationMixin +from transformers.modeling_outputs import CausalLMOutputWithPast + +from .pipeline_model import WinMLPipelineModel + + +_pad_inputs = WinMLPipelineModel._pad_inputs + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + +logger = logging.getLogger(__name__) + + +# ========================================================================= +# DecoderOnlyInputGenerator — shared dummy input generator +# ========================================================================= + + +class DecoderOnlyInputGenerator(DummyInputGenerator): + """Generates base inputs for decoder-only models with static KV cache. + + Produces ``input_ids``, ``attention_mask``, ``position_ids``, and + ``cache_position``. Reads ``vocab_size``, ``max_cache_len``, and + ``seq_len`` from the ``NormalizedConfig``. + + ``seq_len`` controls the input token count and is read from + ``normalized_config.seq_len`` (falls back to ``_default_seq_len``). + Subclasses override the default for prefill vs generation: + + - ``DecoderOnlyPrefillInputGenerator``: ``_default_seq_len = 64`` + - ``DecoderOnlyInputGenerator`` (base / gen): ``_default_seq_len = 1`` + + To override at config time, set ``config.seq_len = N`` on the HF config. + """ + + SUPPORTED_INPUT_NAMES = ( + "input_ids", + "attention_mask", + "position_ids", + "cache_position", + ) + + _default_seq_len: int = 1 + + def __init__( + self, + task: str, + normalized_config: Any, + batch_size: int = 1, + seq_len: int | None = None, + max_cache_len: int | None = None, + **kwargs: Any, + ) -> None: + self.batch_size = batch_size + self.vocab_size = normalized_config.vocab_size + self.max_cache_len = max_cache_len or normalized_config.max_cache_len + self.seq_len: int = seq_len or getattr(normalized_config, "seq_len", self._default_seq_len) + + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ) -> torch.Tensor: + """Generate a dummy tensor for the given input name.""" + if input_name == "input_ids": + return self.random_int_tensor( + (self.batch_size, self.seq_len), + max_value=self.vocab_size, + framework=framework, + dtype=int_dtype, + ) + if input_name == "attention_mask": + mask = torch.zeros(self.batch_size, self.max_cache_len, dtype=torch.int64) + mask[:, : self.seq_len] = 1 + return mask + if input_name == "position_ids": + return torch.arange(self.seq_len, dtype=torch.int64).unsqueeze(0) + if input_name == "cache_position": + return torch.arange(self.seq_len, dtype=torch.int64) + raise ValueError(f"Unknown input: {input_name}") + + +class DecoderOnlyPrefillInputGenerator(DecoderOnlyInputGenerator): + """Prefill variant with ``_default_seq_len = 64``.""" + + _default_seq_len: int = 64 + + +# ========================================================================= +# WinMLDecoderOnlyModel — prefill + gen with StaticCache +# ========================================================================= + + +class WinMLDecoderOnlyModel(WinMLPipelineModel, GenerationMixin): + """Decoder-only pipeline model with HF GenerationMixin support. + + Expects sub-components ``"decoder_prefill"`` and ``"decoder_gen"`` in + ``_SUB_MODEL_CONFIG``. Provides the full interface required by + ``GenerationMixin.generate()`` for decoder-only models with static KV cache. + + Input/output names and shapes are read from ONNX I/O metadata. + """ + + main_input_name = "input_ids" + base_model_prefix = "" + _is_stateful = False + _supports_cache_class = False + + def __init__( + self, + sub_models: dict[str, Any], + config: PretrainedConfig, + ) -> None: + super().__init__(sub_models, config) + self._prefill_model = sub_models["decoder_prefill"] + self._gen_model = sub_models["decoder_gen"] + + # Build {name: shape} lookups from ONNX I/O metadata + prefill_io = self._prefill_model.io_config + self._prefill_expected = dict( + zip( + prefill_io.get("input_names", []), + prefill_io.get("input_shapes", []), + strict=False, + ) + ) + gen_io = self._gen_model.io_config + self._gen_expected = dict( + zip(gen_io.get("input_names", []), gen_io.get("input_shapes", []), strict=False) + ) + + # Cache geometry from gen model's KV input shape + self._max_cache_len = self._gen_expected["past_0_key"][2] + self._num_kv_heads = self._gen_expected["past_0_key"][1] + self._head_dim = self._gen_expected["past_0_key"][3] + self._num_kv_layers = sum( + 1 for n in self._gen_expected if n.startswith("past_") and n.endswith("_key") + ) + + # Prefill chunk size + self._prefill_seq_len = self._prefill_expected["input_ids"][1] + + # ----- GenerationMixin interface ----- + + def can_generate(self) -> bool: # noqa: D102 + return True + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Cache | None = None, + attention_mask: torch.Tensor | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Build inputs for each generate() step. + + GenerationMixin may pass a DynamicCache (auto-created, empty) on the + first call. Only trim to last token when we have a populated + StaticCache (i.e., after prefill). + """ + if isinstance(past_key_values, StaticCache) and past_key_values.get_seq_length() > 0: + input_ids = input_ids[:, -1:] + else: + # First call or empty cache: pass full prompt for prefill + past_key_values = None + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "attention_mask": attention_mask, + } + + # ----- Forward ----- + + def forward( + self, + *, + input_ids: torch.Tensor, + past_key_values: Cache | None = None, + attention_mask: torch.Tensor | None = None, + **kwargs: Any, + ) -> CausalLMOutputWithPast: + """Run prefill or gen with static KV cache. + + Called by ``GenerationMixin.generate()`` on each step: + - First call: ``input_ids`` is the full prompt → prefill (chunked). + - Subsequent calls: ``input_ids`` is 1 token → gen. + + Args: + input_ids: Token IDs ``[batch, seq_len]``. + past_key_values: StaticCache from previous step (None on first call). + attention_mask: Not used directly — rebuilt from cache occupancy. + **kwargs: Absorbed for GenerationMixin compatibility. + + Returns: + CausalLMOutputWithPast with logits and updated StaticCache. + """ + # Resolve or create StaticCache (same pattern as T5) + cache = past_key_values if isinstance(past_key_values, StaticCache) else None + if cache is None: + cache = StaticCache(self.config, max_cache_len=self._max_cache_len) + cache.early_initialization( + batch_size=1, + num_heads=self._num_kv_heads, + head_dim=self._head_dim, + dtype=torch.float32, + device=torch.device("cpu"), + ) + + seq_len = input_ids.shape[1] + if seq_len > 1: + logits = self._run_prefill(input_ids, cache) + else: + logits = self._run_gen(input_ids, cache) + + return CausalLMOutputWithPast( + logits=logits, + past_key_values=cache, + ) + + # ----- Prefill (chunked) ----- + + def _run_prefill(self, input_ids: torch.Tensor, cache: StaticCache) -> torch.Tensor: + """Run prefill model in a loop over chunks of ``prefill_seq_len``. + + Returns logits for ALL real input positions ``[1, seq_len, vocab_size]`` + (same convention as HF CausalLM — enables perplexity evaluation). + """ + seq_len = input_ids.shape[1] + all_logits: list[torch.Tensor] = [] + + for start in range(0, seq_len, self._prefill_seq_len): + end = min(start + self._prefill_seq_len, seq_len) + chunk_len = end - start + + # Pad chunk to prefill_seq_len (right-padding) + padded_ids = torch.zeros(1, self._prefill_seq_len, dtype=input_ids.dtype) + padded_ids[0, :chunk_len] = input_ids[0, start:end] + + position_ids = torch.arange( + start, start + self._prefill_seq_len, dtype=torch.int64 + ).unsqueeze(0) + cache_position = torch.arange(start, start + self._prefill_seq_len, dtype=torch.int64) + + # Attention mask: 1 for all real tokens so far + attn_mask = torch.zeros(1, self._max_cache_len, dtype=torch.int64) + attn_mask[0, : start + chunk_len] = 1 + + feeds: dict[str, Any] = { + "input_ids": padded_ids, + "attention_mask": attn_mask, + "position_ids": position_ids, + "cache_position": cache_position, + } + for i in range(self._num_kv_layers): + feeds[f"past_{i}_key"] = cache.layers[i].keys.detach() + feeds[f"past_{i}_value"] = cache.layers[i].values.detach() + + outputs = self._prefill_model(**_pad_inputs(feeds, self._prefill_expected)) + + # Write only real tokens' KV into cache (skip padding) + real_positions = cache_position[:chunk_len] + ck = {"cache_position": real_positions} + for i in range(self._num_kv_layers): + cache.update( + outputs[f"present_{i}_key"][:, :, :chunk_len, :], + outputs[f"present_{i}_value"][:, :, :chunk_len, :], + layer_idx=i, + cache_kwargs=ck, + ) + + # Keep logits for real tokens only (discard padding positions) + all_logits.append(outputs["logits"][:, :chunk_len, :]) + + return torch.cat(all_logits, dim=1) + + # ----- Generation (single token) ----- + + def _run_gen(self, input_ids: torch.Tensor, cache: StaticCache) -> torch.Tensor: + """Run gen model for a single token. Returns logits ``[1, 1, vocab_size]``.""" + fc = cache.get_seq_length() + + attn_mask = torch.zeros(1, self._max_cache_len, dtype=torch.int64) + attn_mask[0, : fc + 1] = 1 + + feeds: dict[str, Any] = { + "input_ids": input_ids, + "attention_mask": attn_mask, + "position_ids": torch.tensor([[fc]], dtype=torch.int64), + "cache_position": torch.tensor([fc], dtype=torch.int64), + } + for i in range(self._num_kv_layers): + feeds[f"past_{i}_key"] = cache.layers[i].keys.detach() + feeds[f"past_{i}_value"] = cache.layers[i].values.detach() + + outputs = self._gen_model(**_pad_inputs(feeds, self._gen_expected)) + + # Write new token's KV into cache + ck = {"cache_position": torch.tensor([fc], dtype=torch.int64)} + for i in range(self._num_kv_layers): + cache.update( + outputs[f"present_{i}_key"], + outputs[f"present_{i}_value"], + layer_idx=i, + cache_kwargs=ck, + ) + + return outputs["logits"] diff --git a/src/winml/modelkit/models/winml/pipeline_model.py b/src/winml/modelkit/models/winml/pipeline_model.py new file mode 100644 index 000000000..a3e541035 --- /dev/null +++ b/src/winml/modelkit/models/winml/pipeline_model.py @@ -0,0 +1,205 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""WinML Pipeline Model base and registry. + +Provides ``WinMLPipelineModel`` — a base class for models composed of +multiple ``WinMLAutoModel`` sub-components (e.g., encoder + decoder, +prefill + gen). Each subclass declares ``_SUB_MODEL_CONFIG`` mapping +component names to HF tasks; ``from_pretrained()`` builds them all. + +Also provides the ``PIPELINE_MODEL_REGISTRY`` and ``register_pipeline_model`` +decorator, used by ``wmk config`` to generate per-component config files. + +Concrete pipeline models live alongside their export configs: + +- ``models.hf.t5.WinMLT5Model`` (encoder-decoder, T5) +- ``models.hf.qwen.WinMLQwen3Model`` (decoder-only, Qwen3) +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, ClassVar + +import torch + +from .base import PreTrainedModel + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + +logger = logging.getLogger(__name__) + + +# ========================================================================= +# Pipeline Model Registry +# ========================================================================= + +# Maps (model_type, task) → pipeline class with _SUB_MODEL_CONFIG. +# Used by `wmk config` to generate one config file per sub-component. +PIPELINE_MODEL_REGISTRY: dict[tuple[str, str], type] = {} + + +def register_pipeline_model(model_type: str, task: str): + """Class decorator that registers a pipeline model for `wmk config`.""" + + def decorator(cls: type) -> type: + PIPELINE_MODEL_REGISTRY[(model_type, task)] = cls + return cls + + return decorator + + +# ========================================================================= +# WinMLPipelineModel — multi-component base +# ========================================================================= + + +class WinMLPipelineModel(PreTrainedModel): + """Base class for models composed of multiple WinMLAutoModel sub-components. + + Subclasses declare ``_SUB_MODEL_CONFIG``: a mapping of component name to + the HF task used to build it via ``WinMLAutoModel.from_pretrained``. + + After construction, sub-components are available in ``self.sub_models``. + """ + + _SUB_MODEL_CONFIG: ClassVar[dict[str, str]] = {} + + def __init__( + self, + sub_models: dict[str, Any], + config: PretrainedConfig, + ) -> None: + self.sub_models = sub_models + self.config = config + + @classmethod + def from_pretrained( + cls, + model_id: str, + task: str, + *, + device: str = "cpu", + use_cache: bool = True, + force_rebuild: bool = False, + sub_model_kwargs: dict[str, dict[str, Any]] | None = None, + **kwargs: Any, + ) -> WinMLPipelineModel: + """Build all sub-components and return ready-to-use model. + + When called on ``WinMLPipelineModel`` directly (not a subclass), + ``task`` is required to resolve the concrete class from + ``PIPELINE_MODEL_REGISTRY``. When called on a registered subclass + (e.g., ``WinMLT5Model``), ``task`` is optional. + + Args: + model_id: HuggingFace model ID or local path. + task: Pipeline task name (e.g., ``"translation"``, + ``"text-generation"``). Required when calling on the base + class; ignored when calling on a registered subclass. + device: Target device. + use_cache: Use persistent cache. + force_rebuild: Force rebuild even if cached. + sub_model_kwargs: Per-component kwargs forwarded to + ``WinMLAutoModel.from_pretrained()``. Keys are component + names from ``_SUB_MODEL_CONFIG`` (e.g., ``"decoder_prefill"``, + ``"decoder_gen"``). Values are dicts merged on top of the + shared ``**kwargs``. Use this to pass different + ``shape_config`` per sub-model. + **kwargs: Forwarded to ``WinMLAutoModel.from_pretrained()`` + for every sub-component (overridden by ``sub_model_kwargs``). + """ + from transformers import AutoConfig + + trust_remote_code = kwargs.get("trust_remote_code", False) + hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) + model_type = hf_config.model_type + + if not cls._SUB_MODEL_CONFIG: + # Resolve concrete class from registry when called on the base class + resolved_cls = PIPELINE_MODEL_REGISTRY.get((model_type, task)) + if resolved_cls is None: + raise ValueError( + f"No pipeline model registered for ({model_type!r}, {task!r}). " + f"Registered: {list(PIPELINE_MODEL_REGISTRY.keys())}" + ) + return resolved_cls.from_pretrained( + model_id, + task, + device=device, + use_cache=use_cache, + force_rebuild=force_rebuild, + sub_model_kwargs=sub_model_kwargs, + **kwargs, + ) + from ..auto import WinMLAutoModel + + per_component = sub_model_kwargs or {} + sub_models: dict[str, Any] = {} + for name, component_task in cls._SUB_MODEL_CONFIG.items(): + logger.info("Building %s for %s...", name, model_id) + merged = {**kwargs, **per_component.get(name, {})} + sub_models[name] = WinMLAutoModel.from_pretrained( + model_id, + task=component_task, + device=device, + use_cache=use_cache, + force_rebuild=force_rebuild, + **merged, + ) + + return cls(sub_models=sub_models, config=hf_config) + + @property + def device(self) -> torch.device: + """Device (CPU — ORT handles actual placement).""" + return torch.device("cpu") + + @property + def dtype(self) -> torch.dtype: + """Model dtype for HF compatibility.""" + return torch.float32 + + def to(self, *args: Any, **kwargs: Any) -> WinMLPipelineModel: + """No-op for HF pipeline compatibility.""" + return self + + def __call__(self, **kwargs: Any) -> Any: + """Inference entry point.""" + return self.forward(**kwargs) + + def forward(self, **kwargs: Any) -> Any: + """Subclasses implement task-specific logic.""" + raise NotImplementedError + + @staticmethod + def _pad_inputs( + source: dict[str, Any], + expected: dict[str, list[int]], + ) -> dict[str, Any]: + """Filter *source* to keys in *expected* and pad undersized tensors. + + For each name in *expected*, if *source* has a tensor for it, pad + any dimension smaller than the ONNX expected shape (skips batch dim). + Non-tensor values are passed through. Missing names are skipped. + """ + result: dict[str, Any] = {} + for name, expected_shape in expected.items(): + val = source.get(name) + if val is None: + continue + if isinstance(val, torch.Tensor): + # TODO: support dynamic shape ONNX models (None in expected_shape) + ndim = min(len(val.shape), len(expected_shape)) + pad: list[int] = [] + for dim in reversed(range(1, ndim)): + deficit = expected_shape[dim] - val.shape[dim] + pad.extend([0, max(deficit, 0)]) + if any(p > 0 for p in pad): + val = torch.nn.functional.pad(val, pad) + result[name] = val + return result diff --git a/tests/unit/export/test_io.py b/tests/unit/export/test_io.py index 66fe5d3e1..8f3a575fa 100644 --- a/tests/unit/export/test_io.py +++ b/tests/unit/export/test_io.py @@ -13,11 +13,13 @@ from __future__ import annotations +from types import SimpleNamespace from unittest.mock import patch import pytest import torch from transformers import ( + AutoConfig, CLIPTextConfig, CLIPTextModelWithProjection, CLIPVisionConfig, @@ -36,6 +38,7 @@ _get_onnx_config, _populate_image_size_from_preprocessor, ) +from winml.modelkit.models.hf.kv_cache import PastKeyValueInputGenerator # ============================================================================= @@ -672,3 +675,157 @@ def test_no_size_key_in_config(self) -> None: assert "height" not in shape_kwargs assert "width" not in shape_kwargs + + +# ============================================================================= +# PastKeyValueInputGenerator — shared KV cache dummy input generation +# ============================================================================= + + +def _make_normalized_config( + num_layers: int = 4, + num_attention_heads: int = 2, + head_dim: int = 32, + max_cache_len: int = 16, +) -> SimpleNamespace: + """Create a lightweight object that quacks like NormalizedConfig.""" + return SimpleNamespace( + num_layers=num_layers, + num_attention_heads=num_attention_heads, + head_dim=head_dim, + max_cache_len=max_cache_len, + ) + + +@pytest.fixture(scope="module") +def t5_config(): + """T5-small config with n_positions overridden to 32 for fast tests.""" + cfg = AutoConfig.from_pretrained("google-t5/t5-small") + cfg.n_positions = 32 + return cfg + + +@pytest.fixture(scope="module") +def qwen_config(): + """Qwen3-0.6B config with max_position_embeddings overridden to 256.""" + cfg = AutoConfig.from_pretrained("Qwen/Qwen3-0.6B") + cfg.max_position_embeddings = 256 + return cfg + + +class TestPastKeyValueInputGenerator: + """Direct tests for PastKeyValueInputGenerator.""" + + def test_supported_input_names(self) -> None: + nc = _make_normalized_config(num_layers=3) + gen = PastKeyValueInputGenerator("text-generation", nc) + expected = ( + "past_0_key", + "past_0_value", + "past_1_key", + "past_1_value", + "past_2_key", + "past_2_value", + ) + assert expected == gen.SUPPORTED_INPUT_NAMES + + def test_generate_key_shape(self) -> None: + nc = _make_normalized_config( + num_layers=2, + num_attention_heads=4, + head_dim=16, + max_cache_len=64, + ) + gen = PastKeyValueInputGenerator("text-generation", nc, batch_size=2) + tensor = gen.generate("past_0_key") + assert tensor.shape == (2, 4, 64, 16) + + def test_generate_value_shape(self) -> None: + nc = _make_normalized_config( + num_layers=2, + num_attention_heads=4, + head_dim=16, + max_cache_len=64, + ) + gen = PastKeyValueInputGenerator("text-generation", nc, batch_size=1) + tensor = gen.generate("past_1_value") + assert tensor.shape == (1, 4, 64, 16) + + def test_generate_returns_float_tensor(self) -> None: + nc = _make_normalized_config() + gen = PastKeyValueInputGenerator("text-generation", nc) + tensor = gen.generate("past_0_key") + assert isinstance(tensor, torch.Tensor) + assert tensor.dtype == torch.float32 + + def test_single_layer(self) -> None: + nc = _make_normalized_config(num_layers=1) + gen = PastKeyValueInputGenerator("text-generation", nc) + assert gen.SUPPORTED_INPUT_NAMES == ("past_0_key", "past_0_value") + + def test_batch_size_propagated(self) -> None: + nc = _make_normalized_config() + gen = PastKeyValueInputGenerator("text-generation", nc, batch_size=8) + assert gen.batch_size == 8 + tensor = gen.generate("past_0_key") + assert tensor.shape[0] == 8 + + +class TestT5DecoderKVInputs: + """T5 decoder dummy inputs use PastKeyValueInputGenerator.""" + + def test_kv_input_names(self, t5_config) -> None: + inputs = generate_dummy_inputs("t5", "text2text-generation", t5_config) + num_layers = t5_config.num_layers # 6 + for i in range(num_layers): + assert f"past_{i}_key" in inputs + assert f"past_{i}_value" in inputs + + def test_kv_shape(self, t5_config) -> None: + inputs = generate_dummy_inputs("t5", "text2text-generation", t5_config) + kv = inputs["past_0_key"] + # [batch=1, heads=8, max_cache_len=32, d_kv=64] + assert kv.shape == (1, t5_config.num_heads, 32, t5_config.d_kv) + + def test_decoder_attention_mask_matches_cache_len(self, t5_config) -> None: + inputs = generate_dummy_inputs("t5", "text2text-generation", t5_config) + assert inputs["decoder_attention_mask"].shape[1] == 32 + + def test_all_kv_layers_present(self, t5_config) -> None: + inputs = generate_dummy_inputs("t5", "text2text-generation", t5_config) + kv_names = [n for n in inputs if n.startswith("past_")] + assert len(kv_names) == t5_config.num_layers * 2 + + +class TestQwenPrefillKVInputs: + """Qwen3 prefill dummy inputs use PastKeyValueInputGenerator.""" + + def test_kv_input_names(self, qwen_config) -> None: + inputs = generate_dummy_inputs("qwen3", "feature-extraction", qwen_config) + num_layers = qwen_config.num_hidden_layers # 28 + for i in range(num_layers): + assert f"past_{i}_key" in inputs + assert f"past_{i}_value" in inputs + + def test_kv_shape(self, qwen_config) -> None: + inputs = generate_dummy_inputs("qwen3", "feature-extraction", qwen_config) + kv = inputs["past_0_key"] + # [batch=1, kv_heads=8, max_cache_len=256, head_dim=128] + assert kv.shape == (1, qwen_config.num_key_value_heads, 256, qwen_config.head_dim) + + def test_attention_mask_matches_cache_len(self, qwen_config) -> None: + inputs = generate_dummy_inputs("qwen3", "feature-extraction", qwen_config) + assert inputs["attention_mask"].shape[1] == 256 + + +class TestQwenGenKVInputs: + """Qwen3 generation dummy inputs use PastKeyValueInputGenerator.""" + + def test_kv_shape_matches_prefill(self, qwen_config) -> None: + inputs = generate_dummy_inputs("qwen3", "text-generation", qwen_config) + kv = inputs["past_0_key"] + assert kv.shape == (1, qwen_config.num_key_value_heads, 256, qwen_config.head_dim) + + def test_input_ids_single_token(self, qwen_config) -> None: + inputs = generate_dummy_inputs("qwen3", "text-generation", qwen_config) + assert inputs["input_ids"].shape == (1, 1) From eed03eb411de1bff03c0b30b56b51306685ca6b3 Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Tue, 14 Apr 2026 12:03:35 +0800 Subject: [PATCH 02/32] feat: fp16 KV cache support and session_options passthrough Derive KV cache dtype from ONNX model metadata instead of hardcoding float32, enabling fp16 models to use fp16 StaticCache. Thread session_options from WinMLAutoModel.from_onnx through to WinMLSession. --- src/winml/modelkit/models/auto.py | 3 +++ src/winml/modelkit/models/hf/encoder_decoder.py | 12 ++++++++++-- src/winml/modelkit/models/winml/base.py | 3 +++ src/winml/modelkit/models/winml/decoder_only.py | 10 +++++++++- 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index 45db1299b..800f2555a 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -107,6 +107,7 @@ def from_onnx( use_cache: bool = True, force_rebuild: bool = False, skip_build: bool = False, + session_options: Any | None = None, **kwargs: Any, ) -> WinMLPreTrainedModel: """Build from a pre-exported ONNX file. @@ -165,6 +166,7 @@ def from_onnx( onnx_path=onnx_path, config=None, device=device, + session_options=session_options, ) # Resolve output directory @@ -200,6 +202,7 @@ def from_onnx( onnx_path=result.final_onnx_path, config=None, # No HF PretrainedConfig for bare ONNX builds device=device, + session_options=session_options, ) @classmethod diff --git a/src/winml/modelkit/models/hf/encoder_decoder.py b/src/winml/modelkit/models/hf/encoder_decoder.py index 7710bee00..b8ca897c3 100644 --- a/src/winml/modelkit/models/hf/encoder_decoder.py +++ b/src/winml/modelkit/models/hf/encoder_decoder.py @@ -191,11 +191,19 @@ def __init__( zip(dec_io.get("input_names", []), dec_io.get("input_shapes", []), strict=False) ) - # Max decode length from decoder ONNX KV input shape + # Max decode length and KV dtype from decoder ONNX metadata self._max_dec = self._dec_expected["past_0_key"][2] self._num_kv_layers = sum( 1 for n in self._dec_expected if n.startswith("past_") and n.endswith("_key") ) + # Resolve KV cache dtype from ONNX input types (fp32 or fp16) + dec_type_map = dict( + zip(dec_io.get("input_names", []), dec_io.get("input_types", []), strict=False) + ) + import numpy as np + + _np_dtype = dec_type_map.get("past_0_key", np.float32) + self._kv_dtype = torch.from_numpy(np.zeros(1, dtype=_np_dtype)).dtype # ----- Encoder ----- @@ -285,7 +293,7 @@ def forward( batch_size=1, num_heads=kv_shape[1], head_dim=kv_shape[3], - dtype=torch.float32, + dtype=self._kv_dtype, device=torch.device("cpu"), ) diff --git a/src/winml/modelkit/models/winml/base.py b/src/winml/modelkit/models/winml/base.py index 4d4e892bd..2f4d973d2 100644 --- a/src/winml/modelkit/models/winml/base.py +++ b/src/winml/modelkit/models/winml/base.py @@ -64,6 +64,7 @@ def __init__( onnx_path: str | Path, config: PretrainedConfig | None = None, device: str = "auto", + session_options: Any | None = None, ) -> None: """Initialize inference model. @@ -71,6 +72,7 @@ def __init__( onnx_path: Path to ONNX model file config: HuggingFace PretrainedConfig (num_labels, id2label, etc.) device: Target device ("auto", "npu", "gpu", "cpu") + session_options: ORT SessionOptions (e.g., for graph_optimization_level) """ self._onnx_path = Path(onnx_path) self.config = config @@ -80,6 +82,7 @@ def __init__( self._session = WinMLSession( onnx_path=self._onnx_path, device=device, + session_options=session_options, ) @property diff --git a/src/winml/modelkit/models/winml/decoder_only.py b/src/winml/modelkit/models/winml/decoder_only.py index 18cf9c21e..4308f7fa2 100644 --- a/src/winml/modelkit/models/winml/decoder_only.py +++ b/src/winml/modelkit/models/winml/decoder_only.py @@ -204,6 +204,14 @@ def __init__( self._num_kv_layers = sum( 1 for n in self._gen_expected if n.startswith("past_") and n.endswith("_key") ) + # Resolve KV cache dtype from ONNX input types (fp32 or fp16) + gen_type_map = dict( + zip(gen_io.get("input_names", []), gen_io.get("input_types", []), strict=False) + ) + import numpy as np + + _np_dtype = gen_type_map.get("past_0_key", np.float32) + self._kv_dtype = torch.from_numpy(np.zeros(1, dtype=_np_dtype)).dtype # Prefill chunk size self._prefill_seq_len = self._prefill_expected["input_ids"][1] @@ -270,7 +278,7 @@ def forward( batch_size=1, num_heads=self._num_kv_heads, head_dim=self._head_dim, - dtype=torch.float32, + dtype=self._kv_dtype, device=torch.device("cpu"), ) From d4064348a1efa20cfa6b26e2cd0cca0a05b53666 Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Tue, 14 Apr 2026 13:20:40 +0800 Subject: [PATCH 03/32] revert: remove subfolder from config command (not part of this PR) --- src/winml/modelkit/commands/config.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/winml/modelkit/commands/config.py b/src/winml/modelkit/commands/config.py index 3e28f823a..21c777204 100644 --- a/src/winml/modelkit/commands/config.py +++ b/src/winml/modelkit/commands/config.py @@ -166,11 +166,6 @@ def _is_onnx_file(model_input: str) -> bool: default=False, help="Allow running custom code from model repository", ) -@click.option( - "--subfolder", - default=None, - help="Subfolder within HF repo to load from (e.g., 'text_encoder' for Stable Diffusion).", -) def config( hf_model: str | None, task: str | None, @@ -188,7 +183,6 @@ def config( no_quant: bool, no_compile: bool, trust_remote_code: bool, - subfolder: str | None, ) -> None: r"""Generate WinMLBuildConfig for a HuggingFace model or .onnx file. @@ -347,7 +341,6 @@ def config( precision=precision, trust_remote_code=trust_remote_code, ep=ep, - subfolder=subfolder, ) # Handle output format From 4488fd2bd5e3a70988b986d04247ce777db02dbb Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Wed, 15 Apr 2026 14:09:59 +0800 Subject: [PATCH 04/32] refactor: WinMLCache hierarchy with polymorphic interface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce WinMLCache abstract base (step, num_layers, build_decoder_mask, update_all_layers, reset, create). Two concrete implementations: - WinMLStaticCache (ScatterElements/index_copy_, T5/Qwen) - WinMLSlidingWindowCache (Slice+Concat FIFO, Mu2) WinMLEncoderDecoderModel.forward is now cache-agnostic — calls only WinMLCache interface methods. Each model subclass declares get_cache_class(). No isinstance checks in the forward path. --- .../modelkit/models/hf/encoder_decoder.py | 87 ++++---- src/winml/modelkit/models/hf/kv_cache.py | 197 ++++++++++++++++-- src/winml/modelkit/models/hf/mu2.py | 49 +++-- src/winml/modelkit/models/hf/qwen.py | 11 +- src/winml/modelkit/models/hf/t5.py | 9 +- 5 files changed, 260 insertions(+), 93 deletions(-) diff --git a/src/winml/modelkit/models/hf/encoder_decoder.py b/src/winml/modelkit/models/hf/encoder_decoder.py index b8ca897c3..279297502 100644 --- a/src/winml/modelkit/models/hf/encoder_decoder.py +++ b/src/winml/modelkit/models/hf/encoder_decoder.py @@ -67,16 +67,16 @@ import torch from optimum.utils.input_generators import DummyInputGenerator -from transformers import Cache, StaticCache from transformers.generation.utils import GenerationMixin from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from ..winml.pipeline_model import WinMLPipelineModel +from .kv_cache import WinMLStaticCache if TYPE_CHECKING: from optimum.utils import NormalizedConfig - from transformers import PretrainedConfig + from transformers import Cache, PretrainedConfig logger = logging.getLogger(__name__) @@ -100,6 +100,7 @@ class EncoderDecoderInputGenerator(DummyInputGenerator): "attention_mask", "decoder_attention_mask", "cache_position", + "position_id", ) def __init__( @@ -144,6 +145,8 @@ def generate( return torch.ones(self.batch_size, self.max_cache_len, dtype=torch.int64) if input_name == "cache_position": return torch.tensor([5], dtype=torch.int64) # arbitrary position for tracing + if input_name == "position_id": + return torch.tensor([5], dtype=torch.int64) # absolute seq position for RoPE raise ValueError(f"Unknown input: {input_name}") @@ -248,6 +251,36 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, } + # ----- Cache management ----- + + @classmethod + def get_cache_class(cls) -> type: + """Return the WinMLCache subclass for this model. Subclasses override.""" + return WinMLStaticCache + + def _resolve_cache(self, past_key_values: Any) -> Any: + """Unwrap or create the WinMLCache for this generation step. + + 1. Unwrap EncoderDecoderCache wrapper (GenerationMixin may add it). + 2. If already a WinMLCache, return directly. + 3. Otherwise create a fresh one and reset it. + """ + from .kv_cache import WinMLCache + + # (1) Unwrap EncoderDecoderCache + if hasattr(past_key_values, "self_attention_cache"): + past_key_values = past_key_values.self_attention_cache + + # (2) Already our cache — return as-is + if isinstance(past_key_values, WinMLCache): + return past_key_values + + # (3) Create fresh cache and reset + kv_shape = self._dec_expected["past_0_key"] + cache = self.get_cache_class().create(self.config, kv_shape, self._kv_dtype) + cache.reset() + return cache + # ----- Forward (decoder via WinMLAutoModel + KV cache) ----- def forward( @@ -275,55 +308,25 @@ def forward( raise ValueError("Either encoder_outputs or input_ids required") enc_h = encoder_outputs["last_hidden_state"] - # Resolve the self-attention cache. - # GenerationMixin may pass None, a StaticCache, or an - # EncoderDecoderCache wrapping a DynamicCache (auto-created). - cache = None - if isinstance(past_key_values, StaticCache): - cache = past_key_values - elif hasattr(past_key_values, "self_attention_cache"): - sa = past_key_values.self_attention_cache - if isinstance(sa, StaticCache): - cache = sa - if cache is None: - # Read KV geometry from ONNX metadata (architecture-agnostic) - kv_shape = self._dec_expected["past_0_key"] # [batch, heads, max_dec, head_dim] - cache = StaticCache(self.config, max_cache_len=self._max_dec) - cache.early_initialization( - batch_size=1, - num_heads=kv_shape[1], - head_dim=kv_shape[3], - dtype=self._kv_dtype, - device=torch.device("cpu"), - ) + # Resolve or create cache (subclasses override _create_cache). + cache = self._resolve_cache(past_key_values) - # Determine write position from cache occupancy - fc = cache.get_seq_length() - dec_mask = torch.zeros(1, self._max_dec, dtype=torch.int64) - dec_mask[0, : fc + 1] = 1 + fc = cache.step + dec_mask = cache.build_decoder_mask(self._max_dec) - # Build feeds: model_kwargs first, then fill in generated inputs feeds: dict[str, Any] = dict(model_kwargs) feeds.setdefault("encoder_hidden_states", enc_h.detach()) feeds.setdefault("decoder_attention_mask", dec_mask) - feeds.setdefault("cache_position", torch.tensor([fc], dtype=torch.int64)) + feeds.setdefault(cache.position_input_name, torch.tensor([fc], dtype=torch.int64)) for i in range(self._num_kv_layers): - layer = cache.layers[i] - feeds[f"past_{i}_key"] = layer.keys.detach() - feeds[f"past_{i}_value"] = layer.values.detach() + feeds[f"past_{i}_key"] = cache.layers[i].keys.detach() + feeds[f"past_{i}_value"] = cache.layers[i].values.detach() - # Filter to decoder ONNX inputs and pad any undersized tensors + # Run decoder ONNX (pad_inputs filters to expected names + pads) outputs = self._decoder(**self._pad_inputs(feeds, self._dec_expected)) - # Write new token's KV into the StaticCache in-place - cache_kwargs = {"cache_position": torch.tensor([fc], dtype=torch.int64)} - for i in range(self._num_kv_layers): - cache.update( - outputs[f"present_{i}_key"], - outputs[f"present_{i}_value"], - layer_idx=i, - cache_kwargs=cache_kwargs, - ) + # Write present KV back and advance step + cache.update_all_layers(outputs) return Seq2SeqLMOutput( logits=outputs["logits"], diff --git a/src/winml/modelkit/models/hf/kv_cache.py b/src/winml/modelkit/models/hf/kv_cache.py index f4bc53289..5b0a6e439 100644 --- a/src/winml/modelkit/models/hf/kv_cache.py +++ b/src/winml/modelkit/models/hf/kv_cache.py @@ -2,19 +2,22 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Shared KV cache utilities for ONNX export wrappers. +"""Shared KV cache utilities for ONNX export and inference. -Provides ``CapturingStaticCache`` — a ``StaticCache`` subclass that captures -each layer's new-token KV from ``update()``, eliminating the scatter→gather -round-trip in the exported ONNX graph. +Hierarchy:: + + StaticCache (HF transformers) + └─ WinMLCache — common interface + ├─ WinMLStaticCache — ScatterElements (index_copy_), T5/Qwen + └─ WinMLSlidingWindowCache — Slice+Concat (FIFO), Mu2 Also provides ``PastKeyValueInputGenerator`` — a reusable ``DummyInputGenerator`` -for static KV cache inputs (``past_{i}_key``, ``past_{i}_value``), shared by -T5, Qwen, and future models with static KV cache export. +for static KV cache inputs (``past_{i}_key``, ``past_{i}_value``). """ from __future__ import annotations +from abc import abstractmethod from typing import TYPE_CHECKING, Any from optimum.utils.input_generators import DummyInputGenerator @@ -24,21 +27,89 @@ if TYPE_CHECKING: import torch from optimum.utils import NormalizedConfig + from transformers import PretrainedConfig + + +# ============================================================================= +# WinMLCache — common interface +# ============================================================================= + + +class WinMLCache(StaticCache): + """Abstract base for WinML KV caches (export + inference). + + Subclasses set ``position_input_name`` and implement + ``build_decoder_mask`` and ``update_all_layers``. + + ``step`` tracks the absolute generation position + (used for RoPE and mask construction). + ``num_layers`` is set from ``config.num_hidden_layers``. + """ + + #: ONNX input name for the position tensor (subclasses override). + position_input_name: str + + def __init__(self, config: PretrainedConfig, *args: Any, **kwargs: Any) -> None: + super().__init__(config, *args, **kwargs) + self.step: int = 0 + self.num_layers: int = config.num_hidden_layers + + # ----- Interface for WinMLEncoderDecoderModel.forward ----- + + @abstractmethod + def build_decoder_mask(self, max_len: int) -> torch.Tensor: + """Build the decoder attention mask for the current step.""" + + @abstractmethod + def update_all_layers(self, outputs: dict[str, Any]) -> None: + """Write present KV for all layers from ONNX output and advance step.""" + def reset(self) -> None: + """Zero out all layers and reset step (start of new generation).""" + self.step = 0 + for i in range(self.num_layers): + self.layers[i].keys.zero_() + self.layers[i].values.zero_() -class CapturingStaticCache(StaticCache): - """StaticCache that captures each layer's new-token KV from ``update()``. + @classmethod + def create( + cls, config: PretrainedConfig, kv_shape: list[int], dtype: torch.dtype + ) -> WinMLCache: + """Create and initialize a cache from ONNX KV shape metadata. - Standard ``StaticCache.update()`` does ``index_copy_`` (ScatterElements in - ONNX) to write the new KV into the full buffer, then returns the full - buffer for attention. The old approach then used ``gather`` - (GatherElements) to extract the same KV back — a pointless round-trip. + Args: + config: HF model config (must have ``num_hidden_layers``). + kv_shape: ``[batch, heads, max_cache_len, head_dim]`` from ONNX. + dtype: KV dtype (fp32 or fp16). + """ + import torch - This subclass intercepts ``update()`` to save the *incoming* - ``key_states`` / ``value_states`` before they enter the buffer, so the - wrapper can return them directly as ONNX outputs. + cache = cls(config, max_cache_len=kv_shape[2]) + cache.early_initialization( + batch_size=1, + num_heads=kv_shape[1], + head_dim=kv_shape[3], + dtype=dtype, + device=torch.device("cpu"), + ) + return cache + + +# ============================================================================= +# WinMLStaticCache — ScatterElements (index_copy_) +# ============================================================================= + + +class WinMLStaticCache(WinMLCache): + """Cache using ``index_copy_`` at ``cache_position`` (ScatterElements). + + **Export**: intercepts ``update()`` to capture incoming KV for ONNX output. + **Inference**: ``update_all_layers`` writes new-token KV at the current step. + Mask is left-aligned: ``[1, 1, ..., 1, 0, 0, ..., 0]``. """ + position_input_name: str = "cache_position" + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.captured: dict[int, tuple[torch.Tensor, torch.Tensor]] = {} @@ -54,14 +125,104 @@ def update( self.captured[layer_idx] = (key_states, value_states) return super().update(key_states, value_states, layer_idx, cache_kwargs) + def build_decoder_mask(self, max_len: int) -> torch.Tensor: + """Left-aligned: first ``step + 1`` positions are 1.""" + import torch + + mask = torch.zeros(1, max_len, dtype=torch.int64) + mask[0, : self.step + 1] = 1 + return mask + + def update_all_layers(self, outputs: dict[str, Any]) -> None: + """Write new-token KV at current step for all layers, then advance.""" + import torch + + ck = {"cache_position": torch.tensor([self.step], dtype=torch.int64)} + for i in range(self.num_layers): + k = outputs[f"present_{i}_key"] + v = outputs[f"present_{i}_value"] + k = k if isinstance(k, torch.Tensor) else torch.tensor(k) + v = v if isinstance(v, torch.Tensor) else torch.tensor(v) + super(WinMLCache, self).update(k, v, i, cache_kwargs=ck) + self.step += 1 + + +# ============================================================================= +# WinMLSlidingWindowCache — Slice + Concat (FIFO) +# ============================================================================= + + +class WinMLSlidingWindowCache(WinMLCache): + """FIFO cache: evict oldest, append new at end (Slice+Concat). + + **Export**: ``update()`` traces as Slice+Concat — no ScatterElements. + Present KV output is the full updated buffer. + **Inference**: ``update_all_layers`` replaces the full buffer. + Mask is right-aligned: ``[0, 0, ..., 0, 1, 1, ..., 1]``. + """ + + position_input_name: str = "position_id" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.updated: dict[int, tuple[torch.Tensor, torch.Tensor]] = {} + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Drop oldest entry, append new KV at end. Return full buffer.""" + import torch + + old_k = self.layers[layer_idx].keys[:, :, 1:, :] + new_k = torch.cat([old_k, key_states], dim=2) + self.layers[layer_idx].keys = new_k + + old_v = self.layers[layer_idx].values[:, :, 1:, :] + new_v = torch.cat([old_v, value_states], dim=2) + self.layers[layer_idx].values = new_v + + self.updated[layer_idx] = (new_k, new_v) + return new_k, new_v + + def build_decoder_mask(self, max_len: int) -> torch.Tensor: + """Right-aligned: rightmost ``step + 1`` positions are 1.""" + import torch + + mask = torch.zeros(1, max_len, dtype=torch.int64) + mask[0, max(0, max_len - self.step - 1) :] = 1 + return mask + + def update_all_layers(self, outputs: dict[str, Any]) -> None: + """Replace full KV buffers for all layers, then advance.""" + import torch + + for i in range(self.num_layers): + k = outputs[f"present_{i}_key"] + v = outputs[f"present_{i}_value"] + self.layers[i].keys = k if isinstance(k, torch.Tensor) else torch.tensor(k) + self.layers[i].values = v if isinstance(v, torch.Tensor) else torch.tensor(v) + self.step += 1 + + def get_seq_length(self, layer_idx: int = 0) -> int: + """Filled positions: ``min(step, max_cache_len)``.""" + max_len = self.layers[layer_idx].keys.shape[2] + return min(self.step, max_len) + + +# ============================================================================= +# PastKeyValueInputGenerator +# ============================================================================= + class PastKeyValueInputGenerator(DummyInputGenerator): """Generates ``past_{i}_key`` / ``past_{i}_value`` tensors for static KV cache. Reads ``num_layers``, ``num_attention_heads``, ``head_dim``, and - ``max_cache_len`` from the ``NormalizedConfig``. Each model's - ``NORMALIZED_CONFIG_CLASS`` maps these to the appropriate HF config fields - (e.g. T5: ``head_dim="d_kv"``, ``max_cache_len="n_positions"``). + ``max_cache_len`` from the ``NormalizedConfig``. """ SUPPORTED_INPUT_NAMES = () # dynamic — built in __init__ diff --git a/src/winml/modelkit/models/hf/mu2.py b/src/winml/modelkit/models/hf/mu2.py index c77ff774d..9db2687d9 100644 --- a/src/winml/modelkit/models/hf/mu2.py +++ b/src/winml/modelkit/models/hf/mu2.py @@ -17,7 +17,7 @@ The Mu2 model's native attention (MuAttentionSDPA) does NOT support HF's cache mechanism. The decoder wrapper reimplements the decoder forward pass -using the original layer weights, adding CapturingStaticCache for +using the original layer weights, adding WinMLCache for self-attention KV. Cross-attention KV is always recomputed from encoder_hidden_states (no cache needed). @@ -41,8 +41,7 @@ from ...export import register_onnx_overwrite from ..winml.pipeline_model import register_pipeline_model from .encoder_decoder import EncoderDecoderInputGenerator, WinMLEncoderDecoderModel -from .kv_cache import CapturingStaticCache as _CapturingStaticCache -from .kv_cache import PastKeyValueInputGenerator +from .kv_cache import PastKeyValueInputGenerator, WinMLSlidingWindowCache # ============================================================================= @@ -76,7 +75,7 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torc class Mu2DecoderWrapper(nn.Module): - """Wraps Mu2 decoder with CapturingStaticCache for ONNX export. + """Wraps Mu2 decoder for ONNX export. Delegates to the model's own decoder (which now accepts ``past_key_values`` and ``cache_position``). This wrapper just builds the cache from flat @@ -106,27 +105,28 @@ def get_export_args(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor return tuple(inputs.values()) def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: - """Run decoder with static KV cache. + """Run decoder with FIFO KV cache (Slice+Concat). Positional args (order matches OnnxConfig.inputs): decoder_input_ids, encoder_hidden_states, attention_mask (encoder), - decoder_attention_mask, cache_position, + decoder_attention_mask, position_id, past_0_key, past_0_value, past_1_key, past_1_value, ... Returns: (logits, present_0_key, present_0_value, ...) where each - present KV is [batch, n_kv_head, 1, head_dim]. + present KV is the full updated buffer [batch, n_kv_head, max_cache_len, head_dim] + (oldest entry evicted, new token appended at end). """ decoder_input_ids = args[0] encoder_hidden_states = args[1] encoder_attention_mask = args[2] # "attention_mask" in OnnxConfig decoder_attention_mask = args[3] - cache_position = args[4] + position_id = args[4] # absolute sequence position for RoPE kv_start = 5 - # Build CapturingStaticCache from input KV tensors - self_attn_cache = _CapturingStaticCache(self.config, max_cache_len=args[kv_start].size(2)) - self_attn_cache.early_initialization( + # Build WinMLSlidingWindowCache (FIFO: Slice+Concat instead of ScatterElements) + cache = WinMLSlidingWindowCache(self.config, max_cache_len=args[kv_start].size(2)) + cache.early_initialization( batch_size=decoder_input_ids.size(0), num_heads=self.config.n_kv_head, head_dim=self.config.head_dim, @@ -134,24 +134,25 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: device=decoder_input_ids.device, ) for i in range(self.num_layers): - self_attn_cache.layers[i].keys = args[kv_start + i * 2] - self_attn_cache.layers[i].values = args[kv_start + i * 2 + 1] + cache.layers[i].keys = args[kv_start + i * 2] + cache.layers[i].values = args[kv_start + i * 2 + 1] - # Delegate to model's decoder (now supports past_key_values + cache_position) + # Delegate to model's decoder — position_id is passed as cache_position + # for RoPE computation (WinMLSlidingWindowCache.update ignores it for indexing) hidden_states = self.model.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, - past_key_values=self_attn_cache, - cache_position=cache_position, + past_key_values=cache, + cache_position=position_id, ) logits = self.model.lm_head(hidden_states) - # Collect captured KV + # Output full updated cache buffers (not just new token) result: list[torch.Tensor] = [logits] for i in range(self.num_layers): - k, v = self_attn_cache.captured[i] + k, v = cache.updated[i] result.extend([k, v]) return tuple(result) @@ -210,7 +211,7 @@ def inputs(self) -> dict[str, dict[int, str]]: # noqa: D102 "encoder_hidden_states": {0: "batch_size"}, "attention_mask": {0: "batch_size"}, "decoder_attention_mask": {0: "batch_size"}, - "cache_position": {}, + "position_id": {}, } num_layers = self._normalized_config.num_layers for i in range(num_layers): @@ -240,10 +241,10 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 @register_pipeline_model("mu2", "translation") class WinMLMu2Model(WinMLEncoderDecoderModel): - """Mu2 encoder-decoder model for translation. + """Mu2 encoder-decoder model with sliding-window KV cache. - Declares Mu2 sub-component tasks and generation config defaults. - All encoder-decoder forward/cache logic lives in ``WinMLEncoderDecoderModel``. + Only differs from T5 in ``get_cache_class`` and ``_SUB_MODEL_CONFIG``. + All forward/cache logic lives in ``WinMLEncoderDecoderModel``. """ _SUB_MODEL_CONFIG: ClassVar[dict[str, str]] = { @@ -251,6 +252,10 @@ class WinMLMu2Model(WinMLEncoderDecoderModel): "decoder": "text2text-generation", } + @classmethod + def get_cache_class(cls) -> type: # noqa: D102 + return WinMLSlidingWindowCache + @property def generation_config(self): # noqa: D102 if not hasattr(self, "_generation_config"): diff --git a/src/winml/modelkit/models/hf/qwen.py b/src/winml/modelkit/models/hf/qwen.py index 35cbcd956..c54ad3285 100644 --- a/src/winml/modelkit/models/hf/qwen.py +++ b/src/winml/modelkit/models/hf/qwen.py @@ -15,14 +15,14 @@ → generation ONNX (input_ids [1, 1] → logits [1, 1, vocab] + KV [1, kv_heads, 1, head_dim]) Both tasks share the same wrapper class; OnnxConfig determines static shapes. -Uses ``CapturingStaticCache`` (from ``kv_cache.py``) to return new-token KV +Uses ``WinMLStaticCache`` (from ``kv_cache.py``) to return new-token KV directly as ONNX outputs, eliminating the scatter→gather round-trip. How it works: 1. ``QwenDecoderWrapper.forward()`` takes positional args (order matches OnnxConfig.inputs): input_ids, attention_mask, position_ids, cache_position, - past_0_key, past_0_value, ... It builds a ``CapturingStaticCache`` from the + past_0_key, past_0_value, ... It builds a ``WinMLStaticCache`` from the input KV buffers, runs ``Qwen3ForCausalLM``, and returns logits + captured KV. 2. Decoder-only models need NO ``EncoderDecoderCache`` wrapping — @@ -90,8 +90,7 @@ WinMLDecoderOnlyModel, ) from ..winml.pipeline_model import register_pipeline_model -from .kv_cache import CapturingStaticCache as _CapturingStaticCache -from .kv_cache import PastKeyValueInputGenerator +from .kv_cache import PastKeyValueInputGenerator, WinMLStaticCache # ============================================================================= @@ -146,9 +145,9 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: cache_position = args[3] kv_start = 4 - # Build CapturingStaticCache from input KV tensors. + # Build WinMLStaticCache from input KV tensors. # Decoder-only: pass StaticCache directly (no EncoderDecoderCache needed). - cache = _CapturingStaticCache(self.config, max_cache_len=args[kv_start].size(2)) + cache = WinMLStaticCache(self.config, max_cache_len=args[kv_start].size(2)) cache.early_initialization( batch_size=input_ids.size(0), num_heads=args[kv_start].size(1), diff --git a/src/winml/modelkit/models/hf/t5.py b/src/winml/modelkit/models/hf/t5.py index 02c142f97..562bc4063 100644 --- a/src/winml/modelkit/models/hf/t5.py +++ b/src/winml/modelkit/models/hf/t5.py @@ -37,8 +37,7 @@ from ...export import register_onnx_overwrite from ..winml.pipeline_model import register_pipeline_model from .encoder_decoder import EncoderDecoderInputGenerator, WinMLEncoderDecoderModel -from .kv_cache import CapturingStaticCache as _CapturingStaticCache -from .kv_cache import PastKeyValueInputGenerator +from .kv_cache import PastKeyValueInputGenerator, WinMLStaticCache # ============================================================================= @@ -131,11 +130,11 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: cache_position = args[4] kv_start = 5 - # Build CapturingStaticCache from input KV tensors. + # Build WinMLStaticCache from input KV tensors. # update() uses index_copy_ at cache_position for correct attention, # and captures the incoming key/value states for direct output # (eliminating the old scatter→gather round-trip in the ONNX graph). - self_attn_cache = _CapturingStaticCache(self.config, max_cache_len=args[kv_start].size(2)) + self_attn_cache = WinMLStaticCache(self.config, max_cache_len=args[kv_start].size(2)) self_attn_cache.early_initialization( batch_size=decoder_input_ids.size(0), num_heads=args[kv_start].size(1), @@ -167,7 +166,7 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: # Return new-token KV directly from the capturing cache. # The old approach did gather(ScatterElements output) — a round-trip. - # _CapturingStaticCache already saved the incoming key/value states. + # WinMLStaticCache already saved the incoming key/value states. result: list[torch.Tensor] = [out.logits] for i in range(self.num_layers): k, v = self_attn_cache.captured[i] From 4ab05cd5cb2fad7b82faaa5d267b8017d4f4b2dd Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Wed, 15 Apr 2026 14:47:07 +0800 Subject: [PATCH 05/32] docs: consolidate pipeline design docs into source file docstrings Move content from docs/design/pipeline-model.md into the files that own the code: kv_cache.py (cache compatibility), encoder_decoder.py (forward gotchas), mu2.py (custom model integration), pipeline_model.py (registry + sub_model_kwargs). Add T5 get_cache_class with explanation of why sliding window is incompatible. Remove the now-redundant design doc. --- docs/design/pipeline-model.md | 133 ------------------ .../modelkit/models/hf/encoder_decoder.py | 92 ++++++------ src/winml/modelkit/models/hf/kv_cache.py | 21 ++- src/winml/modelkit/models/hf/mu2.py | 59 ++++---- src/winml/modelkit/models/hf/t5.py | 12 ++ .../modelkit/models/winml/pipeline_model.py | 24 +++- 6 files changed, 129 insertions(+), 212 deletions(-) delete mode 100644 docs/design/pipeline-model.md diff --git a/docs/design/pipeline-model.md b/docs/design/pipeline-model.md deleted file mode 100644 index 45773f280..000000000 --- a/docs/design/pipeline-model.md +++ /dev/null @@ -1,133 +0,0 @@ -# Multi-Model Pipeline Design - -## Problem - -`WinMLAutoModel.from_pretrained` builds ONE ONNX model. Multi-component -architectures (T5 encoder+decoder, SD text_encoder+unet+vae) need multiple -ONNX models composed together. - -## Class Hierarchy - -``` -WinMLPipelineModel(PreTrainedModel) — multi-component base - └─ WinMLEncoderDecoderModel(GenerationMixin) — encoder-decoder with StaticCache - └─ WinMLT5Model — T5 tasks + generation config -``` - -- **WinMLPipelineModel**: `_SUB_MODEL_CONFIG` mapping, `from_pretrained` builds - each component via `WinMLAutoModel`, provides `device`/`to`/`dtype`. -- **WinMLEncoderDecoderModel**: `forward()` with StaticCache KV management, - `_EncoderWithInputPadding` wrapper, `get_encoder()`, `prepare_inputs_for_generation()`. - Auto-pads undersized inputs to ONNX expected shapes via `_pad_inputs`. -- **WinMLT5Model**: declares `_SUB_MODEL_CONFIG` and `generation_config` only. - -## Registry - -`@register_pipeline_model(model_type, task)` registers a pipeline class. -`winml config` checks the registry to generate per-component configs. - -```python -@register_pipeline_model("t5", "translation") -class WinMLT5Model(WinMLEncoderDecoderModel): - _SUB_MODEL_CONFIG = { - "encoder": "feature-extraction", - "decoder": "text2text-generation", - } -``` - -## ONNX Export - -Each component is exported independently via the existing pipeline -(export → optimize → compile). Export wrappers in `models/hf/t5.py`: - -| Component | Class | Description | -|---|---|---| -| Encoder | `T5EncoderWrapper` | `forward(input_ids, attention_mask) → encoder_hidden_states` | -| Decoder | `T5DecoderWrapper` | StaticCache + EncoderDecoderCache from flat KV inputs, extracts new token KV via `gather` | -| Decoder IO | `T5DecoderIOConfig` | OnnxConfig with custom DummyInputGenerators for KV cache tensors | - -### Decoder ONNX I/O (all static shapes) - -``` -Inputs: - decoder_input_ids [1, 1] - encoder_hidden_states [1, enc_seq, d_model] - attention_mask [1, enc_seq] - decoder_attention_mask [1, max_decode] - cache_position [1] - past_{i}_key [1, heads, max_decode, d_kv] # i=0..num_layers-1 - past_{i}_value [1, heads, max_decode, d_kv] - -Outputs: - logits [1, 1, vocab_size] - present_{i}_key [1, heads, 1, d_kv] # new token only - present_{i}_value [1, heads, 1, d_kv] -``` - -Cross-attention KV is always recomputed from `encoder_hidden_states` -(empty cross-attention cache → `is_updated=False` → never constant-folded). - -## KV Cache Design - -Uses HF `StaticCache` for both export and inference: - -- **Export**: `StaticCache.update()` uses `index_copy_` which traces correctly - in `torch.onnx.export`. `KV_index = sequence_position` always holds, so T5's - relative position bias computes correct distances. -- **Inference**: Same `StaticCache` object persists across generation steps, - mutated in-place via `cache.update()`. `get_seq_length()` counts non-zero - positions automatically. -- **GenerationMixin integration**: `StaticCache` flows through the generate loop - via `Seq2SeqLMOutput.past_key_values`. GenerationMixin may wrap it in an - `EncoderDecoderCache`; `forward()` unwraps to find the `StaticCache`. - -Known limitation: OpenVINO EP does not support ScatterElements, requires CPU -EP fallback for decoder inference. - -## Usage - -### 1. Generate configs (one per component) - -``` -winml config -m google-t5/t5-small --task translation --device cpu -o t5.json -``` - -Produces two files: -- `t5_encoder.json` — task `feature-extraction` -- `t5_decoder.json` — task `text2text-generation` - -### 2. Build ONNX models independently - -``` -winml build -c t5_encoder.json -m google-t5/t5-small -o output/encoder -winml build -c t5_decoder.json -m google-t5/t5-small -o output/decoder -``` - -### 3. Run translation pipeline - -```python -from winml.modelkit.models.winml.seq2seq import WinMLPipelineModel -from transformers import AutoTokenizer, pipeline - -model = WinMLPipelineModel.from_pretrained("google-t5/t5-small", "translation") -tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") - -pipe = pipeline("translation_en_to_fr", model=model, tokenizer=tokenizer) -result = pipe("Hello, how are you?", num_beams=1) -print(result[0]["translation_text"]) -# Bonjour, comment êtes-vous ? -``` - -`from_pretrained` resolves the concrete class from `PIPELINE_MODEL_REGISTRY`, -builds both ONNX sub-models via `WinMLAutoModel`, and returns a model that -plugs into HF `transformers.pipeline` as a drop-in replacement for -`T5ForConditionalGeneration`. - -## Open Questions - -- Manage KV cache and attention mask jointly in same cache class? -- Update KV in numpy to avoid pytorch tensor <-> numpy array round trip? -- Handle quantized cache (channel-wise quantization for accuracy)? -- EP-specific KV cache management to avoid ORT <-> EP round trip? -- Beam search support (requires dynamic batch)? -- Is it possible/better to use a shared model class for both export and inference? diff --git a/src/winml/modelkit/models/hf/encoder_decoder.py b/src/winml/modelkit/models/hf/encoder_decoder.py index 279297502..f76b0a4dd 100644 --- a/src/winml/modelkit/models/hf/encoder_decoder.py +++ b/src/winml/modelkit/models/hf/encoder_decoder.py @@ -4,60 +4,48 @@ # -------------------------------------------------------------------------- """WinML Encoder-Decoder inference model and shared input generator. -Provides ``WinMLEncoderDecoderModel`` — inference wrapper for encoder-decoder -pipelines (T5, mBART, etc.) with static KV cache, and -``EncoderDecoderInputGenerator`` — reusable ``DummyInputGenerator`` for -decoder base inputs shared across encoder-decoder architectures. - Class hierarchy:: - WinMLPipelineModel(PreTrainedModel) — multi-component base - └─ WinMLEncoderDecoderModel(GenerationMixin) — encoder-decoder with StaticCache - └─ WinMLT5Model (in t5.py) — T5 tasks + generation config - -How it works: - -1. Each pipeline model declares ``_SUB_MODEL_CONFIG = {"encoder": "feature-extraction", - "decoder": "text2text-generation"}``. ``from_pretrained()`` builds each component - via ``WinMLAutoModel`` (export → optimize → compile) independently. - -2. The encoder is wrapped in ``_EncoderWithInputPadding`` which reads ONNX input - names/shapes from ``io_config`` and zero-pads any undersized inputs. - -3. ``forward()`` takes ``(*, encoder_outputs, past_key_values, input_ids, **model_kwargs)`` - where ``model_kwargs`` carries decoder inputs like ``decoder_input_ids`` and - ``attention_mask``. Feeds are built from model_kwargs + generated inputs - (encoder_hidden_states, decoder_attention_mask, cache_position, KV cache), - filtered to decoder ONNX input names, and auto-padded. - -4. KV cache uses HF ``StaticCache`` — same class for both export (``index_copy_`` - traces correctly in ``torch.onnx.export``) and inference (mutated in-place via - ``cache.update()``). The ONNX decoder takes the full static buffer as input - and outputs only the new token's KV ``[batch, heads, 1, d_kv]``. - -Key findings from T5 KV cache study: - -- HF's ``DynamicCache`` is stateful (same object, mutated in-place via ``cat``). - ``GenerationMixin._update_model_kwargs_for_generation`` reads ``past_key_values`` - from the output and reassigns it in ``model_kwargs`` — but for stateful caches - it's the same reference. -- ``StaticCache`` uses ``index_copy_`` at ``cache_position`` (traces correctly). - ``StaticCache.get_seq_length()`` counts non-zero positions automatically. -- ``EncoderDecoderCache`` with empty cross-attn cache → ``is_updated`` dict is - empty → cross-attention always recomputed from ``encoder_hidden_states`` → - prevents constant-folding during ONNX export. -- ``GenerationMixin`` may wrap our ``StaticCache`` in an ``EncoderDecoderCache`` - before passing it back. ``forward()`` must unwrap to find the ``StaticCache``. -- ``TranslationPipeline`` passes its own ``generation_config`` with ``num_beams=4`` - to ``generate()``. Use ``num_beams=1`` at call time or override in subclass. - -Design principles: - -- NEVER guard config access with default values. Use ``self.config.param`` - directly and let AttributeError raise if the config is missing a field. -- ONNX I/O names and shapes are read from ``io_config``, never hardcoded. -- Inputs smaller than ONNX expected shape are zero-padded automatically. - Inputs larger than expected are NOT truncated — let ORT raise the error. + WinMLPipelineModel — multi-component base + └─ WinMLEncoderDecoderModel(GenerationMixin) — encoder-decoder inference + ├─ WinMLT5Model (t5.py) — WinMLStaticCache + └─ WinMLMu2Model (mu2.py) — WinMLSlidingWindowCache + +How ``forward()`` works: + +1. Encoder runs once (via ``get_encoder()``), hidden states cached by + GenerationMixin across decode steps. + +2. Each decode step: ``_resolve_cache`` unwraps GenerationMixin's + ``EncoderDecoderCache`` wrapper (or creates a fresh ``WinMLCache`` + on first call). Cache type is determined by ``get_cache_class()``. + +3. Feeds are built from ``model_kwargs`` (decoder_input_ids, attention_mask) + plus generated inputs (encoder_hidden_states, decoder_attention_mask, + position input, KV buffers). ``_pad_inputs`` filters to ONNX input + names and pads undersized tensors. + +4. After ONNX inference, ``cache.update_all_layers(outputs)`` writes + present KV back and advances step — fully polymorphic, no isinstance. + +Cache-type gotchas (lessons learned): + +- **GenerationMixin wraps cache**: On the first decode call, GenerationMixin + may pass an ``EncoderDecoderCache`` (not None). ``_resolve_cache`` must + unwrap it, and cache reset must check ``not isinstance(WinMLCache)``. + +- **Causal mask with seq_len=1**: ``torch.tril(ones(1, N))`` only keeps + column 0. For single-token KV-cached decoding, the decoder_attention_mask + alone is sufficient — no tril needed. + +- **RoPE position vs buffer position**: With ``WinMLSlidingWindowCache``, + the ONNX input is ``position_id`` (absolute sequence position for RoPE). + With ``WinMLStaticCache``, it's ``cache_position`` (= buffer position = + sequence position). + +- **T5 cannot use sliding window**: ``T5Attention.compute_bias`` assumes + ``buffer_position == sequence_position`` via ``arange(key_length)``. + See ``WinMLT5Model.get_cache_class()`` for details. """ from __future__ import annotations diff --git a/src/winml/modelkit/models/hf/kv_cache.py b/src/winml/modelkit/models/hf/kv_cache.py index 5b0a6e439..3e1e2e1b8 100644 --- a/src/winml/modelkit/models/hf/kv_cache.py +++ b/src/winml/modelkit/models/hf/kv_cache.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Shared KV cache utilities for ONNX export and inference. +"""WinML KV cache classes for ONNX export and inference. Hierarchy:: @@ -11,6 +11,25 @@ ├─ WinMLStaticCache — ScatterElements (index_copy_), T5/Qwen └─ WinMLSlidingWindowCache — Slice+Concat (FIFO), Mu2 +Cache type compatibility: + +- **WinMLStaticCache**: Required for models using learned relative position bias + (T5, mBART) where ``buffer_position == sequence_position`` must hold. + ``T5Attention.compute_bias`` uses ``memory_position = arange(key_length)`` + so KV entries must stay at their original buffer positions. + +- **WinMLSlidingWindowCache**: Compatible with models using RoPE (Mu2, Llama) + where position encoding is baked into K/V tensors. Buffer positions don't + matter — attention scores depend only on the RoPE embeddings in each K. + +Common interface (called by ``WinMLEncoderDecoderModel.forward``): + +- ``position_input_name``: ONNX input name (``"cache_position"`` or ``"position_id"``) +- ``build_decoder_mask(max_len)``: attention mask for current step +- ``update_all_layers(outputs)``: write present KV from ONNX output, advance step +- ``reset()``: zero out for new generation +- ``create(config, kv_shape, dtype)``: factory from ONNX metadata + Also provides ``PastKeyValueInputGenerator`` — a reusable ``DummyInputGenerator`` for static KV cache inputs (``past_{i}_key``, ``past_{i}_value``). """ diff --git a/src/winml/modelkit/models/hf/mu2.py b/src/winml/modelkit/models/hf/mu2.py index 9db2687d9..8fda3dd38 100644 --- a/src/winml/modelkit/models/hf/mu2.py +++ b/src/winml/modelkit/models/hf/mu2.py @@ -2,30 +2,41 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Mu2 HuggingFace Model Configuration. - -Provides encoder/decoder export wrappers and OnnxConfig registrations for -Mu2 encoder-decoder models with static KV cache. - -Export Strategy (split by task): -- Mu2EncoderWrapper + Mu2EncoderIOConfig: ``feature-extraction`` task - → encoder-only ONNX (input_ids, attention_mask → encoder_hidden_states) -- Mu2DecoderWrapper + Mu2DecoderIOConfig: ``text2text-generation`` task - → decoder ONNX with static KV buffer input + single-token KV output. - Input past KV: full static buffer [batch, n_kv_head, max_decode, head_dim]. - Output present KV: new token only [batch, n_kv_head, 1, head_dim]. - -The Mu2 model's native attention (MuAttentionSDPA) does NOT support HF's -cache mechanism. The decoder wrapper reimplements the decoder forward pass -using the original layer weights, adding WinMLCache for -self-attention KV. Cross-attention KV is always recomputed from -encoder_hidden_states (no cache needed). - -Model: local Mu2ForCausalLM with trust_remote_code=True. - -Usage: - wmk config -m path/to/mu2 --task feature-extraction → encoder - wmk config -m path/to/mu2 --task text2text-generation → decoder +"""Mu2 encoder-decoder model with sliding-window KV cache. + +Export wrappers, OnnxConfig registrations, and ``WinMLMu2Model`` inference +class for Mu2 (custom ``trust_remote_code`` model). + +Export Strategy: +- Mu2EncoderWrapper (``feature-extraction``): encoder-only ONNX. +- Mu2DecoderWrapper (``text2text-generation``): decoder with + ``WinMLSlidingWindowCache`` (Slice+Concat, no ScatterElements). + Present KV output is the full updated buffer. + +Custom model integration (``auto_map``): + The Mu2 model uses ``trust_remote_code=True`` with ``auto_map`` in + ``config.json`` pointing to ``modeling_mu.py`` / ``configuration_mu.py`` + alongside the weights. KV cache support was added to the model source + (``MuAttentionSDPA`` accepts ``past_key_value`` + ``cache_position``). + +Key decisions: +- Uses ``WinMLSlidingWindowCache`` (not Static) because Mu2 uses RoPE, + not learned relative position bias. RoPE is baked into K tensors, + so buffer positions don't affect attention — sliding window is safe. +- The decoder ONNX input is ``position_id`` (absolute seq position for + RoPE), not ``cache_position`` (which implies buffer-position indexing). +- Mu2's ``generate_sin_cos_pos_emb`` was patched for transformers < 5.x + compatibility (computes inv_freq directly instead of using + ``LlamaRotaryEmbedding.compute_default_rope_parameters``). +- Mu2's ``Mu2Config`` must pass ``pad_token_id`` / ``bos_token_id`` / + ``eos_token_id`` to ``super().__init__()`` or PretrainedConfig + overrides them to None. + +Usage:: + + wmk config -m path/to/mu2 --task translation --trust-remote-code -o mu2.json + wmk build -c mu2_encoder.json -m path/to/mu2 --trust-remote-code -o output/encoder + wmk build -c mu2_decoder.json -m path/to/mu2 --trust-remote-code -o output/decoder """ from __future__ import annotations diff --git a/src/winml/modelkit/models/hf/t5.py b/src/winml/modelkit/models/hf/t5.py index 562bc4063..2977ebb55 100644 --- a/src/winml/modelkit/models/hf/t5.py +++ b/src/winml/modelkit/models/hf/t5.py @@ -292,6 +292,18 @@ class WinMLT5Model(WinMLEncoderDecoderModel): "decoder": "text2text-generation", } + @classmethod + def get_cache_class(cls) -> type: + """T5 requires WinMLStaticCache (cannot use sliding window). + + T5's relative position bias (``T5Attention.compute_bias``) computes + ``memory_position = arange(key_length)`` — it assumes buffer + position == sequence position. With sliding window, KV entries + shift left each step, so buffer positions no longer correspond to + sequence positions, producing wrong relative distances. + """ + return WinMLStaticCache + @property def generation_config(self): # noqa: D102 if not hasattr(self, "_generation_config"): diff --git a/src/winml/modelkit/models/winml/pipeline_model.py b/src/winml/modelkit/models/winml/pipeline_model.py index a3e541035..db2a76603 100644 --- a/src/winml/modelkit/models/winml/pipeline_model.py +++ b/src/winml/modelkit/models/winml/pipeline_model.py @@ -9,12 +9,32 @@ prefill + gen). Each subclass declares ``_SUB_MODEL_CONFIG`` mapping component names to HF tasks; ``from_pretrained()`` builds them all. -Also provides the ``PIPELINE_MODEL_REGISTRY`` and ``register_pipeline_model`` -decorator, used by ``wmk config`` to generate per-component config files. +Registry +-------- +``@register_pipeline_model(model_type, task)`` registers a pipeline class. +``wmk config`` checks the registry to generate one config file per component:: + + wmk config -m google-t5/t5-small --task translation -o t5.json + # → t5_encoder.json (feature-extraction) + t5_decoder.json (text2text-generation) + + wmk build -c t5_encoder.json -m google-t5/t5-small -o output/encoder + wmk build -c t5_decoder.json -m google-t5/t5-small -o output/decoder + +Per-component kwargs +-------------------- +``sub_model_kwargs`` in ``from_pretrained`` allows different ``shape_config`` +per sub-model (e.g., different ``max_cache_len`` for prefill vs gen):: + + WinMLPipelineModel.from_pretrained(model_id, task="text-generation", + sub_model_kwargs={ + "decoder_prefill": {"shape_config": {"max_cache_len": 256, "seq_len": 64}}, + "decoder_gen": {"shape_config": {"max_cache_len": 256, "seq_len": 1}}, + }) Concrete pipeline models live alongside their export configs: - ``models.hf.t5.WinMLT5Model`` (encoder-decoder, T5) +- ``models.hf.mu2.WinMLMu2Model`` (encoder-decoder, Mu2) - ``models.hf.qwen.WinMLQwen3Model`` (decoder-only, Qwen3) """ From cd8125e98bad4651a2c8af999cdfc91eb6767e44 Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Wed, 15 Apr 2026 15:01:08 +0800 Subject: [PATCH 06/32] refactor: sliding window cache outputs new-token KV (not full buffer) WinMLSlidingWindowCache.update() now captures new-token KV (like WinMLStaticCache) and stores in self.captured. Export wrappers output captured[i] for both cache types. update_all_layers() does Slice+Concat on the inference side. Move captured dict to WinMLCache base class. --- src/winml/modelkit/models/hf/kv_cache.py | 33 ++++++++++++------------ src/winml/modelkit/models/hf/mu2.py | 4 +-- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/winml/modelkit/models/hf/kv_cache.py b/src/winml/modelkit/models/hf/kv_cache.py index 3e1e2e1b8..c8c543d42 100644 --- a/src/winml/modelkit/models/hf/kv_cache.py +++ b/src/winml/modelkit/models/hf/kv_cache.py @@ -72,6 +72,9 @@ def __init__(self, config: PretrainedConfig, *args: Any, **kwargs: Any) -> None: super().__init__(config, *args, **kwargs) self.step: int = 0 self.num_layers: int = config.num_hidden_layers + #: New-token KV captured during ``update()``, keyed by layer index. + #: Export wrappers read ``captured[i]`` to build ONNX present outputs. + self.captured: dict[int, tuple[torch.Tensor, torch.Tensor]] = {} # ----- Interface for WinMLEncoderDecoderModel.forward ----- @@ -129,10 +132,6 @@ class WinMLStaticCache(WinMLCache): position_input_name: str = "cache_position" - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.captured: dict[int, tuple[torch.Tensor, torch.Tensor]] = {} - def update( self, key_states: torch.Tensor, @@ -174,18 +173,15 @@ def update_all_layers(self, outputs: dict[str, Any]) -> None: class WinMLSlidingWindowCache(WinMLCache): """FIFO cache: evict oldest, append new at end (Slice+Concat). - **Export**: ``update()`` traces as Slice+Concat — no ScatterElements. - Present KV output is the full updated buffer. - **Inference**: ``update_all_layers`` replaces the full buffer. + **Export**: ``update()`` does Slice+Concat on the buffer and captures + the new-token KV (same as ``WinMLStaticCache.captured``). Present KV + output is the new token only ``[batch, heads, 1, head_dim]``. + **Inference**: ``update_all_layers`` does Slice+Concat from present KV. Mask is right-aligned: ``[0, 0, ..., 0, 1, 1, ..., 1]``. """ position_input_name: str = "position_id" - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.updated: dict[int, tuple[torch.Tensor, torch.Tensor]] = {} - def update( self, key_states: torch.Tensor, @@ -193,9 +189,12 @@ def update( layer_idx: int, cache_kwargs: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Drop oldest entry, append new KV at end. Return full buffer.""" + """Drop oldest, append new KV at end. Capture new-token KV for output.""" import torch + # Capture new-token KV for ONNX output (same interface as WinMLStaticCache) + self.captured[layer_idx] = (key_states, value_states) + old_k = self.layers[layer_idx].keys[:, :, 1:, :] new_k = torch.cat([old_k, key_states], dim=2) self.layers[layer_idx].keys = new_k @@ -204,7 +203,6 @@ def update( new_v = torch.cat([old_v, value_states], dim=2) self.layers[layer_idx].values = new_v - self.updated[layer_idx] = (new_k, new_v) return new_k, new_v def build_decoder_mask(self, max_len: int) -> torch.Tensor: @@ -216,14 +214,17 @@ def build_decoder_mask(self, max_len: int) -> torch.Tensor: return mask def update_all_layers(self, outputs: dict[str, Any]) -> None: - """Replace full KV buffers for all layers, then advance.""" + """Slice+Concat present KV into buffer for all layers, then advance.""" import torch for i in range(self.num_layers): k = outputs[f"present_{i}_key"] v = outputs[f"present_{i}_value"] - self.layers[i].keys = k if isinstance(k, torch.Tensor) else torch.tensor(k) - self.layers[i].values = v if isinstance(v, torch.Tensor) else torch.tensor(v) + k = k if isinstance(k, torch.Tensor) else torch.tensor(k) + v = v if isinstance(v, torch.Tensor) else torch.tensor(v) + # FIFO: drop oldest, append new token KV + self.layers[i].keys = torch.cat([self.layers[i].keys[:, :, 1:, :], k], dim=2) + self.layers[i].values = torch.cat([self.layers[i].values[:, :, 1:, :], v], dim=2) self.step += 1 def get_seq_length(self, layer_idx: int = 0) -> int: diff --git a/src/winml/modelkit/models/hf/mu2.py b/src/winml/modelkit/models/hf/mu2.py index 8fda3dd38..e0e165653 100644 --- a/src/winml/modelkit/models/hf/mu2.py +++ b/src/winml/modelkit/models/hf/mu2.py @@ -160,10 +160,10 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: ) logits = self.model.lm_head(hidden_states) - # Output full updated cache buffers (not just new token) + # Output new-token KV only (same as T5 — captured during update) result: list[torch.Tensor] = [logits] for i in range(self.num_layers): - k, v = cache.updated[i] + k, v = cache.captured[i] result.extend([k, v]) return tuple(result) From d10a9e8930e0187715acddb62317d8011d6e05b3 Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Wed, 15 Apr 2026 15:05:03 +0800 Subject: [PATCH 07/32] WinMLPipelineModel -> WinMLCompositeModel --- src/winml/modelkit/models/hf/encoder_decoder.py | 8 ++++---- src/winml/modelkit/models/winml/decoder_only.py | 8 ++++---- src/winml/modelkit/models/winml/pipeline_model.py | 14 +++++++------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/winml/modelkit/models/hf/encoder_decoder.py b/src/winml/modelkit/models/hf/encoder_decoder.py index f76b0a4dd..cba9d82d2 100644 --- a/src/winml/modelkit/models/hf/encoder_decoder.py +++ b/src/winml/modelkit/models/hf/encoder_decoder.py @@ -6,7 +6,7 @@ Class hierarchy:: - WinMLPipelineModel — multi-component base + WinMLCompositeModel — multi-component base └─ WinMLEncoderDecoderModel(GenerationMixin) — encoder-decoder inference ├─ WinMLT5Model (t5.py) — WinMLStaticCache └─ WinMLMu2Model (mu2.py) — WinMLSlidingWindowCache @@ -58,7 +58,7 @@ from transformers.generation.utils import GenerationMixin from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput -from ..winml.pipeline_model import WinMLPipelineModel +from ..winml.pipeline_model import WinMLCompositeModel from .kv_cache import WinMLStaticCache @@ -143,7 +143,7 @@ def generate( # ============================================================================= -class WinMLEncoderDecoderModel(WinMLPipelineModel, GenerationMixin): +class WinMLEncoderDecoderModel(WinMLCompositeModel, GenerationMixin): """Pipeline model with HF GenerationMixin support. Expects sub-components ``"encoder"`` and ``"decoder"`` in @@ -213,7 +213,7 @@ def __init__(self, encoder: Any, expected: dict[str, list[int]]) -> None: self._expected = expected def forward(self, **kwargs: Any) -> BaseModelOutput: - feeds = WinMLPipelineModel._pad_inputs(kwargs, self._expected) + feeds = WinMLCompositeModel._pad_inputs(kwargs, self._expected) return self._encoder(**feeds) def get_encoder(self) -> torch.nn.Module: diff --git a/src/winml/modelkit/models/winml/decoder_only.py b/src/winml/modelkit/models/winml/decoder_only.py index 4308f7fa2..f516bfe83 100644 --- a/src/winml/modelkit/models/winml/decoder_only.py +++ b/src/winml/modelkit/models/winml/decoder_only.py @@ -6,7 +6,7 @@ Class hierarchy:: - WinMLPipelineModel(PreTrainedModel) — multi-component base + WinMLCompositeModel(PreTrainedModel) — multi-component base └─ WinMLDecoderOnlyModel(GenerationMixin) — prefill + gen with StaticCache └─ WinMLQwen3Model — Qwen3 tasks + generation config @@ -65,10 +65,10 @@ from transformers.generation.utils import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast -from .pipeline_model import WinMLPipelineModel +from .pipeline_model import WinMLCompositeModel -_pad_inputs = WinMLPipelineModel._pad_inputs +_pad_inputs = WinMLCompositeModel._pad_inputs if TYPE_CHECKING: @@ -159,7 +159,7 @@ class DecoderOnlyPrefillInputGenerator(DecoderOnlyInputGenerator): # ========================================================================= -class WinMLDecoderOnlyModel(WinMLPipelineModel, GenerationMixin): +class WinMLDecoderOnlyModel(WinMLCompositeModel, GenerationMixin): """Decoder-only pipeline model with HF GenerationMixin support. Expects sub-components ``"decoder_prefill"`` and ``"decoder_gen"`` in diff --git a/src/winml/modelkit/models/winml/pipeline_model.py b/src/winml/modelkit/models/winml/pipeline_model.py index db2a76603..6ba35f409 100644 --- a/src/winml/modelkit/models/winml/pipeline_model.py +++ b/src/winml/modelkit/models/winml/pipeline_model.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- """WinML Pipeline Model base and registry. -Provides ``WinMLPipelineModel`` — a base class for models composed of +Provides ``WinMLCompositeModel`` — a base class for models composed of multiple ``WinMLAutoModel`` sub-components (e.g., encoder + decoder, prefill + gen). Each subclass declares ``_SUB_MODEL_CONFIG`` mapping component names to HF tasks; ``from_pretrained()`` builds them all. @@ -25,7 +25,7 @@ ``sub_model_kwargs`` in ``from_pretrained`` allows different ``shape_config`` per sub-model (e.g., different ``max_cache_len`` for prefill vs gen):: - WinMLPipelineModel.from_pretrained(model_id, task="text-generation", + WinMLCompositeModel.from_pretrained(model_id, task="text-generation", sub_model_kwargs={ "decoder_prefill": {"shape_config": {"max_cache_len": 256, "seq_len": 64}}, "decoder_gen": {"shape_config": {"max_cache_len": 256, "seq_len": 1}}, @@ -74,11 +74,11 @@ def decorator(cls: type) -> type: # ========================================================================= -# WinMLPipelineModel — multi-component base +# WinMLCompositeModel — multi-component base # ========================================================================= -class WinMLPipelineModel(PreTrainedModel): +class WinMLCompositeModel(PreTrainedModel): """Base class for models composed of multiple WinMLAutoModel sub-components. Subclasses declare ``_SUB_MODEL_CONFIG``: a mapping of component name to @@ -108,10 +108,10 @@ def from_pretrained( force_rebuild: bool = False, sub_model_kwargs: dict[str, dict[str, Any]] | None = None, **kwargs: Any, - ) -> WinMLPipelineModel: + ) -> WinMLCompositeModel: """Build all sub-components and return ready-to-use model. - When called on ``WinMLPipelineModel`` directly (not a subclass), + When called on ``WinMLCompositeModel`` directly (not a subclass), ``task`` is required to resolve the concrete class from ``PIPELINE_MODEL_REGISTRY``. When called on a registered subclass (e.g., ``WinMLT5Model``), ``task`` is optional. @@ -184,7 +184,7 @@ def dtype(self) -> torch.dtype: """Model dtype for HF compatibility.""" return torch.float32 - def to(self, *args: Any, **kwargs: Any) -> WinMLPipelineModel: + def to(self, *args: Any, **kwargs: Any) -> WinMLCompositeModel: """No-op for HF pipeline compatibility.""" return self From cac83d89e84c5f936f1198cd502fc7728e4b185a Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Wed, 15 Apr 2026 17:26:22 +0800 Subject: [PATCH 08/32] feat: sliding window KV cache for Qwen3 + refactor cache interface - Generalize WinMLSlidingWindowCache.update() for N tokens (prefill+gen) - Qwen3 uses sliding window: cache_position computed internally as right-aligned buffer positions, position_ids handles RoPE separately - Left-pad prefill chunks so real tokens are at END (matches causal mask) - Add _resolve_cache to WinMLDecoderOnlyModel (same pattern as enc-dec) - Make get_cache_class abstract in both base classes - Rename pipeline_model.py -> composite_model.py, register_composite_model - Remove position_id ONNX input from Qwen (no longer needed) - update_all_layers moved to WinMLCache base (calls subclass update()) --- src/winml/modelkit/commands/config.py | 2 +- .../modelkit/models/hf/encoder_decoder.py | 7 +- src/winml/modelkit/models/hf/kv_cache.py | 61 ++++----- src/winml/modelkit/models/hf/mu2.py | 4 +- src/winml/modelkit/models/hf/qwen.py | 35 +++-- src/winml/modelkit/models/hf/t5.py | 4 +- .../{pipeline_model.py => composite_model.py} | 4 +- .../modelkit/models/winml/decoder_only.py | 120 ++++++++---------- 8 files changed, 117 insertions(+), 120 deletions(-) rename src/winml/modelkit/models/winml/{pipeline_model.py => composite_model.py} (98%) diff --git a/src/winml/modelkit/commands/config.py b/src/winml/modelkit/commands/config.py index aceea18b0..5d7885284 100644 --- a/src/winml/modelkit/commands/config.py +++ b/src/winml/modelkit/commands/config.py @@ -528,7 +528,7 @@ def _resolve_pipeline_components( import winml.modelkit.models.hf # noqa: F401 # trigger pipeline registrations - from ..models.winml.pipeline_model import PIPELINE_MODEL_REGISTRY + from ..models.winml.composite_model import PIPELINE_MODEL_REGISTRY # Resolve model_type from HF config if not provided resolved_type = model_type diff --git a/src/winml/modelkit/models/hf/encoder_decoder.py b/src/winml/modelkit/models/hf/encoder_decoder.py index cba9d82d2..7d9cf5503 100644 --- a/src/winml/modelkit/models/hf/encoder_decoder.py +++ b/src/winml/modelkit/models/hf/encoder_decoder.py @@ -58,8 +58,7 @@ from transformers.generation.utils import GenerationMixin from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput -from ..winml.pipeline_model import WinMLCompositeModel -from .kv_cache import WinMLStaticCache +from ..winml.composite_model import WinMLCompositeModel if TYPE_CHECKING: @@ -243,8 +242,8 @@ def prepare_inputs_for_generation( @classmethod def get_cache_class(cls) -> type: - """Return the WinMLCache subclass for this model. Subclasses override.""" - return WinMLStaticCache + """Return the WinMLCache subclass. Subclasses must override.""" + raise NotImplementedError def _resolve_cache(self, past_key_values: Any) -> Any: """Unwrap or create the WinMLCache for this generation step. diff --git a/src/winml/modelkit/models/hf/kv_cache.py b/src/winml/modelkit/models/hf/kv_cache.py index c8c543d42..bfa88374d 100644 --- a/src/winml/modelkit/models/hf/kv_cache.py +++ b/src/winml/modelkit/models/hf/kv_cache.py @@ -57,8 +57,8 @@ class WinMLCache(StaticCache): """Abstract base for WinML KV caches (export + inference). - Subclasses set ``position_input_name`` and implement - ``build_decoder_mask`` and ``update_all_layers``. + Subclasses set ``position_input_name``, implement ``build_decoder_mask``, + and override ``update()`` for cache-specific write logic. ``step`` tracks the absolute generation position (used for RoPE and mask construction). @@ -82,9 +82,24 @@ def __init__(self, config: PretrainedConfig, *args: Any, **kwargs: Any) -> None: def build_decoder_mask(self, max_len: int) -> torch.Tensor: """Build the decoder attention mask for the current step.""" - @abstractmethod def update_all_layers(self, outputs: dict[str, Any]) -> None: - """Write present KV for all layers from ONNX output and advance step.""" + """Write present KV for all layers via ``update()`` and advance step. + + Step advances by N where N is the seq_len of the present KV tensors + (1 for gen, chunk_len for prefill). + """ + import torch + + ck = {"cache_position": torch.tensor([self.step], dtype=torch.int64)} + n = 0 + for i in range(self.num_layers): + k = outputs[f"present_{i}_key"] + v = outputs[f"present_{i}_value"] + k = k if isinstance(k, torch.Tensor) else torch.tensor(k) + v = v if isinstance(v, torch.Tensor) else torch.tensor(v) + n = k.size(2) + self.update(k, v, i, cache_kwargs=ck) + self.step += n def reset(self) -> None: """Zero out all layers and reset step (start of new generation).""" @@ -151,19 +166,6 @@ def build_decoder_mask(self, max_len: int) -> torch.Tensor: mask[0, : self.step + 1] = 1 return mask - def update_all_layers(self, outputs: dict[str, Any]) -> None: - """Write new-token KV at current step for all layers, then advance.""" - import torch - - ck = {"cache_position": torch.tensor([self.step], dtype=torch.int64)} - for i in range(self.num_layers): - k = outputs[f"present_{i}_key"] - v = outputs[f"present_{i}_value"] - k = k if isinstance(k, torch.Tensor) else torch.tensor(k) - v = v if isinstance(v, torch.Tensor) else torch.tensor(v) - super(WinMLCache, self).update(k, v, i, cache_kwargs=ck) - self.step += 1 - # ============================================================================= # WinMLSlidingWindowCache — Slice + Concat (FIFO) @@ -189,17 +191,20 @@ def update( layer_idx: int, cache_kwargs: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Drop oldest, append new KV at end. Capture new-token KV for output.""" + """Drop N oldest, append N new KV at end (N = key_states.size(2)). + + Works for both single-token gen (N=1) and multi-token prefill (N>1). + """ import torch - # Capture new-token KV for ONNX output (same interface as WinMLStaticCache) self.captured[layer_idx] = (key_states, value_states) - old_k = self.layers[layer_idx].keys[:, :, 1:, :] + n = key_states.size(2) + old_k = self.layers[layer_idx].keys[:, :, n:, :] new_k = torch.cat([old_k, key_states], dim=2) self.layers[layer_idx].keys = new_k - old_v = self.layers[layer_idx].values[:, :, 1:, :] + old_v = self.layers[layer_idx].values[:, :, n:, :] new_v = torch.cat([old_v, value_states], dim=2) self.layers[layer_idx].values = new_v @@ -213,20 +218,6 @@ def build_decoder_mask(self, max_len: int) -> torch.Tensor: mask[0, max(0, max_len - self.step - 1) :] = 1 return mask - def update_all_layers(self, outputs: dict[str, Any]) -> None: - """Slice+Concat present KV into buffer for all layers, then advance.""" - import torch - - for i in range(self.num_layers): - k = outputs[f"present_{i}_key"] - v = outputs[f"present_{i}_value"] - k = k if isinstance(k, torch.Tensor) else torch.tensor(k) - v = v if isinstance(v, torch.Tensor) else torch.tensor(v) - # FIFO: drop oldest, append new token KV - self.layers[i].keys = torch.cat([self.layers[i].keys[:, :, 1:, :], k], dim=2) - self.layers[i].values = torch.cat([self.layers[i].values[:, :, 1:, :], v], dim=2) - self.step += 1 - def get_seq_length(self, layer_idx: int = 0) -> int: """Filled positions: ``min(step, max_cache_len)``.""" max_len = self.layers[layer_idx].keys.shape[2] diff --git a/src/winml/modelkit/models/hf/mu2.py b/src/winml/modelkit/models/hf/mu2.py index e0e165653..4ff39e303 100644 --- a/src/winml/modelkit/models/hf/mu2.py +++ b/src/winml/modelkit/models/hf/mu2.py @@ -50,7 +50,7 @@ class for Mu2 (custom ``trust_remote_code`` model). from optimum.utils.input_generators import DummyTextInputGenerator from ...export import register_onnx_overwrite -from ..winml.pipeline_model import register_pipeline_model +from ..winml.composite_model import register_composite_model from .encoder_decoder import EncoderDecoderInputGenerator, WinMLEncoderDecoderModel from .kv_cache import PastKeyValueInputGenerator, WinMLSlidingWindowCache @@ -250,7 +250,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 } -@register_pipeline_model("mu2", "translation") +@register_composite_model("mu2", "translation") class WinMLMu2Model(WinMLEncoderDecoderModel): """Mu2 encoder-decoder model with sliding-window KV cache. diff --git a/src/winml/modelkit/models/hf/qwen.py b/src/winml/modelkit/models/hf/qwen.py index c54ad3285..c79837ce5 100644 --- a/src/winml/modelkit/models/hf/qwen.py +++ b/src/winml/modelkit/models/hf/qwen.py @@ -84,13 +84,13 @@ from ...export import register_onnx_overwrite from ...export.config import WinMLExportConfig from ..winml import register_specialization +from ..winml.composite_model import register_composite_model from ..winml.decoder_only import ( DecoderOnlyInputGenerator, DecoderOnlyPrefillInputGenerator, WinMLDecoderOnlyModel, ) -from ..winml.pipeline_model import register_pipeline_model -from .kv_cache import PastKeyValueInputGenerator, WinMLStaticCache +from .kv_cache import PastKeyValueInputGenerator, WinMLSlidingWindowCache # ============================================================================= @@ -131,7 +131,7 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: """Run decoder with static KV cache. Positional args (order matches OnnxConfig.inputs): - input_ids, attention_mask, position_ids, cache_position, + input_ids, attention_mask, position_ids, position_id, past_0_key, past_0_value, past_1_key, past_1_value, ... Returns: @@ -142,12 +142,12 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: input_ids = args[0] attention_mask = args[1] position_ids = args[2] - cache_position = args[3] - kv_start = 4 + kv_start = 3 - # Build WinMLStaticCache from input KV tensors. - # Decoder-only: pass StaticCache directly (no EncoderDecoderCache needed). - cache = WinMLStaticCache(self.config, max_cache_len=args[kv_start].size(2)) + seq_len = input_ids.size(1) + + # Build WinMLSlidingWindowCache from input KV tensors. + cache = WinMLSlidingWindowCache(self.config, max_cache_len=args[kv_start].size(2)) cache.early_initialization( batch_size=input_ids.size(0), num_heads=args[kv_start].size(1), @@ -155,10 +155,22 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: dtype=args[kv_start].dtype, device=input_ids.device, ) + max_cache_len = args[kv_start].size(2) for i in range(self.num_layers): cache.layers[i].keys = args[kv_start + i * 2] cache.layers[i].values = args[kv_start + i * 2 + 1] + # Sliding window: tokens always append at the END of the buffer. + # cache_position = buffer positions (right-aligned) so HF's + # create_causal_mask builds correct kv_idx <= q_idx constraint. + # position_ids (separate) handles RoPE with absolute positions. + cache_position = torch.arange( + max_cache_len - seq_len, + max_cache_len, + dtype=torch.int64, + device=input_ids.device, + ) + out = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -203,7 +215,6 @@ def _qwen_io_inputs(num_layers: int) -> dict[str, dict[int, str]]: "input_ids": {0: "batch_size"}, "attention_mask": {0: "batch_size"}, "position_ids": {0: "batch_size"}, - "cache_position": {}, } for i in range(num_layers): result[f"past_{i}_key"] = {0: "batch_size"} @@ -284,7 +295,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 # ============================================================================= -@register_pipeline_model("qwen3", "text-generation") +@register_composite_model("qwen3", "text-generation") class WinMLQwen3Model(WinMLDecoderOnlyModel): """Qwen3 decoder-only model for text generation. @@ -297,6 +308,10 @@ class WinMLQwen3Model(WinMLDecoderOnlyModel): "decoder_gen": "text-generation", } + @classmethod + def get_cache_class(cls) -> type: # noqa: D102 + return WinMLSlidingWindowCache + @property def generation_config(self): # noqa: D102 if not hasattr(self, "_generation_config"): diff --git a/src/winml/modelkit/models/hf/t5.py b/src/winml/modelkit/models/hf/t5.py index 2977ebb55..e534ac626 100644 --- a/src/winml/modelkit/models/hf/t5.py +++ b/src/winml/modelkit/models/hf/t5.py @@ -35,7 +35,7 @@ from transformers.cache_utils import DynamicCache, EncoderDecoderCache from ...export import register_onnx_overwrite -from ..winml.pipeline_model import register_pipeline_model +from ..winml.composite_model import register_composite_model from .encoder_decoder import EncoderDecoderInputGenerator, WinMLEncoderDecoderModel from .kv_cache import PastKeyValueInputGenerator, WinMLStaticCache @@ -279,7 +279,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 # ============================================================================= -@register_pipeline_model("t5", "translation") +@register_composite_model("t5", "translation") class WinMLT5Model(WinMLEncoderDecoderModel): """T5 encoder-decoder model for translation. diff --git a/src/winml/modelkit/models/winml/pipeline_model.py b/src/winml/modelkit/models/winml/composite_model.py similarity index 98% rename from src/winml/modelkit/models/winml/pipeline_model.py rename to src/winml/modelkit/models/winml/composite_model.py index 6ba35f409..f1c9192f5 100644 --- a/src/winml/modelkit/models/winml/pipeline_model.py +++ b/src/winml/modelkit/models/winml/composite_model.py @@ -11,7 +11,7 @@ Registry -------- -``@register_pipeline_model(model_type, task)`` registers a pipeline class. +``@register_composite_model(model_type, task)`` registers a pipeline class. ``wmk config`` checks the registry to generate one config file per component:: wmk config -m google-t5/t5-small --task translation -o t5.json @@ -63,7 +63,7 @@ PIPELINE_MODEL_REGISTRY: dict[tuple[str, str], type] = {} -def register_pipeline_model(model_type: str, task: str): +def register_composite_model(model_type: str, task: str): """Class decorator that registers a pipeline model for `wmk config`.""" def decorator(cls: type) -> type: diff --git a/src/winml/modelkit/models/winml/decoder_only.py b/src/winml/modelkit/models/winml/decoder_only.py index f516bfe83..47fa3037c 100644 --- a/src/winml/modelkit/models/winml/decoder_only.py +++ b/src/winml/modelkit/models/winml/decoder_only.py @@ -12,7 +12,7 @@ How it works: -1. ``@register_pipeline_model("qwen3", "text-generation")`` hooks into +1. ``@register_composite_model("qwen3", "text-generation")`` hooks into ``winml config`` so that ``winml config -m Qwen/Qwen3-0.6B --task text-generation`` generates ``qwen_decoder_prefill.json`` + ``qwen_decoder_gen.json``. @@ -47,7 +47,7 @@ rather than trimming to the last token. On subsequent calls with a populated ``StaticCache``, we trim to the last token as usual. -Design principles (same as pipeline_model.py): +Design principles (same as composite_model.py): - ONNX I/O names and shapes are read from ``io_config``, never hardcoded. - Inputs smaller than ONNX expected shape are zero-padded via ``_pad_inputs``. @@ -61,18 +61,17 @@ import torch from optimum.utils.input_generators import DummyInputGenerator -from transformers import Cache, StaticCache from transformers.generation.utils import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast -from .pipeline_model import WinMLCompositeModel +from .composite_model import WinMLCompositeModel _pad_inputs = WinMLCompositeModel._pad_inputs if TYPE_CHECKING: - from transformers import PretrainedConfig + from transformers import Cache, PretrainedConfig logger = logging.getLogger(__name__) @@ -104,6 +103,7 @@ class DecoderOnlyInputGenerator(DummyInputGenerator): "attention_mask", "position_ids", "cache_position", + "position_id", ) _default_seq_len: int = 1 @@ -145,6 +145,8 @@ def generate( return torch.arange(self.seq_len, dtype=torch.int64).unsqueeze(0) if input_name == "cache_position": return torch.arange(self.seq_len, dtype=torch.int64) + if input_name == "position_id": + return torch.arange(self.seq_len, dtype=torch.int64) raise ValueError(f"Unknown input: {input_name}") @@ -216,7 +218,24 @@ def __init__( # Prefill chunk size self._prefill_seq_len = self._prefill_expected["input_ids"][1] - # ----- GenerationMixin interface ----- + # ----- Cache + GenerationMixin interface ----- + + @classmethod + def get_cache_class(cls) -> type: + """Return the WinMLCache subclass. Subclasses must override.""" + raise NotImplementedError + + def _resolve_cache(self, past_key_values: Any) -> Any: + """Unwrap or create WinMLCache for this generation step.""" + from ..hf.kv_cache import WinMLCache + + if isinstance(past_key_values, WinMLCache): + return past_key_values + + kv_shape = [1, self._num_kv_heads, self._max_cache_len, self._head_dim] + cache = self.get_cache_class().create(self.config, kv_shape, self._kv_dtype) + cache.reset() + return cache def can_generate(self) -> bool: # noqa: D102 return True @@ -228,16 +247,12 @@ def prepare_inputs_for_generation( attention_mask: torch.Tensor | None = None, **kwargs: Any, ) -> dict[str, Any]: - """Build inputs for each generate() step. + """Build inputs for each generate() step.""" + from ..hf.kv_cache import WinMLCache - GenerationMixin may pass a DynamicCache (auto-created, empty) on the - first call. Only trim to last token when we have a populated - StaticCache (i.e., after prefill). - """ - if isinstance(past_key_values, StaticCache) and past_key_values.get_seq_length() > 0: + if isinstance(past_key_values, WinMLCache) and past_key_values.get_seq_length() > 0: input_ids = input_ids[:, -1:] else: - # First call or empty cache: pass full prompt for prefill past_key_values = None return { "input_ids": input_ids, @@ -270,17 +285,7 @@ def forward( Returns: CausalLMOutputWithPast with logits and updated StaticCache. """ - # Resolve or create StaticCache (same pattern as T5) - cache = past_key_values if isinstance(past_key_values, StaticCache) else None - if cache is None: - cache = StaticCache(self.config, max_cache_len=self._max_cache_len) - cache.early_initialization( - batch_size=1, - num_heads=self._num_kv_heads, - head_dim=self._head_dim, - dtype=self._kv_dtype, - device=torch.device("cpu"), - ) + cache = self._resolve_cache(past_key_values) seq_len = input_ids.shape[1] if seq_len > 1: @@ -295,11 +300,10 @@ def forward( # ----- Prefill (chunked) ----- - def _run_prefill(self, input_ids: torch.Tensor, cache: StaticCache) -> torch.Tensor: + def _run_prefill(self, input_ids: torch.Tensor, cache: Any) -> torch.Tensor: """Run prefill model in a loop over chunks of ``prefill_seq_len``. - Returns logits for ALL real input positions ``[1, seq_len, vocab_size]`` - (same convention as HF CausalLM — enables perplexity evaluation). + Returns logits for ALL real input positions ``[1, seq_len, vocab_size]``. """ seq_len = input_ids.shape[1] all_logits: list[torch.Tensor] = [] @@ -308,76 +312,64 @@ def _run_prefill(self, input_ids: torch.Tensor, cache: StaticCache) -> torch.Ten end = min(start + self._prefill_seq_len, seq_len) chunk_len = end - start - # Pad chunk to prefill_seq_len (right-padding) + # Left-pad: real tokens at the END of the chunk (matches sliding + # window right-alignment so causal mask kv_idx<=q_idx is correct). + pad_len = self._prefill_seq_len - chunk_len padded_ids = torch.zeros(1, self._prefill_seq_len, dtype=input_ids.dtype) - padded_ids[0, :chunk_len] = input_ids[0, start:end] + padded_ids[0, pad_len:] = input_ids[0, start:end] - position_ids = torch.arange( - start, start + self._prefill_seq_len, dtype=torch.int64 - ).unsqueeze(0) - cache_position = torch.arange(start, start + self._prefill_seq_len, dtype=torch.int64) + position_ids = torch.zeros(1, self._prefill_seq_len, dtype=torch.int64) + position_ids[0, pad_len:] = torch.arange(start, start + chunk_len, dtype=torch.int64) - # Attention mask: 1 for all real tokens so far + # Mask: 1s for real tokens (previously cached + current chunk). + # With left-padding, real tokens are at the rightmost chunk_len + # positions of the prefill_seq_len slot. + filled = min(cache.step + chunk_len, self._max_cache_len) attn_mask = torch.zeros(1, self._max_cache_len, dtype=torch.int64) - attn_mask[0, : start + chunk_len] = 1 + attn_mask[0, max(0, self._max_cache_len - filled) :] = 1 feeds: dict[str, Any] = { "input_ids": padded_ids, "attention_mask": attn_mask, "position_ids": position_ids, - "cache_position": cache_position, } for i in range(self._num_kv_layers): feeds[f"past_{i}_key"] = cache.layers[i].keys.detach() feeds[f"past_{i}_value"] = cache.layers[i].values.detach() - outputs = self._prefill_model(**_pad_inputs(feeds, self._prefill_expected)) + outputs = self._prefill_model(**feeds) - # Write only real tokens' KV into cache (skip padding) - real_positions = cache_position[:chunk_len] - ck = {"cache_position": real_positions} - for i in range(self._num_kv_layers): - cache.update( - outputs[f"present_{i}_key"][:, :, :chunk_len, :], - outputs[f"present_{i}_value"][:, :, :chunk_len, :], - layer_idx=i, - cache_kwargs=ck, - ) + # update_all_layers advances step by present KV size (prefill_seq_len). + # Correct to chunk_len since padding KV should not count. + cache.update_all_layers(outputs) + if pad_len > 0: + cache.step -= pad_len - # Keep logits for real tokens only (discard padding positions) - all_logits.append(outputs["logits"][:, :chunk_len, :]) + # With left-padding, real token logits are at positions pad_len: + all_logits.append(outputs["logits"][:, pad_len : pad_len + chunk_len, :]) return torch.cat(all_logits, dim=1) # ----- Generation (single token) ----- - def _run_gen(self, input_ids: torch.Tensor, cache: StaticCache) -> torch.Tensor: + def _run_gen(self, input_ids: torch.Tensor, cache: Any) -> torch.Tensor: """Run gen model for a single token. Returns logits ``[1, 1, vocab_size]``.""" - fc = cache.get_seq_length() + fc = cache.step + filled = min(fc + 1, self._max_cache_len) attn_mask = torch.zeros(1, self._max_cache_len, dtype=torch.int64) - attn_mask[0, : fc + 1] = 1 + attn_mask[0, max(0, self._max_cache_len - filled) :] = 1 feeds: dict[str, Any] = { "input_ids": input_ids, "attention_mask": attn_mask, "position_ids": torch.tensor([[fc]], dtype=torch.int64), - "cache_position": torch.tensor([fc], dtype=torch.int64), } for i in range(self._num_kv_layers): feeds[f"past_{i}_key"] = cache.layers[i].keys.detach() feeds[f"past_{i}_value"] = cache.layers[i].values.detach() - outputs = self._gen_model(**_pad_inputs(feeds, self._gen_expected)) - - # Write new token's KV into cache - ck = {"cache_position": torch.tensor([fc], dtype=torch.int64)} - for i in range(self._num_kv_layers): - cache.update( - outputs[f"present_{i}_key"], - outputs[f"present_{i}_value"], - layer_idx=i, - cache_kwargs=ck, - ) + outputs = self._gen_model(**feeds) + cache.update_all_layers(outputs) return outputs["logits"] From 6551adb6551aba7834f2d8c06b9f7e8f9c186547 Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Thu, 16 Apr 2026 11:04:12 +0800 Subject: [PATCH 09/32] refactor: polymorphic KV cache for decoder-only prefill + gen MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Make WinMLDecoderOnlyModel cache-agnostic by delegating padding, mask construction, and cache updates to the WinMLCache subclass. kv_cache.py: - build_decoder_mask: add num_new_tokens param (default 1) - prepare_prefill_chunk: new abstract method — left-pad (sliding window) vs right-pad (static cache) - update_all_layers: cache_position as range instead of scalar so StaticCache.index_copy_ works with multi-token prefill KV decoder_only.py: - _run_prefill: delegates to cache.prepare_prefill_chunk and cache.build_decoder_mask; slices padding from outputs before update_all_layers - _run_gen: uses cache.build_decoder_mask instead of inline mask - Both pass cache_position in feeds when the ONNX model expects it Verified: Qwen3-0.6B e2e with both WinMLSlidingWindowCache and WinMLStaticCache produces correct results. --- src/winml/modelkit/models/hf/kv_cache.py | 82 +++++++++++++++++-- src/winml/modelkit/models/hf/qwen.py | 33 ++++++-- .../modelkit/models/winml/decoder_only.py | 48 +++++------ 3 files changed, 121 insertions(+), 42 deletions(-) diff --git a/src/winml/modelkit/models/hf/kv_cache.py b/src/winml/modelkit/models/hf/kv_cache.py index bfa88374d..beb83bc97 100644 --- a/src/winml/modelkit/models/hf/kv_cache.py +++ b/src/winml/modelkit/models/hf/kv_cache.py @@ -79,8 +79,34 @@ def __init__(self, config: PretrainedConfig, *args: Any, **kwargs: Any) -> None: # ----- Interface for WinMLEncoderDecoderModel.forward ----- @abstractmethod - def build_decoder_mask(self, max_len: int) -> torch.Tensor: - """Build the decoder attention mask for the current step.""" + def build_decoder_mask(self, max_len: int, num_new_tokens: int = 1) -> torch.Tensor: + """Build the decoder attention mask for the current step. + + Args: + max_len: Total cache buffer length. + num_new_tokens: Number of new tokens being added (1 for gen, + chunk_len for prefill). + """ + + @abstractmethod + def prepare_prefill_chunk( + self, + chunk_ids: torch.Tensor, + start: int, + prefill_seq_len: int, + ) -> tuple[torch.Tensor, torch.Tensor, int]: + """Pad tokens and build position IDs for one prefill chunk. + + Args: + chunk_ids: ``[1, chunk_len]`` — real tokens for this chunk. + start: Absolute position of the first real token. + prefill_seq_len: ONNX model's fixed prefill input length. + + Returns: + padded_ids: ``[1, prefill_seq_len]`` — padded input token IDs. + position_ids: ``[1, prefill_seq_len]`` — position encoding input. + pad_len: Number of leading padding positions (0 for right-pad). + """ def update_all_layers(self, outputs: dict[str, Any]) -> None: """Write present KV for all layers via ``update()`` and advance step. @@ -90,7 +116,6 @@ def update_all_layers(self, outputs: dict[str, Any]) -> None: """ import torch - ck = {"cache_position": torch.tensor([self.step], dtype=torch.int64)} n = 0 for i in range(self.num_layers): k = outputs[f"present_{i}_key"] @@ -98,6 +123,7 @@ def update_all_layers(self, outputs: dict[str, Any]) -> None: k = k if isinstance(k, torch.Tensor) else torch.tensor(k) v = v if isinstance(v, torch.Tensor) else torch.tensor(v) n = k.size(2) + ck = {"cache_position": torch.arange(self.step, self.step + n, dtype=torch.int64)} self.update(k, v, i, cache_kwargs=ck) self.step += n @@ -158,14 +184,31 @@ def update( self.captured[layer_idx] = (key_states, value_states) return super().update(key_states, value_states, layer_idx, cache_kwargs) - def build_decoder_mask(self, max_len: int) -> torch.Tensor: - """Left-aligned: first ``step + 1`` positions are 1.""" + def build_decoder_mask(self, max_len: int, num_new_tokens: int = 1) -> torch.Tensor: + """Left-aligned: first ``step + num_new_tokens`` positions are 1.""" import torch mask = torch.zeros(1, max_len, dtype=torch.int64) - mask[0, : self.step + 1] = 1 + mask[0, : self.step + num_new_tokens] = 1 return mask + def prepare_prefill_chunk( + self, + chunk_ids: torch.Tensor, + start: int, + prefill_seq_len: int, + ) -> tuple[torch.Tensor, torch.Tensor, int]: + """Right-pad: real tokens at START, padding at end.""" + import torch + + chunk_len = chunk_ids.shape[1] + padded_ids = torch.zeros(1, prefill_seq_len, dtype=chunk_ids.dtype) + padded_ids[0, :chunk_len] = chunk_ids[0] + + position_ids = torch.arange(start, start + prefill_seq_len, dtype=torch.int64).unsqueeze(0) + + return padded_ids, position_ids, 0 + # ============================================================================= # WinMLSlidingWindowCache — Slice + Concat (FIFO) @@ -210,14 +253,35 @@ def update( return new_k, new_v - def build_decoder_mask(self, max_len: int) -> torch.Tensor: - """Right-aligned: rightmost ``step + 1`` positions are 1.""" + def build_decoder_mask(self, max_len: int, num_new_tokens: int = 1) -> torch.Tensor: + """Right-aligned: rightmost ``step + num_new_tokens`` positions are 1.""" import torch + filled = min(self.step + num_new_tokens, max_len) mask = torch.zeros(1, max_len, dtype=torch.int64) - mask[0, max(0, max_len - self.step - 1) :] = 1 + mask[0, max(0, max_len - filled) :] = 1 return mask + def prepare_prefill_chunk( + self, + chunk_ids: torch.Tensor, + start: int, + prefill_seq_len: int, + ) -> tuple[torch.Tensor, torch.Tensor, int]: + """Left-pad: padding at start, real tokens at END.""" + import torch + + chunk_len = chunk_ids.shape[1] + pad_len = prefill_seq_len - chunk_len + + padded_ids = torch.zeros(1, prefill_seq_len, dtype=chunk_ids.dtype) + padded_ids[0, pad_len:] = chunk_ids[0] + + position_ids = torch.zeros(1, prefill_seq_len, dtype=torch.int64) + position_ids[0, pad_len:] = torch.arange(start, start + chunk_len, dtype=torch.int64) + + return padded_ids, position_ids, pad_len + def get_seq_length(self, layer_idx: int = 0) -> int: """Filled positions: ``min(step, max_cache_len)``.""" max_len = self.layers[layer_idx].keys.shape[2] diff --git a/src/winml/modelkit/models/hf/qwen.py b/src/winml/modelkit/models/hf/qwen.py index c79837ce5..7869d100f 100644 --- a/src/winml/modelkit/models/hf/qwen.py +++ b/src/winml/modelkit/models/hf/qwen.py @@ -5,8 +5,8 @@ """Qwen3 HuggingFace Model Configuration. Provides decoder export wrappers and OnnxConfig registrations for -Qwen3 decoder-only models with static KV cache, split into prefill -and generation sub-models. +Qwen3 decoder-only models with KV cache, split into prefill and +generation sub-models. Export Strategy (split by task): - QwenDecoderWrapper + QwenPrefillIOConfig: ``feature-extraction`` task @@ -15,15 +15,16 @@ → generation ONNX (input_ids [1, 1] → logits [1, 1, vocab] + KV [1, kv_heads, 1, head_dim]) Both tasks share the same wrapper class; OnnxConfig determines static shapes. -Uses ``WinMLStaticCache`` (from ``kv_cache.py``) to return new-token KV -directly as ONNX outputs, eliminating the scatter→gather round-trip. +The wrapper captures new-token KV directly as ONNX outputs, eliminating the +scatter→gather round-trip. How it works: 1. ``QwenDecoderWrapper.forward()`` takes positional args (order matches - OnnxConfig.inputs): input_ids, attention_mask, position_ids, cache_position, - past_0_key, past_0_value, ... It builds a ``WinMLStaticCache`` from the - input KV buffers, runs ``Qwen3ForCausalLM``, and returns logits + captured KV. + OnnxConfig.inputs): input_ids, attention_mask, position_ids, + past_0_key, past_0_value, ... It builds a ``WinMLSlidingWindowCache`` + from the input KV buffers, computes right-aligned ``cache_position`` + internally, runs ``Qwen3ForCausalLM``, and returns logits + captured KV. 2. Decoder-only models need NO ``EncoderDecoderCache`` wrapping — ``StaticCache`` is passed directly as ``past_key_values``. (Contrast with @@ -38,6 +39,24 @@ exporter fails with an internal error. Dynamo produces opset 18 models; opset 17 downconversion currently fails for these graphs. +Cache type: + +The default configuration uses ``WinMLSlidingWindowCache`` (FIFO +Slice+Concat). ``WinMLDecoderOnlyModel`` is cache-agnostic — padding, +mask construction, and cache updates are all delegated to the cache class +via ``prepare_prefill_chunk``, ``build_decoder_mask``, and +``update_all_layers``. To switch to ``WinMLStaticCache`` (index_copy_): + +1. **Export wrapper**: change ``QwenDecoderWrapper.forward()`` to use + ``WinMLStaticCache``, take ``cache_position`` as an explicit ONNX + input (instead of computing it internally), and set ``kv_start = 4``. +2. **OnnxConfig inputs**: add ``"cache_position": {}`` to + ``_qwen_io_inputs`` (after ``position_ids``, before ``past_*``). +3. **Inference**: override ``get_cache_class()`` to return + ``WinMLStaticCache``. ``WinMLDecoderOnlyModel`` passes + ``cache_position`` in feeds automatically when the ONNX model + expects it. + Task name constraints (Optimum compatibility): - Task names must exist in ``TasksManager.get_all_tasks()`` to pass diff --git a/src/winml/modelkit/models/winml/decoder_only.py b/src/winml/modelkit/models/winml/decoder_only.py index 47fa3037c..271f039fc 100644 --- a/src/winml/modelkit/models/winml/decoder_only.py +++ b/src/winml/modelkit/models/winml/decoder_only.py @@ -312,41 +312,38 @@ def _run_prefill(self, input_ids: torch.Tensor, cache: Any) -> torch.Tensor: end = min(start + self._prefill_seq_len, seq_len) chunk_len = end - start - # Left-pad: real tokens at the END of the chunk (matches sliding - # window right-alignment so causal mask kv_idx<=q_idx is correct). - pad_len = self._prefill_seq_len - chunk_len - padded_ids = torch.zeros(1, self._prefill_seq_len, dtype=input_ids.dtype) - padded_ids[0, pad_len:] = input_ids[0, start:end] - - position_ids = torch.zeros(1, self._prefill_seq_len, dtype=torch.int64) - position_ids[0, pad_len:] = torch.arange(start, start + chunk_len, dtype=torch.int64) - - # Mask: 1s for real tokens (previously cached + current chunk). - # With left-padding, real tokens are at the rightmost chunk_len - # positions of the prefill_seq_len slot. - filled = min(cache.step + chunk_len, self._max_cache_len) - attn_mask = torch.zeros(1, self._max_cache_len, dtype=torch.int64) - attn_mask[0, max(0, self._max_cache_len - filled) :] = 1 + padded_ids, position_ids, pad_len = cache.prepare_prefill_chunk( + input_ids[:, start:end], + start, + self._prefill_seq_len, + ) + attn_mask = cache.build_decoder_mask(self._max_cache_len, chunk_len) feeds: dict[str, Any] = { "input_ids": padded_ids, "attention_mask": attn_mask, "position_ids": position_ids, } + if "cache_position" in self._prefill_expected: + feeds["cache_position"] = position_ids.squeeze(0) for i in range(self._num_kv_layers): feeds[f"past_{i}_key"] = cache.layers[i].keys.detach() feeds[f"past_{i}_value"] = cache.layers[i].values.detach() outputs = self._prefill_model(**feeds) - # update_all_layers advances step by present KV size (prefill_seq_len). - # Correct to chunk_len since padding KV should not count. - cache.update_all_layers(outputs) - if pad_len > 0: - cache.step -= pad_len + # Slice out padding — real tokens are at [pad_len : pad_len+chunk_len] + real = slice(pad_len, pad_len + chunk_len) + all_logits.append(outputs["logits"][:, real, :]) - # With left-padding, real token logits are at positions pad_len: - all_logits.append(outputs["logits"][:, pad_len : pad_len + chunk_len, :]) + # Strip padding KV before updating cache so step advances by + # chunk_len (not prefill_seq_len). + real_outputs = {k: v for k, v in outputs.items() if not k.startswith("present_")} + for k, v in outputs.items(): + if k.startswith("present_"): + t = v if isinstance(v, torch.Tensor) else torch.tensor(v) + real_outputs[k] = t[:, :, real, :] + cache.update_all_layers(real_outputs) return torch.cat(all_logits, dim=1) @@ -355,16 +352,15 @@ def _run_prefill(self, input_ids: torch.Tensor, cache: Any) -> torch.Tensor: def _run_gen(self, input_ids: torch.Tensor, cache: Any) -> torch.Tensor: """Run gen model for a single token. Returns logits ``[1, 1, vocab_size]``.""" fc = cache.step - - filled = min(fc + 1, self._max_cache_len) - attn_mask = torch.zeros(1, self._max_cache_len, dtype=torch.int64) - attn_mask[0, max(0, self._max_cache_len - filled) :] = 1 + attn_mask = cache.build_decoder_mask(self._max_cache_len) feeds: dict[str, Any] = { "input_ids": input_ids, "attention_mask": attn_mask, "position_ids": torch.tensor([[fc]], dtype=torch.int64), } + if "cache_position" in self._gen_expected: + feeds["cache_position"] = feeds["position_ids"].squeeze(0) for i in range(self._num_kv_layers): feeds[f"past_{i}_key"] = cache.layers[i].keys.detach() feeds[f"past_{i}_value"] = cache.layers[i].values.detach() From 989dd9440aeacbc93a4e6130d7854533d693d2a4 Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Thu, 16 Apr 2026 11:08:14 +0800 Subject: [PATCH 10/32] fix: remove unused _pad_inputs from decoder_only.py Fixes CodeQL finding: global variable '_pad_inputs' is not used. The refactoring to polymorphic cache methods removed the last call site. --- src/winml/modelkit/models/winml/decoder_only.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/winml/modelkit/models/winml/decoder_only.py b/src/winml/modelkit/models/winml/decoder_only.py index 271f039fc..2794407c7 100644 --- a/src/winml/modelkit/models/winml/decoder_only.py +++ b/src/winml/modelkit/models/winml/decoder_only.py @@ -50,8 +50,6 @@ Design principles (same as composite_model.py): - ONNX I/O names and shapes are read from ``io_config``, never hardcoded. -- Inputs smaller than ONNX expected shape are zero-padded via ``_pad_inputs``. -- ``_pad_inputs`` is reused from ``WinMLEncoderDecoderModel`` (static method). """ from __future__ import annotations @@ -67,9 +65,6 @@ from .composite_model import WinMLCompositeModel -_pad_inputs = WinMLCompositeModel._pad_inputs - - if TYPE_CHECKING: from transformers import Cache, PretrainedConfig From 9e42116900b6ca1588670290235693aee3467413 Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Thu, 16 Apr 2026 11:19:09 +0800 Subject: [PATCH 11/32] docs: add static cache switching instructions to mu2.py Document how to switch Mu2 from WinMLSlidingWindowCache to WinMLStaticCache (3 changes: wrapper, OnnxConfig, get_cache_class). Verified: Mu2 e2e correct with both cache types (6/6 queries). --- src/winml/modelkit/models/hf/mu2.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/winml/modelkit/models/hf/mu2.py b/src/winml/modelkit/models/hf/mu2.py index 4ff39e303..0e92fcd1a 100644 --- a/src/winml/modelkit/models/hf/mu2.py +++ b/src/winml/modelkit/models/hf/mu2.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Mu2 encoder-decoder model with sliding-window KV cache. +"""Mu2 encoder-decoder model with KV cache. Export wrappers, OnnxConfig registrations, and ``WinMLMu2Model`` inference class for Mu2 (custom ``trust_remote_code`` model). @@ -11,7 +11,7 @@ class for Mu2 (custom ``trust_remote_code`` model). - Mu2EncoderWrapper (``feature-extraction``): encoder-only ONNX. - Mu2DecoderWrapper (``text2text-generation``): decoder with ``WinMLSlidingWindowCache`` (Slice+Concat, no ScatterElements). - Present KV output is the full updated buffer. + Present KV output is the new-token KV only. Custom model integration (``auto_map``): The Mu2 model uses ``trust_remote_code=True`` with ``auto_map`` in @@ -32,6 +32,24 @@ class for Mu2 (custom ``trust_remote_code`` model). ``eos_token_id`` to ``super().__init__()`` or PretrainedConfig overrides them to None. +Cache type: + +The default configuration uses ``WinMLSlidingWindowCache`` (FIFO +Slice+Concat). ``WinMLEncoderDecoderModel`` is cache-agnostic — mask +construction and cache updates are delegated to the cache class via +``build_decoder_mask``, ``position_input_name``, and +``update_all_layers``. To switch to ``WinMLStaticCache`` (index_copy_): + +1. **Export wrapper**: change ``Mu2DecoderWrapper.forward()`` to use + ``WinMLStaticCache`` and rename the position arg from ``position_id`` + to ``cache_position``. +2. **OnnxConfig inputs**: change ``"position_id"`` to + ``"cache_position"`` in ``Mu2DecoderIOConfig.inputs``. +3. **Inference**: override ``get_cache_class()`` to return + ``WinMLStaticCache``. ``WinMLEncoderDecoderModel`` uses + ``cache.position_input_name`` to select the correct ONNX input name + automatically. + Usage:: wmk config -m path/to/mu2 --task translation --trust-remote-code -o mu2.json From 0973dbd9313bc86f1548c11844b926643f943c32 Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Thu, 16 Apr 2026 14:47:33 +0800 Subject: [PATCH 12/32] feat: WinMLCompositeModel.from_onnx + from_pretrained composite routing auto.py: - Fix circular import: move WinMLCompositeModel import under TYPE_CHECKING, lazy import in from_onnx - from_onnx: delegate dict onnx_path to WinMLCompositeModel.from_onnx - from_pretrained: check PIPELINE_MODEL_REGISTRY before config phase, delegate to WinMLCompositeModel.from_pretrained for composite models composite_model.py: - Implement from_onnx: resolves concrete class from registry using task + hf_config.model_type, builds each sub-component via WinMLAutoModel.from_onnx with per-component task from _SUB_MODEL_CONFIG Verified: T5 translation and Mu2 translation via both from_onnx(dict) and from_pretrained produce correct results. --- src/winml/modelkit/models/auto.py | 52 ++++++++++++++++++- .../modelkit/models/winml/composite_model.py | 51 ++++++++++++++++++ 2 files changed, 101 insertions(+), 2 deletions(-) diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index 37be2b5b9..ac7b8a284 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -38,6 +38,7 @@ if TYPE_CHECKING: from ..config import WinMLBuildConfig from .winml.base import WinMLPreTrainedModel + from .winml.composite_model import WinMLCompositeModel logger = logging.getLogger(__name__) @@ -96,7 +97,7 @@ def __init__(self) -> None: @classmethod def from_onnx( cls, - onnx_path: str | Path, + onnx_path: str | Path | dict[str, str | Path], *, task: str | None = None, config: WinMLBuildConfig | None = None, @@ -109,7 +110,7 @@ def from_onnx( skip_build: bool = False, session_options: Any | None = None, **kwargs: Any, - ) -> WinMLPreTrainedModel: + ) -> WinMLPreTrainedModel | WinMLCompositeModel: """Build from a pre-exported ONNX file. Runs optimize -> [quantize] -> [compile] via ``build_onnx_model()``. @@ -130,6 +131,24 @@ def from_onnx( Returns: WinMLPreTrainedModel inference wrapper. """ + if isinstance(onnx_path, dict): + from .winml.composite_model import WinMLCompositeModel + + return WinMLCompositeModel.from_onnx( + onnx_path, + task=task, + config=config, + device=device, + precision=precision, + ep=ep, + cache_dir=cache_dir, + use_cache=use_cache, + force_rebuild=force_rebuild, + skip_build=skip_build, + session_options=session_options, + **kwargs, + ) + onnx_path = Path(onnx_path) if not onnx_path.is_file(): raise FileNotFoundError( @@ -281,6 +300,35 @@ def from_pretrained( **kwargs, ) + # ===================================================================== + # COMPOSITE MODEL CHECK — delegate to WinMLCompositeModel.from_pretrained + # when (model_type, task) is a registered composite (e.g., T5 translation, + # Qwen text-generation). AutoConfig is lightweight (~config.json only). + # ===================================================================== + if task is not None: + from transformers import AutoConfig + + _hf_cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) + _model_type = getattr(_hf_cfg, "model_type", None) + from .winml.composite_model import PIPELINE_MODEL_REGISTRY + + if (_model_type, task) in PIPELINE_MODEL_REGISTRY: + from .winml.composite_model import WinMLCompositeModel + + return WinMLCompositeModel.from_pretrained( + model_id, + task, + device=device, + use_cache=use_cache, + force_rebuild=force_rebuild, + trust_remote_code=trust_remote_code, + shape_config=shape_config, + precision=precision, + config=config, + cache_dir=cache_dir, + **kwargs, + ) + # ===================================================================== # [1] CONFIG PHASE - Generate complete config with I/O specs (Lightweight, ~2s) # ===================================================================== diff --git a/src/winml/modelkit/models/winml/composite_model.py b/src/winml/modelkit/models/winml/composite_model.py index f1c9192f5..e2bc74b47 100644 --- a/src/winml/modelkit/models/winml/composite_model.py +++ b/src/winml/modelkit/models/winml/composite_model.py @@ -174,6 +174,57 @@ def from_pretrained( return cls(sub_models=sub_models, config=hf_config) + @classmethod + def from_onnx( + cls, + onnx_path: dict[str, str], + *, + task: str | None = None, + **kwargs: Any, + ) -> WinMLCompositeModel: + """Load composite model from pre-built ONNX files. + + Resolves the concrete model class from the registry using *task* + and ``hf_config.model_type``, then builds each sub-component via + ``WinMLAutoModel.from_onnx``. + + Args: + onnx_path: Maps component name (e.g., ``"encoder"``, + ``"decoder_prefill"``) to its ONNX file path. + task: Pipeline task (e.g., ``"translation"``, + ``"text-generation"``). + **kwargs: Must include ``hf_config`` (``PretrainedConfig``). + May include ``sub_model_kwargs`` for per-component + overrides. Remaining kwargs are forwarded to + ``WinMLAutoModel.from_onnx`` for every component. + """ + from pathlib import Path + + hf_config = kwargs.pop("hf_config", None) + sub_model_kwargs = kwargs.pop("sub_model_kwargs", None) or {} + + # Resolve concrete class from registry + model_type = getattr(hf_config, "model_type", None) if hf_config else None + if not cls._SUB_MODEL_CONFIG: + resolved_cls = PIPELINE_MODEL_REGISTRY.get((model_type, task)) + if resolved_cls is None: + raise ValueError( + f"No composite model for ({model_type!r}, {task!r}). " + f"Registered: {list(PIPELINE_MODEL_REGISTRY.keys())}" + ) + else: + resolved_cls = cls + + from ..auto import WinMLAutoModel + + sub_models: dict[str, Any] = {} + for name, path in onnx_path.items(): + component_task = resolved_cls._SUB_MODEL_CONFIG.get(name) + merged = {**kwargs, "task": component_task, **sub_model_kwargs.get(name, {})} + sub_models[name] = WinMLAutoModel.from_onnx(Path(path), **merged) + + return resolved_cls(sub_models=sub_models, config=hf_config) + @property def device(self) -> torch.device: """Device (CPU — ORT handles actual placement).""" From 0d9d10ec7362d370471a669cecb02f7273f7ae3d Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Thu, 16 Apr 2026 16:13:45 +0800 Subject: [PATCH 13/32] feat: composite model support in run_eval.py + T5 summarization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit run_eval.py: - Generalize _run_build: winml config produces a list of config JSONs (1 for single model, N for composite), build loop handles both - Generalize run_model (perf): takes list of ONNX paths, runs perf for each, merges results — single model is list-of-1 case t5.py: - Register WinMLT5Model for ("t5", "summarization") in addition to ("t5", "translation") timeout_skip_list.json: - Remove T5 entries (t5-small, t5-base, t5-3b) — composite model build now works Verified: T5-small summarization + translation via both from_pretrained and run_eval.py perf pipeline. --- scripts/e2e_eval/cache/timeout_skip_list.json | 15 - scripts/e2e_eval/run_eval.py | 278 +++++++++++------- src/winml/modelkit/models/hf/t5.py | 3 +- 3 files changed, 174 insertions(+), 122 deletions(-) diff --git a/scripts/e2e_eval/cache/timeout_skip_list.json b/scripts/e2e_eval/cache/timeout_skip_list.json index 02e9e2caf..d8d3988fb 100644 --- a/scripts/e2e_eval/cache/timeout_skip_list.json +++ b/scripts/e2e_eval/cache/timeout_skip_list.json @@ -29,21 +29,6 @@ "task": "translation", "reason": "Build hangs >14000s, m2m_100 enc-dec export issue" }, - { - "hf_id": "google-t5/t5-3b", - "task": "translation", - "reason": "Build hangs >14000s, t5 enc-dec + network download issue" - }, - { - "hf_id": "google-t5/t5-base", - "task": "summarization", - "reason": "Build hangs >2000s, t5 enc-dec export issue" - }, - { - "hf_id": "google-t5/t5-small", - "task": "translation", - "reason": "Build hangs >8000s, t5 enc-dec export issue" - }, { "hf_id": "knkarthick/MEETING_SUMMARY", "task": "summarization", diff --git a/scripts/e2e_eval/run_eval.py b/scripts/e2e_eval/run_eval.py index adb7bd858..10228f466 100644 --- a/scripts/e2e_eval/run_eval.py +++ b/scripts/e2e_eval/run_eval.py @@ -353,7 +353,11 @@ def _run_build( ) -> dict: """Run winml config + winml build for one model. Returns build result dict. - Flow: winml config → config.json → winml build --use-cache → ONNX path. + Flow: winml config → list of config JSONs → winml build each → ONNX paths. + + Single models produce one config; composite models (e.g., T5 translation) + produce one per sub-component (suffixed names). Both go through the same + build loop — single model is just the list-of-1 case. """ config_path = model_dir / "build_config.json" model_dir.mkdir(parents=True, exist_ok=True) @@ -378,40 +382,68 @@ def _run_build( if config_proc["exit_code"] != 0: return { "success": False, - "onnx_path": None, + "onnx_paths": [], "stage": "config", "proc": config_proc, } - # Step 2: winml build --use-cache - build_args = [ - *WINML_CLI, - "build", - "-c", - str(config_path), - "-m", - entry.hf_id, - "--use-cache", - ] + # Collect config files: composite models produce suffixed files + # (e.g., build_config_encoder.json); single models produce config_path itself. + sub_configs = sorted(config_path.parent.glob(f"{config_path.stem}_*.json")) + if not sub_configs: + sub_configs = [config_path] - build_proc = _run_subprocess(build_args, timeout) - if build_proc["exit_code"] != 0: - return { - "success": False, - "onnx_path": None, - "stage": "build", - "proc": build_proc, - } + # Step 2: build each sub-config + onnx_paths: list[str] = [] + last_proc = config_proc + + for sub_cfg in sub_configs: + label = sub_cfg.stem.removeprefix(f"{config_path.stem}_") if len(sub_configs) > 1 else "" + if label: + safe_print(f" building component: {label}") + + build_args = [ + *WINML_CLI, + "build", + "-c", + str(sub_cfg), + "-m", + entry.hf_id, + "--use-cache", + ] + + build_proc = _run_subprocess(build_args, timeout) + last_proc = build_proc + if build_proc["exit_code"] != 0: + stage = f"build_{label}" if label else "build" + return { + "success": False, + "onnx_paths": onnx_paths, + "stage": stage, + "proc": build_proc, + } + + task_hint = _extract_task_from_config(sub_cfg) or entry.task + path = _extract_onnx_path(build_proc, entry.hf_id, task_hint) + if path: + onnx_paths.append(path) + + return { + "success": len(onnx_paths) == len(sub_configs), + "onnx_paths": onnx_paths, + "stage": "complete", + "proc": last_proc, + } - # Extract ONNX path from build output - # winml build prints "Final artifact: " in stderr + +def _extract_onnx_path(build_proc: dict, hf_id: str, task: str | None) -> str | None: + """Extract ONNX path from build subprocess output.""" onnx_path = None for line in build_proc["stderr"].splitlines(): if "Final artifact:" in line: onnx_path = line.split("Final artifact:")[-1].strip() break - # Fallback: search cache for the built model if not onnx_path: for line in build_proc["stdout"].splitlines(): if "Final artifact:" in line: @@ -419,16 +451,19 @@ def _run_build( break if not onnx_path or not Path(onnx_path).exists(): - # Last resort: find _model.onnx in the cache - onnx_path = _find_cached_model(entry.hf_id, build_proc, entry.task) + onnx_path = _find_cached_model(hf_id, build_proc, task) - return { - "success": onnx_path is not None, - "onnx_path": onnx_path, - "stage": "complete", - "proc": build_proc, - "config_path": str(config_path), - } + return onnx_path + + +def _extract_task_from_config(config_path: Path) -> str | None: + """Read the task from a build config JSON file.""" + try: + data = json.loads(config_path.read_text(encoding="utf-8")) + loader = data.get("loader", {}) + return loader.get("task") + except (OSError, json.JSONDecodeError): + return None def _find_cached_model(hf_id: str, build_proc: dict, task: str | None = None) -> str | None: @@ -447,6 +482,7 @@ def _find_cached_model(hf_id: str, build_proc: dict, task: str | None = None) -> return None from winml.modelkit.loader.task import get_task_abbrev + prefix = get_task_abbrev(task) + "_" model_files = sorted( @@ -466,16 +502,16 @@ def run_model( entry: ModelEntry, device: str, timeout: int, - onnx_path: str | None = None, + onnx_paths: list[str] | None = None, ) -> dict: - """Execute winml perf for one model. Returns raw subprocess result dict. + """Execute winml perf for one or more ONNX models. Returns merged result dict. - When onnx_path is provided, benchmarks the pre-built ONNX directly - (skips internal build). Otherwise falls back to HF model ID. + When onnx_paths is provided, benchmarks each pre-built ONNX directly. + Single model is just the list-of-1 case. Results are merged (worst exit + code, concatenated stdout/stderr, summed elapsed). """ - if onnx_path: - args = [*WINML_CLI, "perf", "-m", onnx_path, "--device", device] - else: + if not onnx_paths: + # No pre-built paths: fall back to HF model ID (single model only) args = [ *WINML_CLI, "perf", @@ -488,22 +524,69 @@ def run_model( ] if entry.task: args += ["--task", entry.task] + args += ["--iterations", "10", "--warmup", "2"] + args += entry.perf_args + + proc = _run_subprocess(args, timeout) + proc["device"] = device + proc["timestamp"] = _utc_now() + proc["error_summary"] = ( + "" + if proc["exit_code"] == 0 + else f"timeout ({timeout}s)" + if proc["timeout"] + else f"exit code {proc['exit_code']}" + ) + return proc + + # Run perf for each sub-model and merge results + all_stdout: list[str] = [] + all_stderr: list[str] = [] + total_elapsed = 0.0 + worst_exit = 0 + any_timeout = False + commands: list[str] = [] + + for path in onnx_paths: + component = Path(path).parent.name if len(onnx_paths) > 1 else "" + if component: + safe_print(f" perf: {component}") + + args = [*WINML_CLI, "perf", "-m", path, "--device", device] + args += ["--iterations", "10", "--warmup", "2"] + args += entry.perf_args + + proc = _run_subprocess(args, timeout) + if component: + all_stdout.append(f"=== {component} ===\n{proc['stdout']}") + all_stderr.append(f"=== {component} ===\n{proc['stderr']}") + else: + all_stdout.append(proc["stdout"]) + all_stderr.append(proc["stderr"]) + total_elapsed += proc["elapsed"] + commands.append(proc["command"]) + if proc["exit_code"] != 0: + worst_exit = proc["exit_code"] + if proc["timeout"]: + any_timeout = True - args += ["--iterations", "10", "--warmup", "2"] - args += entry.perf_args - - proc = _run_subprocess(args, timeout) - # Attach device and timestamp for build_eval_result - proc["device"] = device - proc["timestamp"] = _utc_now() - proc["error_summary"] = ( - "" - if proc["exit_code"] == 0 - else f"timeout ({timeout}s)" - if proc["timeout"] - else f"exit code {proc['exit_code']}" - ) - return proc + return { + "stdout": "\n".join(all_stdout), + "stderr": "\n".join(all_stderr), + "exit_code": worst_exit, + "elapsed": round(total_elapsed, 1), + "timeout": any_timeout, + "command": commands[0] if len(commands) == 1 else " | ".join(commands), + "device": device, + "timestamp": _utc_now(), + "error_summary": ( + "" + if worst_exit == 0 + else f"timeout ({timeout}s)" + if any_timeout + else f"exit code {worst_exit}" + ), + } # --------------------------------------------------------------------------- @@ -1182,70 +1265,53 @@ def main() -> None: perf_proc: dict | None = None accuracy_result: dict | None = None - # Build phase: winml config + winml build → ONNX path + # Build phase: winml config + winml build → list of ONNX paths # Build is shared by perf and eval, avoiding redundant builds. - onnx_path: str | None = None - if args.eval_type in ("perf", "both"): - build_result = _run_build( - entry, - args.device, - _DEFAULT_PRECISION, - args.timeout, - model_dir, - ) - if build_result["success"]: - onnx_path = build_result["onnx_path"] - - if args.eval_type == "accuracy": - # Accuracy-only: build + eval (no perf) - build_result = _run_build( + build_result = _run_build( + entry, + args.device, + _DEFAULT_PRECISION, + args.timeout, + model_dir, + ) + onnx_paths = build_result["onnx_paths"] if build_result["success"] else [] + + if not build_result["success"]: + # Build failed — synthesize failed result for downstream phases + fail_proc = build_result["proc"] + fail_proc["device"] = args.device + fail_proc["timestamp"] = _utc_now() + fail_proc["error_summary"] = f"build_{build_result['stage']}_failed" + + if args.eval_type != "accuracy": + perf_proc = fail_proc + if args.eval_type != "perf": + accuracy_result = {"skipped": True, "skip_reason": "build_failed"} + elif args.eval_type == "accuracy": + accuracy_result = _run_accuracy_phase( entry, args.device, - _DEFAULT_PRECISION, args.timeout, model_dir, + # TODO: fix for composite model once supported + onnx_paths[0] if onnx_paths else None, ) - if build_result["success"]: - onnx_path = build_result["onnx_path"] + elif args.eval_type == "perf": + perf_proc = run_model(entry, args.device, args.timeout, onnx_paths) + else: + # "both": perf → eval + perf_proc = run_model(entry, args.device, args.timeout, onnx_paths) + if perf_proc["exit_code"] != 0: + accuracy_result = {"skipped": True, "skip_reason": "perf_failed"} + else: accuracy_result = _run_accuracy_phase( entry, args.device, args.timeout, model_dir, - onnx_path, + # TODO: fix for composite model once supported + onnx_paths[0] if onnx_paths else None, ) - else: - accuracy_result = {"skipped": True, "skip_reason": "build_failed"} - elif args.eval_type == "perf": - if onnx_path: - perf_proc = run_model(entry, args.device, args.timeout, onnx_path) - else: - # Build failed — synthesize a failed perf result - perf_proc = build_result["proc"] - perf_proc["device"] = args.device - perf_proc["timestamp"] = _utc_now() - perf_proc["error_summary"] = f"build_{build_result['stage']}_failed" - else: - # "both": build → perf → eval - if onnx_path: - perf_proc = run_model(entry, args.device, args.timeout, onnx_path) - if perf_proc["exit_code"] != 0: - accuracy_result = {"skipped": True, "skip_reason": "perf_failed"} - else: - accuracy_result = _run_accuracy_phase( - entry, - args.device, - args.timeout, - model_dir, - onnx_path, - ) - else: - # Build failed - perf_proc = build_result["proc"] - perf_proc["device"] = args.device - perf_proc["timestamp"] = _utc_now() - perf_proc["error_summary"] = f"build_{build_result['stage']}_failed" - accuracy_result = {"skipped": True, "skip_reason": "build_failed"} except KeyboardInterrupt: safe_print("\n\n[Ctrl+C] Interrupted — generating reports for completed models...") diff --git a/src/winml/modelkit/models/hf/t5.py b/src/winml/modelkit/models/hf/t5.py index e534ac626..d8062d0b7 100644 --- a/src/winml/modelkit/models/hf/t5.py +++ b/src/winml/modelkit/models/hf/t5.py @@ -280,8 +280,9 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 @register_composite_model("t5", "translation") +@register_composite_model("t5", "summarization") class WinMLT5Model(WinMLEncoderDecoderModel): - """T5 encoder-decoder model for translation. + """T5 encoder-decoder model for seq2seq tasks (translation, summarization). Declares T5 sub-component tasks and generation config defaults. All encoder-decoder forward/cache logic lives in ``WinMLEncoderDecoderModel``. From a4cdecccbfdf7ad0531abf39cd48be788a7384c0 Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Thu, 16 Apr 2026 17:19:57 +0800 Subject: [PATCH 14/32] feat: remove_isnan_in_attention_mask surgery + optim configs for T5/Qwen/Mu2 surgery capability: - Add REMOVE_ISNAN_IN_ATTENTION_MASK: removes Softmax->IsNaN->Where NaN guard patterns (dead code when clamp_constant_values replaces -inf with finite value) model configs (gelu_fusion, fuse_rmsnorm, matmul_add_fusion, clamp_constant_values, remove_isnan_in_attention_mask): - T5_CONFIG: registered for model_type "t5" - MU2_CONFIG: registered for model_type "mu2" - QWEN_CONFIG: add optim flags (already had export config) run_eval.py: - Fix _extract_onnx_path: match "Artifact:" and "Existing artifact found:" markers in addition to "Final artifact:" --- scripts/e2e_eval/run_eval.py | 19 ++-- src/winml/modelkit/models/hf/__init__.py | 4 + src/winml/modelkit/models/hf/mu2.py | 13 +++ src/winml/modelkit/models/hf/qwen.py | 8 ++ src/winml/modelkit/models/hf/t5.py | 13 +++ .../modelkit/optim/capabilities/surgery.py | 11 ++ src/winml/modelkit/optim/pipes/surgery.py | 102 +++++++++++++++++- 7 files changed, 160 insertions(+), 10 deletions(-) diff --git a/scripts/e2e_eval/run_eval.py b/scripts/e2e_eval/run_eval.py index 10228f466..9d586b855 100644 --- a/scripts/e2e_eval/run_eval.py +++ b/scripts/e2e_eval/run_eval.py @@ -438,18 +438,19 @@ def _run_build( def _extract_onnx_path(build_proc: dict, hf_id: str, task: str | None) -> str | None: """Extract ONNX path from build subprocess output.""" + # Patterns used by winml build to report the artifact path + markers = ("Final artifact:", "Existing artifact found:", "Artifact:") onnx_path = None - for line in build_proc["stderr"].splitlines(): - if "Final artifact:" in line: - onnx_path = line.split("Final artifact:")[-1].strip() + for line in (build_proc["stderr"] + build_proc["stdout"]).splitlines(): + for marker in markers: + if marker in line: + candidate = line.split(marker)[-1].strip() + if candidate and Path(candidate).exists(): + onnx_path = candidate + break + if onnx_path: break - if not onnx_path: - for line in build_proc["stdout"].splitlines(): - if "Final artifact:" in line: - onnx_path = line.split("Final artifact:")[-1].strip() - break - if not onnx_path or not Path(onnx_path).exists(): onnx_path = _find_cached_model(hf_id, build_proc, task) diff --git a/src/winml/modelkit/models/hf/__init__.py b/src/winml/modelkit/models/hf/__init__.py index 97b7cda64..6b46a6567 100644 --- a/src/winml/modelkit/models/hf/__init__.py +++ b/src/winml/modelkit/models/hf/__init__.py @@ -37,6 +37,7 @@ from .depth_pro import DepthProIOConfig as _DepthProIOConfig # triggers registration from .detr import DETR_CONFIG from .mu2 import MODEL_CLASS_MAPPING as _MU2_CLASS_MAPPING +from .mu2 import MU2_CONFIG from .mu2 import Mu2DecoderIOConfig as _Mu2DecoderIOConfig # triggers registration from .mu2 import Mu2EncoderIOConfig as _Mu2EncoderIOConfig # triggers registration from .qwen import MODEL_CLASS_MAPPING as _QWEN_CLASS_MAPPING @@ -49,6 +50,7 @@ from .segformer import MODEL_CLASS_MAPPING as _SEGFORMER_CLASS_MAPPING from .segformer import SegformerIOConfig as _SegformerIOConfig # triggers registration from .t5 import MODEL_CLASS_MAPPING as _T5_CLASS_MAPPING +from .t5 import T5_CONFIG from .t5 import T5DecoderIOConfig as _T5DecoderIOConfig # triggers registration from .t5 import T5EncoderIOConfig as _T5EncoderIOConfig # triggers registration from .vision_encoder_decoder import VISION_ENCODER_DECODER_CONFIG @@ -77,7 +79,9 @@ "clip-vision-model": CLIP_CONFIG, "detr": DETR_CONFIG, "roberta": ROBERTA_FAMILY_CONFIG, + "mu2": MU2_CONFIG, "qwen3": QWEN_CONFIG, + "t5": T5_CONFIG, "vision-encoder-decoder": VISION_ENCODER_DECODER_CONFIG, "xlm-roberta": ROBERTA_FAMILY_CONFIG, } diff --git a/src/winml/modelkit/models/hf/mu2.py b/src/winml/modelkit/models/hf/mu2.py index 0e92fcd1a..5f351bc4a 100644 --- a/src/winml/modelkit/models/hf/mu2.py +++ b/src/winml/modelkit/models/hf/mu2.py @@ -67,7 +67,9 @@ class for Mu2 (custom ``trust_remote_code`` model). from optimum.utils import NormalizedConfig from optimum.utils.input_generators import DummyTextInputGenerator +from ...config import WinMLBuildConfig from ...export import register_onnx_overwrite +from ...optim import WinMLOptimizationConfig from ..winml.composite_model import register_composite_model from .encoder_decoder import EncoderDecoderInputGenerator, WinMLEncoderDecoderModel from .kv_cache import PastKeyValueInputGenerator, WinMLSlidingWindowCache @@ -267,6 +269,16 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 ("mu2", "text2text-generation"): Mu2DecoderWrapper, } +MU2_CONFIG = WinMLBuildConfig( + optim=WinMLOptimizationConfig( + gelu_fusion=True, + fuse_rmsnorm=True, + matmul_add_fusion=True, + clamp_constant_values=True, + remove_isnan_in_attention_mask=True, + ), +) + @register_composite_model("mu2", "translation") class WinMLMu2Model(WinMLEncoderDecoderModel): @@ -314,6 +326,7 @@ def generation_config(self, value: Any) -> None: __all__ = [ "MODEL_CLASS_MAPPING", + "MU2_CONFIG", "Mu2DecoderIOConfig", "Mu2DecoderWrapper", "Mu2EncoderIOConfig", diff --git a/src/winml/modelkit/models/hf/qwen.py b/src/winml/modelkit/models/hf/qwen.py index 7869d100f..2fba45902 100644 --- a/src/winml/modelkit/models/hf/qwen.py +++ b/src/winml/modelkit/models/hf/qwen.py @@ -102,6 +102,7 @@ from ...config import WinMLBuildConfig from ...export import register_onnx_overwrite from ...export.config import WinMLExportConfig +from ...optim import WinMLOptimizationConfig from ..winml import register_specialization from ..winml.composite_model import register_composite_model from ..winml.decoder_only import ( @@ -297,6 +298,13 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 QWEN_CONFIG = WinMLBuildConfig( export=WinMLExportConfig(dynamo=True, opset_version=18), + optim=WinMLOptimizationConfig( + gelu_fusion=True, + fuse_rmsnorm=True, + matmul_add_fusion=True, + clamp_constant_values=True, + remove_isnan_in_attention_mask=True, + ), ) diff --git a/src/winml/modelkit/models/hf/t5.py b/src/winml/modelkit/models/hf/t5.py index d8062d0b7..923206972 100644 --- a/src/winml/modelkit/models/hf/t5.py +++ b/src/winml/modelkit/models/hf/t5.py @@ -34,7 +34,9 @@ from transformers import T5ForConditionalGeneration from transformers.cache_utils import DynamicCache, EncoderDecoderCache +from ...config import WinMLBuildConfig from ...export import register_onnx_overwrite +from ...optim import WinMLOptimizationConfig from ..winml.composite_model import register_composite_model from .encoder_decoder import EncoderDecoderInputGenerator, WinMLEncoderDecoderModel from .kv_cache import PastKeyValueInputGenerator, WinMLStaticCache @@ -273,6 +275,16 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 ("t5", "text2text-generation"): T5DecoderWrapper, } +T5_CONFIG = WinMLBuildConfig( + optim=WinMLOptimizationConfig( + gelu_fusion=True, + fuse_rmsnorm=True, + matmul_add_fusion=True, + clamp_constant_values=True, + remove_isnan_in_attention_mask=True, + ), +) + # ============================================================================= # WinMLT5Model — inference wrapper (registered as pipeline model) @@ -335,6 +347,7 @@ def generation_config(self, value: Any) -> None: __all__ = [ "MODEL_CLASS_MAPPING", + "T5_CONFIG", "T5DecoderIOConfig", "T5DecoderWrapper", "T5EncoderIOConfig", diff --git a/src/winml/modelkit/optim/capabilities/surgery.py b/src/winml/modelkit/optim/capabilities/surgery.py index 7d851c6fa..8b2048f00 100644 --- a/src/winml/modelkit/optim/capabilities/surgery.py +++ b/src/winml/modelkit/optim/capabilities/surgery.py @@ -26,3 +26,14 @@ category=CapabilityCategory.SURGERY, default=False, ) + +# Remove Softmax -> IsNaN -> Where NaN guard patterns in attention. +# These guards are dead code when clamp_constant_values replaces -inf +# with a finite value (Softmax never produces NaN). +REMOVE_ISNAN_IN_ATTENTION_MASK = BoolCapability( + name="remove-isnan-in-attention-mask", + ort_name=None, # Custom implementation, not ORT optimizer + description="Remove Softmax->IsNaN->Where NaN guard patterns in attention", + category=CapabilityCategory.SURGERY, + default=False, +) diff --git a/src/winml/modelkit/optim/pipes/surgery.py b/src/winml/modelkit/optim/pipes/surgery.py index 0a90227b2..fa4fa6bcf 100644 --- a/src/winml/modelkit/optim/pipes/surgery.py +++ b/src/winml/modelkit/optim/pipes/surgery.py @@ -37,6 +37,7 @@ SURGERY_CAPABILITIES: dict[str, Any] = caps_dict( surgery.CLAMP_CONSTANT_VALUES, + surgery.REMOVE_ISNAN_IN_ATTENTION_MASK, ) @@ -53,12 +54,16 @@ class SurgeryPipeConfig(PipeConfig): clamp_constant_values: Whether to clamp extreme float constants clamp_min: Minimum value for constant clamping (default: -1e3) clamp_max: Maximum value for constant clamping (default: 1e3) + fix_nan_attention_mask: Replace -inf attention mask with finite value + and remove Softmax->IsNaN->Where NaN guard patterns + mask_value: Replacement value for -inf (default: -1e3) verbose: Enable verbose logging """ clamp_constant_values: bool = False clamp_min: float = -1e3 clamp_max: float = 1e3 + remove_isnan_in_attention_mask: bool = False verbose: bool = False @@ -90,6 +95,7 @@ def build_config(cls, **kwargs: Any) -> SurgeryPipeConfig: - clamp_constant_values: Enable/disable constant clamping - clamp_min: Minimum value for clamping (default: -1e3) - clamp_max: Maximum value for clamping (default: 1e3) + - remove_isnan_in_attention_mask: Remove IsNaN guard patterns - verbose: Enable verbose logging Returns: @@ -99,6 +105,7 @@ def build_config(cls, **kwargs: Any) -> SurgeryPipeConfig: clamp_constant_values=kwargs.get("clamp_constant_values", False), clamp_min=kwargs.get("clamp_min", -1e3), clamp_max=kwargs.get("clamp_max", 1e3), + remove_isnan_in_attention_mask=kwargs.get("remove_isnan_in_attention_mask", False), verbose=kwargs.get("verbose", False), ) @@ -112,7 +119,7 @@ def should_process(cls, config: SurgeryPipeConfig) -> bool: Returns: True if any surgery operation is enabled """ - return config.clamp_constant_values + return config.clamp_constant_values or config.remove_isnan_in_attention_mask def process(self, model: onnx.ModelProto, config: SurgeryPipeConfig) -> onnx.ModelProto: """Apply surgery operations to the model. @@ -139,6 +146,9 @@ def process(self, model: onnx.ModelProto, config: SurgeryPipeConfig) -> onnx.Mod model_copy, config.clamp_min, config.clamp_max, config.verbose ) + if config.remove_isnan_in_attention_mask: + model_copy = self._remove_isnan_in_attention_mask(model_copy, config.verbose) + return model_copy def _clamp_constant_values( @@ -219,3 +229,93 @@ def _clamp_constant_values( logger.debug("Clamped tensors: %s", clamped_tensors) return model + + # ----------------------------------------------------------------- + # remove-isnan-in-attention-mask + # ----------------------------------------------------------------- + + def _remove_isnan_in_attention_mask( + self, + model: onnx.ModelProto, + verbose: bool = False, + ) -> onnx.ModelProto: + """Remove Softmax → IsNaN → Where NaN guard patterns in attention. + + Pattern: Softmax → IsNaN → Where(isnan, 0, softmax_out) + Remove IsNaN + guard Where, use Softmax output directly. + + These guards are dead code when clamp_constant_values has already + replaced -inf with a finite value (Softmax never produces NaN). + + Args: + model: ONNX model (modified in place). + verbose: Log details about each removal. + + Returns: + Model with IsNaN guard patterns removed. + """ + guard_count = 0 + + # Build output→node map + output_to_node: dict[str, onnx.NodeProto] = {} + for node in model.graph.node: + for out in node.output: + output_to_node[out] = node + + nodes_to_remove: list[onnx.NodeProto] = [] + rewire_map: dict[str, str] = {} + + for node in list(model.graph.node): + if node.op_type != "IsNaN": + continue + producer = output_to_node.get(node.input[0]) + if producer is None or producer.op_type != "Softmax": + continue + softmax_out = producer.output[0] + isnan_out = node.output[0] + + # Find guard Where consuming IsNaN output + guard_wheres = [ + n for n in model.graph.node if n.op_type == "Where" and isnan_out in n.input + ] + if len(guard_wheres) != 1: + continue + guard_where = guard_wheres[0] + if softmax_out not in guard_where.input: + continue + + guard_out = guard_where.output[0] + nodes_to_remove.extend([node, guard_where]) + rewire_map[guard_out] = softmax_out + guard_count += 1 + if verbose: + logger.info( + " remove-isnan: remove %s + %s, rewire %s -> %s", + node.name, + guard_where.name, + guard_out, + softmax_out, + ) + + # Apply rewiring + for node in model.graph.node: + for i, inp in enumerate(node.input): + if inp in rewire_map: + node.input[i] = rewire_map[inp] + for out in model.graph.output: + if out.name in rewire_map: + out.name = rewire_map[out.name] + + # Remove dead nodes + remove_ids = {id(n) for n in nodes_to_remove} + remaining = [n for n in model.graph.node if id(n) not in remove_ids] + del model.graph.node[:] + model.graph.node.extend(remaining) + + if guard_count: + logger.info( + "SurgeryPipe: remove-isnan-in-attention-mask: %d IsNaN+Where guards removed", + guard_count, + ) + + return model From e57b33767352033804aa6d2c5ee6803adf56804e Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Thu, 16 Apr 2026 17:39:51 +0800 Subject: [PATCH 15/32] fix: enable DepthPro ONNX registration + update timeout skip list - Uncomment @register_onnx_overwrite for depth_pro (was disabled pending quantization fix) - Add apple/DepthPro-hf and Qwen3 models to timeout skip list (OOM segfault during in-process quantization) --- scripts/e2e_eval/cache/timeout_skip_list.json | 20 +++++++++++++++++++ src/winml/modelkit/models/hf/depth_pro.py | 5 +++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/scripts/e2e_eval/cache/timeout_skip_list.json b/scripts/e2e_eval/cache/timeout_skip_list.json index d8d3988fb..246344435 100644 --- a/scripts/e2e_eval/cache/timeout_skip_list.json +++ b/scripts/e2e_eval/cache/timeout_skip_list.json @@ -43,5 +43,25 @@ "hf_id": "philschmid/bart-large-cnn-samsum", "task": "summarization", "reason": "Build hangs >10000s, bart enc-dec export issue" + }, + { + "hf_id": "apple/DepthPro-hf", + "task": "depth-estimation", + "reason": "OOM in quantization (model too large for in-process quantize)" + }, + { + "hf_id": "Qwen/Qwen3-0.6B", + "task": "text-generation", + "reason": "OOM in quantization (segfault during quantize, 2.9GB model)" + }, + { + "hf_id": "Qwen/Qwen3-1.7B", + "task": "text-generation", + "reason": "OOM in quantization (model too large for in-process quantize)" + }, + { + "hf_id": "Qwen/Qwen3-8B", + "task": "text-generation", + "reason": "OOM in quantization (model too large for in-process quantize)" } ] diff --git a/src/winml/modelkit/models/hf/depth_pro.py b/src/winml/modelkit/models/hf/depth_pro.py index d24850900..a5d53a770 100644 --- a/src/winml/modelkit/models/hf/depth_pro.py +++ b/src/winml/modelkit/models/hf/depth_pro.py @@ -27,6 +27,8 @@ from optimum.utils import NormalizedConfig from optimum.utils.input_generators import DummyVisionInputGenerator +from ...export import register_onnx_overwrite + class _DepthProNormalizedConfig(NormalizedConfig): """Normalized config for DepthPro with computed image_size. @@ -44,8 +46,7 @@ def image_size(self) -> int: return int(self.config.patch_size / min(self.config.scaled_images_ratios)) -# TODO: enable registration once quantization can be done with enough RAM -# @register_onnx_overwrite("depth_pro", "depth-estimation", library_name="transformers") +@register_onnx_overwrite("depth_pro", "depth-estimation", library_name="transformers") class DepthProIOConfig(OnnxConfig): """ONNX config for DepthPro depth estimation. From d4be9dc521d7f7a63eec750efd83fda92b9f9f20 Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Thu, 16 Apr 2026 17:48:15 +0800 Subject: [PATCH 16/32] test: add unit tests for polymorphic KV cache and composite from_onnx test_io.py (14 new tests): - TestStaticCacheBuildDecoderMask: left-aligned mask, num_new_tokens - TestSlidingWindowCacheBuildDecoderMask: right-aligned mask, saturation - TestStaticCachePreparePrefillChunk: right-pad, pad_len=0 - TestSlidingWindowCachePreparePrefillChunk: left-pad, pad_len>0 test_auto_onnx.py (1 new test): - TestFromOnnxDictDispatch: dict onnx_path delegates to WinMLCompositeModel.from_onnx with correct kwargs --- tests/unit/export/test_io.py | 132 +++++++++++++++++++++++ tests/unit/models/auto/test_auto_onnx.py | 27 +++++ 2 files changed, 159 insertions(+) diff --git a/tests/unit/export/test_io.py b/tests/unit/export/test_io.py index 8f3a575fa..2e85e1615 100644 --- a/tests/unit/export/test_io.py +++ b/tests/unit/export/test_io.py @@ -829,3 +829,135 @@ def test_kv_shape_matches_prefill(self, qwen_config) -> None: def test_input_ids_single_token(self, qwen_config) -> None: inputs = generate_dummy_inputs("qwen3", "text-generation", qwen_config) assert inputs["input_ids"].shape == (1, 1) + + +# ============================================================================= +# WinMLCache — build_decoder_mask and prepare_prefill_chunk +# ============================================================================= + + +def _make_cache(cls, num_layers=2, num_heads=2, max_cache_len=16, head_dim=8): + """Create a WinMLCache instance with minimal config. + + Uses a real PretrainedConfig subclass because HF StaticCache.__init__ + calls config.get_text_config(). + """ + from transformers import PretrainedConfig + + config = PretrainedConfig(num_hidden_layers=num_layers) + cache = cls.create(config, [1, num_heads, max_cache_len, head_dim], torch.float32) + cache.reset() + return cache + + +class TestStaticCacheBuildDecoderMask: + """WinMLStaticCache.build_decoder_mask — left-aligned mask.""" + + def test_default_single_token(self) -> None: + from winml.modelkit.models.hf.kv_cache import WinMLStaticCache + + cache = _make_cache(WinMLStaticCache) + cache.step = 3 + mask = cache.build_decoder_mask(16) + assert mask.shape == (1, 16) + assert mask[0, :4].tolist() == [1, 1, 1, 1] + assert mask[0, 4:].sum().item() == 0 + + def test_num_new_tokens(self) -> None: + from winml.modelkit.models.hf.kv_cache import WinMLStaticCache + + cache = _make_cache(WinMLStaticCache) + cache.step = 2 + mask = cache.build_decoder_mask(16, num_new_tokens=4) + assert mask[0, :6].tolist() == [1, 1, 1, 1, 1, 1] + assert mask[0, 6:].sum().item() == 0 + + +class TestSlidingWindowCacheBuildDecoderMask: + """WinMLSlidingWindowCache.build_decoder_mask — right-aligned mask.""" + + def test_default_single_token(self) -> None: + from winml.modelkit.models.hf.kv_cache import WinMLSlidingWindowCache + + cache = _make_cache(WinMLSlidingWindowCache) + cache.step = 3 + mask = cache.build_decoder_mask(16) + # rightmost 4 positions should be 1 + assert mask[0, -4:].tolist() == [1, 1, 1, 1] + assert mask[0, :-4].sum().item() == 0 + + def test_num_new_tokens(self) -> None: + from winml.modelkit.models.hf.kv_cache import WinMLSlidingWindowCache + + cache = _make_cache(WinMLSlidingWindowCache) + cache.step = 2 + mask = cache.build_decoder_mask(16, num_new_tokens=4) + # rightmost 6 positions + assert mask[0, -6:].tolist() == [1, 1, 1, 1, 1, 1] + assert mask[0, :-6].sum().item() == 0 + + def test_saturates_at_max_len(self) -> None: + from winml.modelkit.models.hf.kv_cache import WinMLSlidingWindowCache + + cache = _make_cache(WinMLSlidingWindowCache, max_cache_len=8) + cache.step = 10 + mask = cache.build_decoder_mask(8, num_new_tokens=4) + # min(10+4, 8)=8 → all 1s + assert mask[0].sum().item() == 8 + + +class TestStaticCachePreparePrefillChunk: + """WinMLStaticCache.prepare_prefill_chunk — right-pad.""" + + def test_full_chunk_no_padding(self) -> None: + from winml.modelkit.models.hf.kv_cache import WinMLStaticCache + + cache = _make_cache(WinMLStaticCache) + chunk = torch.tensor([[10, 20, 30, 40]]) + padded_ids, pos_ids, pad_len = cache.prepare_prefill_chunk( + chunk, start=0, prefill_seq_len=4 + ) + assert pad_len == 0 + assert padded_ids[0].tolist() == [10, 20, 30, 40] + assert pos_ids[0].tolist() == [0, 1, 2, 3] + + def test_partial_chunk_right_padded(self) -> None: + from winml.modelkit.models.hf.kv_cache import WinMLStaticCache + + cache = _make_cache(WinMLStaticCache) + chunk = torch.tensor([[10, 20]]) + padded_ids, pos_ids, pad_len = cache.prepare_prefill_chunk( + chunk, start=4, prefill_seq_len=4 + ) + assert pad_len == 0 + assert padded_ids[0, :2].tolist() == [10, 20] + assert padded_ids[0, 2:].tolist() == [0, 0] + assert pos_ids[0].tolist() == [4, 5, 6, 7] + + +class TestSlidingWindowCachePreparePrefillChunk: + """WinMLSlidingWindowCache.prepare_prefill_chunk — left-pad.""" + + def test_full_chunk_no_padding(self) -> None: + from winml.modelkit.models.hf.kv_cache import WinMLSlidingWindowCache + + cache = _make_cache(WinMLSlidingWindowCache) + chunk = torch.tensor([[10, 20, 30, 40]]) + padded_ids, pos_ids, pad_len = cache.prepare_prefill_chunk( + chunk, start=0, prefill_seq_len=4 + ) + assert pad_len == 0 + assert padded_ids[0].tolist() == [10, 20, 30, 40] + assert pos_ids[0].tolist() == [0, 1, 2, 3] + + def test_partial_chunk_left_padded(self) -> None: + from winml.modelkit.models.hf.kv_cache import WinMLSlidingWindowCache + + cache = _make_cache(WinMLSlidingWindowCache) + chunk = torch.tensor([[10, 20]]) + padded_ids, pos_ids, pad_len = cache.prepare_prefill_chunk( + chunk, start=4, prefill_seq_len=4 + ) + assert pad_len == 2 + assert padded_ids[0].tolist() == [0, 0, 10, 20] + assert pos_ids[0].tolist() == [0, 0, 4, 5] diff --git a/tests/unit/models/auto/test_auto_onnx.py b/tests/unit/models/auto/test_auto_onnx.py index 5c9b66f9c..ba3a8b530 100644 --- a/tests/unit/models/auto/test_auto_onnx.py +++ b/tests/unit/models/auto/test_auto_onnx.py @@ -195,3 +195,30 @@ def test_passes_ep_from_kwargs(self, fake_onnx: Path, tmp_path: Path): call_kwargs = mock_from_onnx.call_args.kwargs assert call_kwargs["ep"] == "qnn" + + +# ============================================================================= +# from_onnx dict dispatch → WinMLCompositeModel.from_onnx +# ============================================================================= + + +class TestFromOnnxDictDispatch: + """from_onnx with dict onnx_path delegates to WinMLCompositeModel.from_onnx.""" + + def test_dict_dispatches_to_composite(self, tmp_path: Path): + """Dict onnx_path calls WinMLCompositeModel.from_onnx.""" + with patch( + "winml.modelkit.models.winml.composite_model.WinMLCompositeModel.from_onnx" + ) as mock_from_onnx: + mock_from_onnx.return_value = MagicMock() + + WinMLAutoModel.from_onnx( + {"encoder": str(tmp_path / "enc.onnx"), "decoder": str(tmp_path / "dec.onnx")}, + task="translation", + skip_build=True, + ) + + mock_from_onnx.assert_called_once() + call_kwargs = mock_from_onnx.call_args.kwargs + assert call_kwargs["task"] == "translation" + assert call_kwargs["skip_build"] is True From 80cb39af310c9d0bae9a70a6974b03050777b3b2 Mon Sep 17 00:00:00 2001 From: vortex-captain <75063846+vortex-captain@users.noreply.github.com> Date: Thu, 16 Apr 2026 17:58:20 +0800 Subject: [PATCH 17/32] Potential fix for pull request finding 'CodeQL / Cyclic import' Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --- src/winml/modelkit/models/auto.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index ac7b8a284..10512e7ea 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -38,7 +38,6 @@ if TYPE_CHECKING: from ..config import WinMLBuildConfig from .winml.base import WinMLPreTrainedModel - from .winml.composite_model import WinMLCompositeModel logger = logging.getLogger(__name__) @@ -110,7 +109,7 @@ def from_onnx( skip_build: bool = False, session_options: Any | None = None, **kwargs: Any, - ) -> WinMLPreTrainedModel | WinMLCompositeModel: + ) -> WinMLPreTrainedModel | "WinMLCompositeModel": """Build from a pre-exported ONNX file. Runs optimize -> [quantize] -> [compile] via ``build_onnx_model()``. From ebcb91662b0fd94f4d7bad58ea11131366bdcc10 Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Thu, 16 Apr 2026 18:05:34 +0800 Subject: [PATCH 18/32] fix: resolve ruff F821/UP037 for WinMLCompositeModel type annotation --- src/winml/modelkit/models/auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index 10512e7ea..6668bef7f 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -109,7 +109,7 @@ def from_onnx( skip_build: bool = False, session_options: Any | None = None, **kwargs: Any, - ) -> WinMLPreTrainedModel | "WinMLCompositeModel": + ) -> WinMLPreTrainedModel | WinMLCompositeModel: # noqa: F821 """Build from a pre-exported ONNX file. Runs optimize -> [quantize] -> [compile] via ``build_onnx_model()``. From 2000b69c0feb3fd3e511d1a075a2c85c55a711a8 Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Fri, 17 Apr 2026 10:11:17 +0800 Subject: [PATCH 19/32] refactor: move encoder_decoder and kv_cache from hf/ to winml/ These modules are WinML-specific infrastructure (inference wrappers and KV cache classes), not HuggingFace model configs. Move alongside composite_model.py and decoder_only.py. - src/winml/modelkit/models/hf/encoder_decoder.py -> winml/ - src/winml/modelkit/models/hf/kv_cache.py -> winml/ Updated imports in mu2.py, qwen.py, t5.py (hf modules now use ..winml.kv_cache), decoder_only.py, encoder_decoder.py (lazy imports now use .kv_cache), and tests/unit/export/test_io.py. Verified: Mu2 e2e (6/6 queries correct), all 79 UTs pass. --- src/winml/modelkit/models/hf/clip.py | 40 ++++++++++++++----- src/winml/modelkit/models/hf/mu2.py | 4 +- src/winml/modelkit/models/hf/qwen.py | 2 +- src/winml/modelkit/models/hf/t5.py | 4 +- .../modelkit/models/winml/decoder_only.py | 4 +- .../models/{hf => winml}/encoder_decoder.py | 2 +- .../modelkit/models/{hf => winml}/kv_cache.py | 0 tests/unit/export/test_io.py | 20 +++++----- 8 files changed, 49 insertions(+), 27 deletions(-) rename src/winml/modelkit/models/{hf => winml}/encoder_decoder.py (99%) rename src/winml/modelkit/models/{hf => winml}/kv_cache.py (100%) diff --git a/src/winml/modelkit/models/hf/clip.py b/src/winml/modelkit/models/hf/clip.py index 045fcb160..f3b079c56 100644 --- a/src/winml/modelkit/models/hf/clip.py +++ b/src/winml/modelkit/models/hf/clip.py @@ -74,18 +74,17 @@ # ============================================================================= @register_onnx_overwrite("clip_text_model", "feature-extraction", library_name="transformers") class CLIPTextModelIOConfig(CLIPTextWithProjectionOnnxConfig): - """ONNX config for CLIPTextModelWithProjection from transformers. + """ONNX config for CLIP text models (both with and without projection). - Model: openai/clip-vit-base-patch32 (text encoder only) - model.config.model_type = "clip_text_model" + Handles two architectures that share model_type ``clip_text_model``: - Inputs: - - input_ids: {0: "batch_size", 1: "sequence_length"} - - attention_mask: {0: "batch_size", 1: "sequence_length"} + - **CLIPTextModelWithProjection** (standalone CLIP text encoder): + Outputs ``text_embeds`` + ``last_hidden_state`` + - **CLIPTextModel** (e.g., Stable Diffusion text encoder): + Outputs ``last_hidden_state`` only - Outputs: - - text_embeds: {0: "batch_size"} - - last_hidden_state: {0: "batch_size", 1: "sequence_length"} + The ``outputs`` property auto-detects which variant to use based on + ``config.architectures``. Key difference from Optimum's default: - sequence_length = max_position_embeddings (77 for CLIP) @@ -106,6 +105,29 @@ def inputs(self) -> dict[str, dict[int, str]]: "attention_mask": {0: "batch_size", 1: "sequence_length"}, } + @property + def outputs(self) -> dict[str, dict[int, str]]: + """Return output tensors, adapting to the model architecture. + + CLIPTextModelWithProjection produces ``text_embeds`` (projected) + and ``last_hidden_state``. CLIPTextModel (used in Stable Diffusion) + produces only ``last_hidden_state``. + + Default (architectures unset) assumes projection model for backward + compatibility. + """ + architectures = getattr(self._config, "architectures", None) or [] + # CLIPTextModel (non-projection, e.g. SD text_encoder) + if "CLIPTextModel" in architectures and "CLIPTextModelWithProjection" not in architectures: + return { + "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, + } + # Default: projection model outputs (backward compatible) + return { + "text_embeds": {0: "batch_size"}, + "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, + } + @register_onnx_overwrite("clip_vision_model", "feature-extraction", library_name="transformers") class CLIPVisionModelIOConfig(CLIPVisionModelOnnxConfig): diff --git a/src/winml/modelkit/models/hf/mu2.py b/src/winml/modelkit/models/hf/mu2.py index 5f351bc4a..6f519f17c 100644 --- a/src/winml/modelkit/models/hf/mu2.py +++ b/src/winml/modelkit/models/hf/mu2.py @@ -71,8 +71,8 @@ class for Mu2 (custom ``trust_remote_code`` model). from ...export import register_onnx_overwrite from ...optim import WinMLOptimizationConfig from ..winml.composite_model import register_composite_model -from .encoder_decoder import EncoderDecoderInputGenerator, WinMLEncoderDecoderModel -from .kv_cache import PastKeyValueInputGenerator, WinMLSlidingWindowCache +from ..winml.encoder_decoder import EncoderDecoderInputGenerator, WinMLEncoderDecoderModel +from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLSlidingWindowCache # ============================================================================= diff --git a/src/winml/modelkit/models/hf/qwen.py b/src/winml/modelkit/models/hf/qwen.py index 2fba45902..12c5fd717 100644 --- a/src/winml/modelkit/models/hf/qwen.py +++ b/src/winml/modelkit/models/hf/qwen.py @@ -110,7 +110,7 @@ DecoderOnlyPrefillInputGenerator, WinMLDecoderOnlyModel, ) -from .kv_cache import PastKeyValueInputGenerator, WinMLSlidingWindowCache +from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLSlidingWindowCache # ============================================================================= diff --git a/src/winml/modelkit/models/hf/t5.py b/src/winml/modelkit/models/hf/t5.py index 923206972..916e96b94 100644 --- a/src/winml/modelkit/models/hf/t5.py +++ b/src/winml/modelkit/models/hf/t5.py @@ -38,8 +38,8 @@ from ...export import register_onnx_overwrite from ...optim import WinMLOptimizationConfig from ..winml.composite_model import register_composite_model -from .encoder_decoder import EncoderDecoderInputGenerator, WinMLEncoderDecoderModel -from .kv_cache import PastKeyValueInputGenerator, WinMLStaticCache +from ..winml.encoder_decoder import EncoderDecoderInputGenerator, WinMLEncoderDecoderModel +from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLStaticCache # ============================================================================= diff --git a/src/winml/modelkit/models/winml/decoder_only.py b/src/winml/modelkit/models/winml/decoder_only.py index 2794407c7..3827c764c 100644 --- a/src/winml/modelkit/models/winml/decoder_only.py +++ b/src/winml/modelkit/models/winml/decoder_only.py @@ -222,7 +222,7 @@ def get_cache_class(cls) -> type: def _resolve_cache(self, past_key_values: Any) -> Any: """Unwrap or create WinMLCache for this generation step.""" - from ..hf.kv_cache import WinMLCache + from .kv_cache import WinMLCache if isinstance(past_key_values, WinMLCache): return past_key_values @@ -243,7 +243,7 @@ def prepare_inputs_for_generation( **kwargs: Any, ) -> dict[str, Any]: """Build inputs for each generate() step.""" - from ..hf.kv_cache import WinMLCache + from .kv_cache import WinMLCache if isinstance(past_key_values, WinMLCache) and past_key_values.get_seq_length() > 0: input_ids = input_ids[:, -1:] diff --git a/src/winml/modelkit/models/hf/encoder_decoder.py b/src/winml/modelkit/models/winml/encoder_decoder.py similarity index 99% rename from src/winml/modelkit/models/hf/encoder_decoder.py rename to src/winml/modelkit/models/winml/encoder_decoder.py index 7d9cf5503..aecf11cc4 100644 --- a/src/winml/modelkit/models/hf/encoder_decoder.py +++ b/src/winml/modelkit/models/winml/encoder_decoder.py @@ -58,7 +58,7 @@ from transformers.generation.utils import GenerationMixin from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput -from ..winml.composite_model import WinMLCompositeModel +from .composite_model import WinMLCompositeModel if TYPE_CHECKING: diff --git a/src/winml/modelkit/models/hf/kv_cache.py b/src/winml/modelkit/models/winml/kv_cache.py similarity index 100% rename from src/winml/modelkit/models/hf/kv_cache.py rename to src/winml/modelkit/models/winml/kv_cache.py diff --git a/tests/unit/export/test_io.py b/tests/unit/export/test_io.py index 2e85e1615..20dbdcdc4 100644 --- a/tests/unit/export/test_io.py +++ b/tests/unit/export/test_io.py @@ -38,7 +38,7 @@ _get_onnx_config, _populate_image_size_from_preprocessor, ) -from winml.modelkit.models.hf.kv_cache import PastKeyValueInputGenerator +from winml.modelkit.models.winml.kv_cache import PastKeyValueInputGenerator # ============================================================================= @@ -854,7 +854,7 @@ class TestStaticCacheBuildDecoderMask: """WinMLStaticCache.build_decoder_mask — left-aligned mask.""" def test_default_single_token(self) -> None: - from winml.modelkit.models.hf.kv_cache import WinMLStaticCache + from winml.modelkit.models.winml.kv_cache import WinMLStaticCache cache = _make_cache(WinMLStaticCache) cache.step = 3 @@ -864,7 +864,7 @@ def test_default_single_token(self) -> None: assert mask[0, 4:].sum().item() == 0 def test_num_new_tokens(self) -> None: - from winml.modelkit.models.hf.kv_cache import WinMLStaticCache + from winml.modelkit.models.winml.kv_cache import WinMLStaticCache cache = _make_cache(WinMLStaticCache) cache.step = 2 @@ -877,7 +877,7 @@ class TestSlidingWindowCacheBuildDecoderMask: """WinMLSlidingWindowCache.build_decoder_mask — right-aligned mask.""" def test_default_single_token(self) -> None: - from winml.modelkit.models.hf.kv_cache import WinMLSlidingWindowCache + from winml.modelkit.models.winml.kv_cache import WinMLSlidingWindowCache cache = _make_cache(WinMLSlidingWindowCache) cache.step = 3 @@ -887,7 +887,7 @@ def test_default_single_token(self) -> None: assert mask[0, :-4].sum().item() == 0 def test_num_new_tokens(self) -> None: - from winml.modelkit.models.hf.kv_cache import WinMLSlidingWindowCache + from winml.modelkit.models.winml.kv_cache import WinMLSlidingWindowCache cache = _make_cache(WinMLSlidingWindowCache) cache.step = 2 @@ -897,7 +897,7 @@ def test_num_new_tokens(self) -> None: assert mask[0, :-6].sum().item() == 0 def test_saturates_at_max_len(self) -> None: - from winml.modelkit.models.hf.kv_cache import WinMLSlidingWindowCache + from winml.modelkit.models.winml.kv_cache import WinMLSlidingWindowCache cache = _make_cache(WinMLSlidingWindowCache, max_cache_len=8) cache.step = 10 @@ -910,7 +910,7 @@ class TestStaticCachePreparePrefillChunk: """WinMLStaticCache.prepare_prefill_chunk — right-pad.""" def test_full_chunk_no_padding(self) -> None: - from winml.modelkit.models.hf.kv_cache import WinMLStaticCache + from winml.modelkit.models.winml.kv_cache import WinMLStaticCache cache = _make_cache(WinMLStaticCache) chunk = torch.tensor([[10, 20, 30, 40]]) @@ -922,7 +922,7 @@ def test_full_chunk_no_padding(self) -> None: assert pos_ids[0].tolist() == [0, 1, 2, 3] def test_partial_chunk_right_padded(self) -> None: - from winml.modelkit.models.hf.kv_cache import WinMLStaticCache + from winml.modelkit.models.winml.kv_cache import WinMLStaticCache cache = _make_cache(WinMLStaticCache) chunk = torch.tensor([[10, 20]]) @@ -939,7 +939,7 @@ class TestSlidingWindowCachePreparePrefillChunk: """WinMLSlidingWindowCache.prepare_prefill_chunk — left-pad.""" def test_full_chunk_no_padding(self) -> None: - from winml.modelkit.models.hf.kv_cache import WinMLSlidingWindowCache + from winml.modelkit.models.winml.kv_cache import WinMLSlidingWindowCache cache = _make_cache(WinMLSlidingWindowCache) chunk = torch.tensor([[10, 20, 30, 40]]) @@ -951,7 +951,7 @@ def test_full_chunk_no_padding(self) -> None: assert pos_ids[0].tolist() == [0, 1, 2, 3] def test_partial_chunk_left_padded(self) -> None: - from winml.modelkit.models.hf.kv_cache import WinMLSlidingWindowCache + from winml.modelkit.models.winml.kv_cache import WinMLSlidingWindowCache cache = _make_cache(WinMLSlidingWindowCache) chunk = torch.tensor([[10, 20]]) From b47d9838c5c64b803f35ab0c03dec7179fdb3b8c Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Fri, 17 Apr 2026 10:14:53 +0800 Subject: [PATCH 20/32] revert: undo clip.py change that slipped into refactor commit --- src/winml/modelkit/models/hf/clip.py | 40 +++++++--------------------- 1 file changed, 9 insertions(+), 31 deletions(-) diff --git a/src/winml/modelkit/models/hf/clip.py b/src/winml/modelkit/models/hf/clip.py index f3b079c56..045fcb160 100644 --- a/src/winml/modelkit/models/hf/clip.py +++ b/src/winml/modelkit/models/hf/clip.py @@ -74,17 +74,18 @@ # ============================================================================= @register_onnx_overwrite("clip_text_model", "feature-extraction", library_name="transformers") class CLIPTextModelIOConfig(CLIPTextWithProjectionOnnxConfig): - """ONNX config for CLIP text models (both with and without projection). + """ONNX config for CLIPTextModelWithProjection from transformers. - Handles two architectures that share model_type ``clip_text_model``: + Model: openai/clip-vit-base-patch32 (text encoder only) + model.config.model_type = "clip_text_model" - - **CLIPTextModelWithProjection** (standalone CLIP text encoder): - Outputs ``text_embeds`` + ``last_hidden_state`` - - **CLIPTextModel** (e.g., Stable Diffusion text encoder): - Outputs ``last_hidden_state`` only + Inputs: + - input_ids: {0: "batch_size", 1: "sequence_length"} + - attention_mask: {0: "batch_size", 1: "sequence_length"} - The ``outputs`` property auto-detects which variant to use based on - ``config.architectures``. + Outputs: + - text_embeds: {0: "batch_size"} + - last_hidden_state: {0: "batch_size", 1: "sequence_length"} Key difference from Optimum's default: - sequence_length = max_position_embeddings (77 for CLIP) @@ -105,29 +106,6 @@ def inputs(self) -> dict[str, dict[int, str]]: "attention_mask": {0: "batch_size", 1: "sequence_length"}, } - @property - def outputs(self) -> dict[str, dict[int, str]]: - """Return output tensors, adapting to the model architecture. - - CLIPTextModelWithProjection produces ``text_embeds`` (projected) - and ``last_hidden_state``. CLIPTextModel (used in Stable Diffusion) - produces only ``last_hidden_state``. - - Default (architectures unset) assumes projection model for backward - compatibility. - """ - architectures = getattr(self._config, "architectures", None) or [] - # CLIPTextModel (non-projection, e.g. SD text_encoder) - if "CLIPTextModel" in architectures and "CLIPTextModelWithProjection" not in architectures: - return { - "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, - } - # Default: projection model outputs (backward compatible) - return { - "text_embeds": {0: "batch_size"}, - "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, - } - @register_onnx_overwrite("clip_vision_model", "feature-extraction", library_name="transformers") class CLIPVisionModelIOConfig(CLIPVisionModelOnnxConfig): From 44f1f20051daf386a9d1c396156531215908806e Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Fri, 17 Apr 2026 10:20:14 +0800 Subject: [PATCH 21/32] refactor: make hf_config and sub_model_kwargs explicit params in from_onnx --- .../modelkit/models/winml/composite_model.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/winml/modelkit/models/winml/composite_model.py b/src/winml/modelkit/models/winml/composite_model.py index e2bc74b47..1975c4d50 100644 --- a/src/winml/modelkit/models/winml/composite_model.py +++ b/src/winml/modelkit/models/winml/composite_model.py @@ -180,6 +180,8 @@ def from_onnx( onnx_path: dict[str, str], *, task: str | None = None, + hf_config: PretrainedConfig | None = None, + sub_model_kwargs: dict[str, dict[str, Any]] | None = None, **kwargs: Any, ) -> WinMLCompositeModel: """Load composite model from pre-built ONNX files. @@ -193,15 +195,17 @@ def from_onnx( ``"decoder_prefill"``) to its ONNX file path. task: Pipeline task (e.g., ``"translation"``, ``"text-generation"``). - **kwargs: Must include ``hf_config`` (``PretrainedConfig``). - May include ``sub_model_kwargs`` for per-component - overrides. Remaining kwargs are forwarded to - ``WinMLAutoModel.from_onnx`` for every component. + hf_config: HF ``PretrainedConfig`` for the model. Used to + resolve the concrete class from the registry via + ``hf_config.model_type``. + sub_model_kwargs: Per-component kwargs merged on top of + ``**kwargs`` for each sub-model's ``from_onnx`` call. + **kwargs: Forwarded to ``WinMLAutoModel.from_onnx`` for every + component (overridden by ``sub_model_kwargs``). """ from pathlib import Path - hf_config = kwargs.pop("hf_config", None) - sub_model_kwargs = kwargs.pop("sub_model_kwargs", None) or {} + per_component = sub_model_kwargs or {} # Resolve concrete class from registry model_type = getattr(hf_config, "model_type", None) if hf_config else None @@ -220,7 +224,7 @@ def from_onnx( sub_models: dict[str, Any] = {} for name, path in onnx_path.items(): component_task = resolved_cls._SUB_MODEL_CONFIG.get(name) - merged = {**kwargs, "task": component_task, **sub_model_kwargs.get(name, {})} + merged = {**kwargs, "task": component_task, **per_component.get(name, {})} sub_models[name] = WinMLAutoModel.from_onnx(Path(path), **merged) return resolved_cls(sub_models=sub_models, config=hf_config) From 7b42d5b04c8cedcb8801d42df2370f26266f19ac Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Fri, 17 Apr 2026 13:27:24 +0800 Subject: [PATCH 22/32] refactor: move _pad_inputs to utils/data_utils.py with mode param - New pad_inputs() in utils/data_utils.py supports mode="right" (default) and mode="left" - Remove _pad_inputs static method from WinMLCompositeModel - Update encoder_decoder.py callers to use pad_inputs directly - Add device param + ort_device property to WinMLCompositeModel.__init__ Verified: Mu2 e2e translation (6 queries) still correct after refactor. --- .../modelkit/models/winml/composite_model.py | 35 +++-------- .../modelkit/models/winml/encoder_decoder.py | 7 ++- src/winml/modelkit/utils/data_utils.py | 58 +++++++++++++++++++ 3 files changed, 69 insertions(+), 31 deletions(-) create mode 100644 src/winml/modelkit/utils/data_utils.py diff --git a/src/winml/modelkit/models/winml/composite_model.py b/src/winml/modelkit/models/winml/composite_model.py index 1975c4d50..99b39838b 100644 --- a/src/winml/modelkit/models/winml/composite_model.py +++ b/src/winml/modelkit/models/winml/composite_model.py @@ -93,9 +93,11 @@ def __init__( self, sub_models: dict[str, Any], config: PretrainedConfig, + device: str = "cpu", ) -> None: self.sub_models = sub_models self.config = config + self._device = device @classmethod def from_pretrained( @@ -234,6 +236,11 @@ def device(self) -> torch.device: """Device (CPU — ORT handles actual placement).""" return torch.device("cpu") + @property + def ort_device(self) -> str: + """ORT execution provider target (e.g. "npu", "gpu", "cpu", "auto").""" + return self._device + @property def dtype(self) -> torch.dtype: """Model dtype for HF compatibility.""" @@ -250,31 +257,3 @@ def __call__(self, **kwargs: Any) -> Any: def forward(self, **kwargs: Any) -> Any: """Subclasses implement task-specific logic.""" raise NotImplementedError - - @staticmethod - def _pad_inputs( - source: dict[str, Any], - expected: dict[str, list[int]], - ) -> dict[str, Any]: - """Filter *source* to keys in *expected* and pad undersized tensors. - - For each name in *expected*, if *source* has a tensor for it, pad - any dimension smaller than the ONNX expected shape (skips batch dim). - Non-tensor values are passed through. Missing names are skipped. - """ - result: dict[str, Any] = {} - for name, expected_shape in expected.items(): - val = source.get(name) - if val is None: - continue - if isinstance(val, torch.Tensor): - # TODO: support dynamic shape ONNX models (None in expected_shape) - ndim = min(len(val.shape), len(expected_shape)) - pad: list[int] = [] - for dim in reversed(range(1, ndim)): - deficit = expected_shape[dim] - val.shape[dim] - pad.extend([0, max(deficit, 0)]) - if any(p > 0 for p in pad): - val = torch.nn.functional.pad(val, pad) - result[name] = val - return result diff --git a/src/winml/modelkit/models/winml/encoder_decoder.py b/src/winml/modelkit/models/winml/encoder_decoder.py index aecf11cc4..5fdb91b91 100644 --- a/src/winml/modelkit/models/winml/encoder_decoder.py +++ b/src/winml/modelkit/models/winml/encoder_decoder.py @@ -22,7 +22,7 @@ 3. Feeds are built from ``model_kwargs`` (decoder_input_ids, attention_mask) plus generated inputs (encoder_hidden_states, decoder_attention_mask, - position input, KV buffers). ``_pad_inputs`` filters to ONNX input + position input, KV buffers). ``pad_inputs`` filters to ONNX input names and pads undersized tensors. 4. After ONNX inference, ``cache.update_all_layers(outputs)`` writes @@ -58,6 +58,7 @@ from transformers.generation.utils import GenerationMixin from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput +from ...utils.data_utils import pad_inputs from .composite_model import WinMLCompositeModel @@ -212,7 +213,7 @@ def __init__(self, encoder: Any, expected: dict[str, list[int]]) -> None: self._expected = expected def forward(self, **kwargs: Any) -> BaseModelOutput: - feeds = WinMLCompositeModel._pad_inputs(kwargs, self._expected) + feeds = pad_inputs(kwargs, self._expected) return self._encoder(**feeds) def get_encoder(self) -> torch.nn.Module: @@ -310,7 +311,7 @@ def forward( feeds[f"past_{i}_value"] = cache.layers[i].values.detach() # Run decoder ONNX (pad_inputs filters to expected names + pads) - outputs = self._decoder(**self._pad_inputs(feeds, self._dec_expected)) + outputs = self._decoder(**pad_inputs(feeds, self._dec_expected)) # Write present KV back and advance step cache.update_all_layers(outputs) diff --git a/src/winml/modelkit/utils/data_utils.py b/src/winml/modelkit/utils/data_utils.py new file mode 100644 index 000000000..148b60688 --- /dev/null +++ b/src/winml/modelkit/utils/data_utils.py @@ -0,0 +1,58 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Data utilities for input preparation and padding.""" + +from __future__ import annotations + +from typing import Any + +import torch + + +def pad_inputs( + source: dict[str, Any], + expected: dict[str, list[int]], + mode: str = "right", +) -> dict[str, Any]: + """Filter *source* to keys in *expected* and pad undersized tensors. + + For each name in *expected*, if *source* has a tensor for it, pad any + dimension smaller than the ONNX expected shape (skips batch dim). + Non-tensor values are passed through. Missing names are skipped. + + Args: + source: Input tensors keyed by name. + expected: ONNX expected shapes keyed by input name. + mode: Padding side — ``"right"`` (default, pad at end) or + ``"left"`` (pad at start). + + Returns: + Filtered and padded tensors matching *expected* keys. + """ + if mode not in ("right", "left"): + raise ValueError(f"mode must be 'right' or 'left', got {mode!r}") + + result: dict[str, Any] = {} + for name, expected_shape in expected.items(): + val = source.get(name) + if val is None: + continue + if isinstance(val, torch.Tensor): + # TODO: support dynamic shape ONNX models (None in expected_shape) + ndim = min(len(val.shape), len(expected_shape)) + # torch.nn.functional.pad takes pairs (low, high) from the LAST + # dim backwards. Skip batch dim (dim 0). + pad: list[int] = [] + for dim in reversed(range(1, ndim)): + deficit = max(expected_shape[dim] - val.shape[dim], 0) + if mode == "right": + pad.extend([0, deficit]) + else: # left + pad.extend([deficit, 0]) + if any(p > 0 for p in pad): + val = torch.nn.functional.pad(val, pad) + result[name] = val + return result From ff3a8cd1a10354d5fc7bdbe90590241b113ebe53 Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Fri, 17 Apr 2026 13:34:52 +0800 Subject: [PATCH 23/32] add comment on winml build --- scripts/e2e_eval/run_eval.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/e2e_eval/run_eval.py b/scripts/e2e_eval/run_eval.py index 9d586b855..dd508ba13 100644 --- a/scripts/e2e_eval/run_eval.py +++ b/scripts/e2e_eval/run_eval.py @@ -397,6 +397,7 @@ def _run_build( onnx_paths: list[str] = [] last_proc = config_proc + # TODO: remove for loop once wimnl build supports building composite model to multiple onnx files for sub_cfg in sub_configs: label = sub_cfg.stem.removeprefix(f"{config_path.stem}_") if len(sub_configs) > 1 else "" if label: From 50d29ce104c9ed0286a8fdc9b7505ce28811f4bc Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Fri, 17 Apr 2026 13:40:38 +0800 Subject: [PATCH 24/32] fix naming --- src/winml/modelkit/commands/config.py | 10 +++++----- src/winml/modelkit/models/hf/qwen.py | 2 +- src/winml/modelkit/models/hf/t5.py | 2 +- src/winml/modelkit/models/winml/composite_model.py | 10 +++++----- src/winml/modelkit/models/winml/decoder_only.py | 4 ++-- src/winml/modelkit/models/winml/encoder_decoder.py | 2 +- 6 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/winml/modelkit/commands/config.py b/src/winml/modelkit/commands/config.py index 7f45f5268..6fc7b6a09 100644 --- a/src/winml/modelkit/commands/config.py +++ b/src/winml/modelkit/commands/config.py @@ -328,12 +328,12 @@ def config( else: _is_onnx_mode = False - # Check pipeline model registry: (model_type, task) → multi-config - pipeline_components = _resolve_pipeline_components( + # Check composite model registry: (model_type, task) -> multi-config + pipeline_components = _resolve_composite_model_components( hf_model, model_type, task, trust_remote_code=trust_remote_code ) if pipeline_components: - # Pipeline model: generate one config per sub-component + # composite model: generate one config per sub-component _generate_pipeline_configs( pipeline_components, hf_model=hf_model, @@ -513,13 +513,13 @@ def config( raise click.ClickException(f"Unexpected error: {e}") from e -def _resolve_pipeline_components( +def _resolve_composite_model_components( hf_model: str | None, model_type: str | None, task: str | None, trust_remote_code: bool = False, ) -> dict[str, str] | None: - """Check if (model_type, task) is a registered pipeline model. + """Check if (model_type, task) is a registered composite model. Returns _SUB_MODEL_CONFIG dict if found, None otherwise. """ diff --git a/src/winml/modelkit/models/hf/qwen.py b/src/winml/modelkit/models/hf/qwen.py index 12c5fd717..74373bb76 100644 --- a/src/winml/modelkit/models/hf/qwen.py +++ b/src/winml/modelkit/models/hf/qwen.py @@ -318,7 +318,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 } # ============================================================================= -# WinMLQwen3Model — inference wrapper (registered as pipeline model) +# WinMLQwen3Model — inference wrapper (registered as composite model) # ============================================================================= diff --git a/src/winml/modelkit/models/hf/t5.py b/src/winml/modelkit/models/hf/t5.py index 916e96b94..69160cadb 100644 --- a/src/winml/modelkit/models/hf/t5.py +++ b/src/winml/modelkit/models/hf/t5.py @@ -287,7 +287,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 # ============================================================================= -# WinMLT5Model — inference wrapper (registered as pipeline model) +# WinMLT5Model — inference wrapper (registered as composite model) # ============================================================================= diff --git a/src/winml/modelkit/models/winml/composite_model.py b/src/winml/modelkit/models/winml/composite_model.py index 99b39838b..b4fe1400b 100644 --- a/src/winml/modelkit/models/winml/composite_model.py +++ b/src/winml/modelkit/models/winml/composite_model.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""WinML Pipeline Model base and registry. +"""WinML composite model base and registry. Provides ``WinMLCompositeModel`` — a base class for models composed of multiple ``WinMLAutoModel`` sub-components (e.g., encoder + decoder, @@ -31,7 +31,7 @@ "decoder_gen": {"shape_config": {"max_cache_len": 256, "seq_len": 1}}, }) -Concrete pipeline models live alongside their export configs: +Concrete composite models live alongside their export configs: - ``models.hf.t5.WinMLT5Model`` (encoder-decoder, T5) - ``models.hf.mu2.WinMLMu2Model`` (encoder-decoder, Mu2) @@ -55,7 +55,7 @@ # ========================================================================= -# Pipeline Model Registry +# composite model Registry # ========================================================================= # Maps (model_type, task) → pipeline class with _SUB_MODEL_CONFIG. @@ -64,7 +64,7 @@ def register_composite_model(model_type: str, task: str): - """Class decorator that registers a pipeline model for `wmk config`.""" + """Class decorator that registers a composite model for `wmk config`.""" def decorator(cls: type) -> type: PIPELINE_MODEL_REGISTRY[(model_type, task)] = cls @@ -146,7 +146,7 @@ def from_pretrained( resolved_cls = PIPELINE_MODEL_REGISTRY.get((model_type, task)) if resolved_cls is None: raise ValueError( - f"No pipeline model registered for ({model_type!r}, {task!r}). " + f"No composite model registered for ({model_type!r}, {task!r}). " f"Registered: {list(PIPELINE_MODEL_REGISTRY.keys())}" ) return resolved_cls.from_pretrained( diff --git a/src/winml/modelkit/models/winml/decoder_only.py b/src/winml/modelkit/models/winml/decoder_only.py index 3827c764c..4a3f4671b 100644 --- a/src/winml/modelkit/models/winml/decoder_only.py +++ b/src/winml/modelkit/models/winml/decoder_only.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""WinML Decoder-Only Pipeline Model. +"""WinML Decoder-Only composite model. Class hierarchy:: @@ -157,7 +157,7 @@ class DecoderOnlyPrefillInputGenerator(DecoderOnlyInputGenerator): class WinMLDecoderOnlyModel(WinMLCompositeModel, GenerationMixin): - """Decoder-only pipeline model with HF GenerationMixin support. + """Decoder-only composite model with HF GenerationMixin support. Expects sub-components ``"decoder_prefill"`` and ``"decoder_gen"`` in ``_SUB_MODEL_CONFIG``. Provides the full interface required by diff --git a/src/winml/modelkit/models/winml/encoder_decoder.py b/src/winml/modelkit/models/winml/encoder_decoder.py index 5fdb91b91..0f49309b8 100644 --- a/src/winml/modelkit/models/winml/encoder_decoder.py +++ b/src/winml/modelkit/models/winml/encoder_decoder.py @@ -144,7 +144,7 @@ def generate( class WinMLEncoderDecoderModel(WinMLCompositeModel, GenerationMixin): - """Pipeline model with HF GenerationMixin support. + """composite model with HF GenerationMixin support. Expects sub-components ``"encoder"`` and ``"decoder"`` in ``_SUB_MODEL_CONFIG``. Provides the full interface required by From ede3764870abedae8abdcc7c957b612166136359 Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Fri, 17 Apr 2026 14:15:48 +0800 Subject: [PATCH 25/32] refactor: _run_build returns onnx_paths as dict {label: path} - Composite model labels (encoder, decoder, decoder_gen, decoder_prefill) propagate to perf output for clearer logs - Single model uses {"": path} as degenerate case - Verified: T5 composite (decoder, encoder labels) and resnet-50 single both pass run_eval perf --- scripts/e2e_eval/run_eval.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/scripts/e2e_eval/run_eval.py b/scripts/e2e_eval/run_eval.py index dd508ba13..fe4b2a547 100644 --- a/scripts/e2e_eval/run_eval.py +++ b/scripts/e2e_eval/run_eval.py @@ -382,7 +382,7 @@ def _run_build( if config_proc["exit_code"] != 0: return { "success": False, - "onnx_paths": [], + "onnx_paths": {}, "stage": "config", "proc": config_proc, } @@ -394,7 +394,8 @@ def _run_build( sub_configs = [config_path] # Step 2: build each sub-config - onnx_paths: list[str] = [] + # Map component label → ONNX path. Single model uses "" as label. + onnx_paths: dict[str, str] = {} last_proc = config_proc # TODO: remove for loop once wimnl build supports building composite model to multiple onnx files @@ -427,7 +428,7 @@ def _run_build( task_hint = _extract_task_from_config(sub_cfg) or entry.task path = _extract_onnx_path(build_proc, entry.hf_id, task_hint) if path: - onnx_paths.append(path) + onnx_paths[label] = path return { "success": len(onnx_paths) == len(sub_configs), @@ -504,12 +505,12 @@ def run_model( entry: ModelEntry, device: str, timeout: int, - onnx_paths: list[str] | None = None, + onnx_paths: dict[str, str] | None = None, ) -> dict: """Execute winml perf for one or more ONNX models. Returns merged result dict. When onnx_paths is provided, benchmarks each pre-built ONNX directly. - Single model is just the list-of-1 case. Results are merged (worst exit + Single model is the {"": path} case. Results are merged (worst exit code, concatenated stdout/stderr, summed elapsed). """ if not onnx_paths: @@ -549,19 +550,18 @@ def run_model( any_timeout = False commands: list[str] = [] - for path in onnx_paths: - component = Path(path).parent.name if len(onnx_paths) > 1 else "" - if component: - safe_print(f" perf: {component}") + for label, path in onnx_paths.items(): + if label: + safe_print(f" perf: {label}") args = [*WINML_CLI, "perf", "-m", path, "--device", device] args += ["--iterations", "10", "--warmup", "2"] args += entry.perf_args proc = _run_subprocess(args, timeout) - if component: - all_stdout.append(f"=== {component} ===\n{proc['stdout']}") - all_stderr.append(f"=== {component} ===\n{proc['stderr']}") + if label: + all_stdout.append(f"=== {label} ===\n{proc['stdout']}") + all_stderr.append(f"=== {label} ===\n{proc['stderr']}") else: all_stdout.append(proc["stdout"]) all_stderr.append(proc["stderr"]) @@ -1276,7 +1276,9 @@ def main() -> None: args.timeout, model_dir, ) - onnx_paths = build_result["onnx_paths"] if build_result["success"] else [] + onnx_paths = build_result["onnx_paths"] if build_result["success"] else {} + # First ONNX path for accuracy phase (TODO: composite model support) + first_path = next(iter(onnx_paths.values()), None) if onnx_paths else None if not build_result["success"]: # Build failed — synthesize failed result for downstream phases @@ -1295,8 +1297,7 @@ def main() -> None: args.device, args.timeout, model_dir, - # TODO: fix for composite model once supported - onnx_paths[0] if onnx_paths else None, + first_path, ) elif args.eval_type == "perf": perf_proc = run_model(entry, args.device, args.timeout, onnx_paths) @@ -1311,8 +1312,7 @@ def main() -> None: args.device, args.timeout, model_dir, - # TODO: fix for composite model once supported - onnx_paths[0] if onnx_paths else None, + first_path, ) except KeyboardInterrupt: From c90ac1dce9735594ca8362a3a79121a4c74ecabb Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Fri, 17 Apr 2026 16:09:49 +0800 Subject: [PATCH 26/32] feat: add google/flan-t5-base to e2e eval registry (translation + summarization) --- scripts/e2e_eval/testsets/models_all.json | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/scripts/e2e_eval/testsets/models_all.json b/scripts/e2e_eval/testsets/models_all.json index 82809073c..06625abf3 100644 --- a/scripts/e2e_eval/testsets/models_all.json +++ b/scripts/e2e_eval/testsets/models_all.json @@ -629,6 +629,16 @@ "last_update_time": "2023-06-30T02:31:26+00:00", "optimum_supported": true }, + { + "hf_id": "google/flan-t5-base", + "task": "summarization", + "model_type": "t5", + "group": "Top200", + "priority": "P1", + "downloads": 1372124, + "last_update_time": "2023-07-17T12:48:39+00:00", + "optimum_supported": true + }, { "hf_id": "sshleifer/distilbart-cnn-12-6", "task": "summarization", @@ -719,6 +729,16 @@ "last_update_time": "2023-06-30T02:31:26+00:00", "optimum_supported": true }, + { + "hf_id": "google/flan-t5-base", + "task": "translation", + "model_type": "t5", + "group": "Top200", + "priority": "P1", + "downloads": 1372124, + "last_update_time": "2023-07-17T12:48:39+00:00", + "optimum_supported": true + }, { "hf_id": "Helsinki-NLP/opus-mt-nl-en", "task": "translation", From e970c8fb8b952f3cd2f061c4d63d2f3c220c10f9 Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Fri, 17 Apr 2026 17:29:17 +0800 Subject: [PATCH 27/32] Revert "feat: add google/flan-t5-base to e2e eval registry (translation + summarization)" This reverts commit c90ac1dce9735594ca8362a3a79121a4c74ecabb. --- scripts/e2e_eval/testsets/models_all.json | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/scripts/e2e_eval/testsets/models_all.json b/scripts/e2e_eval/testsets/models_all.json index 06625abf3..82809073c 100644 --- a/scripts/e2e_eval/testsets/models_all.json +++ b/scripts/e2e_eval/testsets/models_all.json @@ -629,16 +629,6 @@ "last_update_time": "2023-06-30T02:31:26+00:00", "optimum_supported": true }, - { - "hf_id": "google/flan-t5-base", - "task": "summarization", - "model_type": "t5", - "group": "Top200", - "priority": "P1", - "downloads": 1372124, - "last_update_time": "2023-07-17T12:48:39+00:00", - "optimum_supported": true - }, { "hf_id": "sshleifer/distilbart-cnn-12-6", "task": "summarization", @@ -729,16 +719,6 @@ "last_update_time": "2023-06-30T02:31:26+00:00", "optimum_supported": true }, - { - "hf_id": "google/flan-t5-base", - "task": "translation", - "model_type": "t5", - "group": "Top200", - "priority": "P1", - "downloads": 1372124, - "last_update_time": "2023-07-17T12:48:39+00:00", - "optimum_supported": true - }, { "hf_id": "Helsinki-NLP/opus-mt-nl-en", "task": "translation", From 3c7660194fd075cbf5ac669d5b9f626b3503029b Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Mon, 20 Apr 2026 14:45:23 +0800 Subject: [PATCH 28/32] refactor: T5 uses WinMLSlidingWindowCache; const-fold relative-position bias - T5DecoderWrapper exports with WinMLSlidingWindowCache (Slice+Concat), replacing WinMLStaticCache's index_copy_. No ScatterElements in the graph. - Added abstract WinMLCache.get_query_cache_position(max_len, num_new_tokens): static returns [step..step+N), sliding returns [max_len-N..max_len). Keeps the invariant "cache_position is the query's BUFFER index", so HF's causal mask (kv_idx <= q_idx) and T5's compute_bias work for both cache classes with no compute_bias patch. - encoder_decoder.py feeds both "cache_position" (buffer idx) and "position_id" (seq pos); pad_inputs filters to whichever the decoder ONNX declares, so T5 and Mu2 share the wrapper code without either model taking inputs it can't use. - T5DecoderIOConfig drops "cache_position" as an input; T5DecoderWrapper.forward pins cache_position=[max_cache_len-1] internally. For sliding-window + single-token gen the query is permanently at the rightmost slot, so the relative-distance map is stationary across steps. ONNX constant-folds the entire compute_bias + causal-mask subgraphs: decoder drops 342 -> 316 nodes, and the relative_attention_bias Gather collapses into a single [1, n_heads, 1, W] initializer added to scores. The graph is now coupled to sliding-window semantics at build time; callers who want static-cache semantics subclass the wrapper and re-export (WinMLStaticCache remains available). Verified: t5-small translation "Bonjour, comment etes-vous ?" and summarization match the PR 334 reference; Mu2 e2e unchanged. --- src/winml/modelkit/models/hf/t5.py | 103 ++++++++++++------ .../modelkit/models/winml/encoder_decoder.py | 36 ++++-- src/winml/modelkit/models/winml/kv_cache.py | 56 ++++++++-- 3 files changed, 142 insertions(+), 53 deletions(-) diff --git a/src/winml/modelkit/models/hf/t5.py b/src/winml/modelkit/models/hf/t5.py index 69160cadb..e93cb1534 100644 --- a/src/winml/modelkit/models/hf/t5.py +++ b/src/winml/modelkit/models/hf/t5.py @@ -39,7 +39,7 @@ from ...optim import WinMLOptimizationConfig from ..winml.composite_model import register_composite_model from ..winml.encoder_decoder import EncoderDecoderInputGenerator, WinMLEncoderDecoderModel -from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLStaticCache +from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLSlidingWindowCache # ============================================================================= @@ -78,19 +78,31 @@ def forward( class T5DecoderWrapper(nn.Module): - """Wraps T5ForConditionalGeneration with static KV cache I/O. + """Wraps T5ForConditionalGeneration with sliding-window KV cache I/O. - Input: full static buffer ``[batch, heads, max_decode, d_kv]`` per layer. + Input: full buffer ``[batch, heads, max_decode, d_kv]`` per layer. Output: only the new token's KV ``[batch, heads, 1, d_kv]`` per layer. - Uses HF ``StaticCache`` (``index_copy_`` at ``cache_position``) wrapped - in ``EncoderDecoderCache`` (cross-attn empty → always recomputed from - ``encoder_hidden_states``). ``KV_index = sequence_position`` holds, so - T5's relative position bias computes correct distances. - - The inference wrapper (WinMLT5Model) uses the same - ``StaticCache`` class — it writes the single-token output KV back - into the buffer via ``cache.update()`` before the next step. + Uses ``WinMLSlidingWindowCache`` (Slice+Concat eviction) wrapped in + ``EncoderDecoderCache`` (cross-attn empty → always recomputed from + ``encoder_hidden_states``). + + ``cache_position`` is intentionally NOT an ONNX input — it is pinned to + ``[max_cache_len - 1]`` (the rightmost buffer slot) inside ``forward`` and + traced as a Constant. For single-token generation with a sliding window, + the new token is always written to the rightmost slot, so this value is + invariant. Baking it in lets ONNX constant-fold the entire + ``compute_bias`` subgraph (``memory_position - context_position`` is + constant → learned-bias Gather becomes a fixed tensor) and collapses the + causal mask ``kv_idx <= q_idx`` (all-True since ``q_idx == W-1``). + + This couples the exported graph to sliding-window semantics at build + time. ``WinMLStaticCache`` cannot be used as the *inference* cache for + this ONNX — its buffer layout (left-aligned, index_copy_) does not match + the graph's internal Slice+Concat. Callers who want static-cache + semantics must subclass ``T5DecoderWrapper``, take ``cache_position`` as + an input again, and re-export. ``WinMLStaticCache`` itself remains + fully functional for that path. """ def __init__(self, model: nn.Module, num_layers: int) -> None: @@ -102,7 +114,7 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: @classmethod def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> T5DecoderWrapper: - """Load full T5, wrap with static cache.""" + """Load full T5, wrap with sliding-window cache.""" full_model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, **kwargs) num_layers = full_model.config.num_layers wrapper = cls(full_model, num_layers) @@ -114,11 +126,11 @@ def get_export_args(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor return tuple(inputs.values()) def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: - """Run decoder with static KV cache. + """Run decoder with sliding-window KV cache. Positional args (order matches OnnxConfig.inputs): decoder_input_ids, encoder_hidden_states, attention_mask, - decoder_attention_mask, cache_position, + decoder_attention_mask, past_0_key, past_0_value, past_1_key, past_1_value, ... Returns: @@ -129,14 +141,15 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: encoder_hidden_states = args[1] attention_mask = args[2] decoder_attention_mask = args[3] - cache_position = args[4] - kv_start = 5 - - # Build WinMLStaticCache from input KV tensors. - # update() uses index_copy_ at cache_position for correct attention, - # and captures the incoming key/value states for direct output - # (eliminating the old scatter→gather round-trip in the ONNX graph). - self_attn_cache = WinMLStaticCache(self.config, max_cache_len=args[kv_start].size(2)) + kv_start = 4 + + # Build WinMLSlidingWindowCache from input KV tensors. + # update() does Slice+Concat (not index_copy_/ScatterElements) — evicting + # the N oldest entries and appending the N new ones at the right. The + # incoming key/value states are captured for direct ONNX output + # (avoiding a scatter→gather round-trip in the graph). + max_cache_len = args[kv_start].size(2) + self_attn_cache = WinMLSlidingWindowCache(self.config, max_cache_len=max_cache_len) self_attn_cache.early_initialization( batch_size=decoder_input_ids.size(0), num_heads=args[kv_start].size(1), @@ -148,6 +161,14 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: self_attn_cache.layers[i].keys = args[kv_start + i * 2] self_attn_cache.layers[i].values = args[kv_start + i * 2 + 1] + # Sliding window + single-token gen: the query is always at the + # rightmost slot. Constructing this constant inside forward traces it + # as a Constant node — downstream compute_bias and causal-mask subgraphs + # then constant-fold through ONNX optimization. + cache_position = torch.tensor( + [max_cache_len - 1], dtype=torch.int64, device=decoder_input_ids.device + ) + # EncoderDecoderCache is structurally required: T5Attention routes # self-attention → self_attention_cache, cross-attention → cross_attention_cache. # Without the wrapper, both would share the same cache + layer indices. @@ -168,7 +189,7 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: # Return new-token KV directly from the capturing cache. # The old approach did gather(ScatterElements output) — a round-trip. - # WinMLStaticCache already saved the incoming key/value states. + # The cache already saved the incoming key/value states. result: list[torch.Tensor] = [out.logits] for i in range(self.num_layers): k, v = self_attn_cache.captured[i] @@ -211,13 +232,19 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 @register_onnx_overwrite("t5", "text2text-generation", library_name="transformers") class T5DecoderIOConfig(OnnxConfig): - """ONNX config for T5 decoder with static KV cache. + """ONNX config for T5 decoder with sliding-window KV cache. Inputs: decoder_input_ids, encoder_hidden_states, attention_mask, - decoder_attention_mask, cache_position, past_{i}_key/value + decoder_attention_mask, past_{i}_key/value Outputs: logits, present_{i}_key/value - Input past KV: full static buffer [batch, heads, max_decode, d_kv]. + ``cache_position`` is *not* an input: ``T5DecoderWrapper.forward`` pins it + to ``[max_cache_len - 1]`` (rightmost buffer slot) as a Constant in the + graph. This couples the exported model to sliding-window semantics at + build time; see ``T5DecoderWrapper`` docstring for the static-cache + re-export path if needed. + + Input past KV: full buffer [batch, heads, max_decode, d_kv]. Output present KV: new token only [batch, heads, 1, d_kv]. """ @@ -246,7 +273,6 @@ def inputs(self) -> dict[str, dict[int, str]]: # noqa: D102 "encoder_hidden_states": {0: "batch_size"}, "attention_mask": {0: "batch_size"}, "decoder_attention_mask": {0: "batch_size"}, - "cache_position": {}, } num_layers = self._normalized_config.num_layers for i in range(num_layers): @@ -307,15 +333,22 @@ class WinMLT5Model(WinMLEncoderDecoderModel): @classmethod def get_cache_class(cls) -> type: - """T5 requires WinMLStaticCache (cannot use sliding window). - - T5's relative position bias (``T5Attention.compute_bias``) computes - ``memory_position = arange(key_length)`` — it assumes buffer - position == sequence position. With sliding window, KV entries - shift left each step, so buffer positions no longer correspond to - sequence positions, producing wrong relative distances. + """T5 defaults to ``WinMLSlidingWindowCache`` (Slice+Concat; no ScatterElements). + + Correctness with T5's learned relative position bias hinges on a single + invariant: ``cache_position`` is always the query's *buffer index*, not + its absolute sequence position. ``get_query_cache_position`` on each + cache class supplies the right value — ``[step]`` for static, + ``[max_cache_len-1]`` for sliding. Under that convention, + ``T5Attention.compute_bias`` computes ``memory_position - context_position + = j - (W-1)`` which gives correct relative distances regardless of + overflow, and HF's ``create_causal_mask`` (``kv_idx <= q_idx``) allows + every buffer slot while the 2D decoder mask selects the filled region. + + ``WinMLStaticCache`` remains fully supported — subclass ``WinMLT5Model`` + and override this method to get index_copy_ semantics instead. """ - return WinMLStaticCache + return WinMLSlidingWindowCache @property def generation_config(self): # noqa: D102 diff --git a/src/winml/modelkit/models/winml/encoder_decoder.py b/src/winml/modelkit/models/winml/encoder_decoder.py index 0f49309b8..f8b400a7d 100644 --- a/src/winml/modelkit/models/winml/encoder_decoder.py +++ b/src/winml/modelkit/models/winml/encoder_decoder.py @@ -38,14 +38,20 @@ column 0. For single-token KV-cached decoding, the decoder_attention_mask alone is sufficient — no tril needed. -- **RoPE position vs buffer position**: With ``WinMLSlidingWindowCache``, - the ONNX input is ``position_id`` (absolute sequence position for RoPE). - With ``WinMLStaticCache``, it's ``cache_position`` (= buffer position = - sequence position). - -- **T5 cannot use sliding window**: ``T5Attention.compute_bias`` assumes - ``buffer_position == sequence_position`` via ``arange(key_length)``. - See ``WinMLT5Model.get_cache_class()`` for details. +- **Position inputs, two roles**: ``forward`` seeds ``cache_position`` from + ``cache.get_query_cache_position(...)`` (the query's *buffer index* — used by + HF's causal mask ``kv_idx <= q_idx`` and by T5's ``compute_bias``) and + ``position_id`` from the absolute sequence step (used by RoPE models). + ``pad_inputs`` then filters to whatever the decoder ONNX actually declares, + so T5 (consumes ``cache_position``) and Mu2 (consumes ``position_id``) share + the same wrapper code. + +- **T5 on sliding window**: Works without any ``compute_bias`` patch because + ``WinMLSlidingWindowCache.get_query_cache_position`` returns + ``[max_cache_len - 1]`` (the rightmost buffer slot). With that value, + ``memory_position - context_position = j - (W-1)`` yields the correct + negative distances for all buffer slots, and the 2D right-aligned mask + selects the filled region. """ from __future__ import annotations @@ -305,7 +311,19 @@ def forward( feeds: dict[str, Any] = dict(model_kwargs) feeds.setdefault("encoder_hidden_states", enc_h.detach()) feeds.setdefault("decoder_attention_mask", dec_mask) - feeds.setdefault(cache.position_input_name, torch.tensor([fc], dtype=torch.int64)) + # Feed all position-like names; pad_inputs filters to self._dec_expected. + # Decouples the cache class from the decoder ONNX's chosen input name. + # + # "cache_position": buffer index of the query token — used by HF's + # create_causal_mask (``kv_idx <= q_idx``) and by T5.compute_bias. + # For WinMLStaticCache this equals ``step`` (buffer == seq position); + # for WinMLSlidingWindowCache it is the rightmost buffer slot(s). + # "position_id": absolute sequence position — used by RoPE-based models + # (Mu2) that compute positional encoding from the actual seq position. + cache_pos = cache.get_query_cache_position(self._max_dec).to(torch.int64) + seq_pos = torch.tensor([fc], dtype=torch.int64) + feeds.setdefault("cache_position", cache_pos) + feeds.setdefault("position_id", seq_pos) for i in range(self._num_kv_layers): feeds[f"past_{i}_key"] = cache.layers[i].keys.detach() feeds[f"past_{i}_value"] = cache.layers[i].values.detach() diff --git a/src/winml/modelkit/models/winml/kv_cache.py b/src/winml/modelkit/models/winml/kv_cache.py index beb83bc97..65609d977 100644 --- a/src/winml/modelkit/models/winml/kv_cache.py +++ b/src/winml/modelkit/models/winml/kv_cache.py @@ -13,19 +13,24 @@ Cache type compatibility: -- **WinMLStaticCache**: Required for models using learned relative position bias - (T5, mBART) where ``buffer_position == sequence_position`` must hold. - ``T5Attention.compute_bias`` uses ``memory_position = arange(key_length)`` - so KV entries must stay at their original buffer positions. - -- **WinMLSlidingWindowCache**: Compatible with models using RoPE (Mu2, Llama) - where position encoding is baked into K/V tensors. Buffer positions don't - matter — attention scores depend only on the RoPE embeddings in each K. +- **WinMLStaticCache**: ``index_copy_`` at ``cache_position`` keeps + ``buffer_position == sequence_position``. Cannot evict — ``max_cache_len`` + must be ≥ total generated tokens. + +- **WinMLSlidingWindowCache**: Slice+Concat eviction; works for RoPE models + (Mu2, Qwen, Llama) where position is baked into K when K is computed, and + for learned relative position bias (T5) as long as the wrapper feeds + ``cache_position`` as the query's *buffer index* (see + ``get_query_cache_position``). The invariant ``cache_position = buffer_idx + of query`` makes ``j - cache_position`` the correct relative distance for + both cache types, so no per-model compute_bias patch is required. Common interface (called by ``WinMLEncoderDecoderModel.forward``): - ``position_input_name``: ONNX input name (``"cache_position"`` or ``"position_id"``) -- ``build_decoder_mask(max_len)``: attention mask for current step +- ``build_decoder_mask(max_len)``: 2D attention mask for current step +- ``get_query_cache_position(max_len)``: buffer indices of query tokens + (used by HF's ``create_causal_mask`` and by T5's ``compute_bias``) - ``update_all_layers(outputs)``: write present KV from ONNX output, advance step - ``reset()``: zero out for new generation - ``create(config, kv_shape, dtype)``: factory from ONNX metadata @@ -88,6 +93,20 @@ def build_decoder_mask(self, max_len: int, num_new_tokens: int = 1) -> torch.Ten chunk_len for prefill). """ + @abstractmethod + def get_query_cache_position(self, max_len: int, num_new_tokens: int = 1) -> torch.Tensor: + """Buffer indices of the query tokens for HF's ``cache_position`` input. + + HF's ``create_causal_mask`` uses ``cache_position`` as the query's + *buffer index* (``kv_idx <= q_idx``). For static cache the buffer index + equals the sequence position (``step``); for sliding window it is the + rightmost slot(s) because new tokens are written at the right end. + + Returns: + ``[num_new_tokens]`` int64 tensor of buffer positions for the new + tokens being processed this step. + """ + @abstractmethod def prepare_prefill_chunk( self, @@ -192,6 +211,12 @@ def build_decoder_mask(self, max_len: int, num_new_tokens: int = 1) -> torch.Ten mask[0, : self.step + num_new_tokens] = 1 return mask + def get_query_cache_position(self, max_len: int, num_new_tokens: int = 1) -> torch.Tensor: + """Buffer index == sequence position for static cache: ``[step..step+N)``.""" + import torch + + return torch.arange(self.step, self.step + num_new_tokens, dtype=torch.int64) + def prepare_prefill_chunk( self, chunk_ids: torch.Tensor, @@ -262,6 +287,19 @@ def build_decoder_mask(self, max_len: int, num_new_tokens: int = 1) -> torch.Ten mask[0, max(0, max_len - filled) :] = 1 return mask + def get_query_cache_position(self, max_len: int, num_new_tokens: int = 1) -> torch.Tensor: + """Query tokens sit at the rightmost ``num_new_tokens`` buffer slots. + + Because new tokens are always written at the right end of the buffer + (Slice+Concat), the query's buffer index is ``[max_len-N..max_len)`` — + independent of the absolute sequence position. HF's causal mask + then allows attention to every prior buffer slot (``j <= max_len-1``), + and the 2D ``build_decoder_mask`` selects the filled region within that. + """ + import torch + + return torch.arange(max_len - num_new_tokens, max_len, dtype=torch.int64) + def prepare_prefill_chunk( self, chunk_ids: torch.Tensor, From d28efe8dcdc06859f4c4de9b85af585b854e6f19 Mon Sep 17 00:00:00 2001 From: Zac <1221537+tezheng@users.noreply.github.com> Date: Wed, 22 Apr 2026 15:06:29 +0800 Subject: [PATCH 29/32] fix(models): defensive fixes from PR #334 review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Applies localized fixes for 12 review findings + 1 partial, rebased onto 3c766019 (post T5 SlidingWindow refactor): - C3: register_composite_model duplicate-key ValueError - C6: WinMLCompositeModel.from_onnx unknown-component ValueError - C8: dedicated hf_config: PretrainedConfig on from_onnx(dict) dispatch (previously passed WinMLBuildConfig, causing registry miss; headline API was non-functional for real callers, hidden by mocked tests) - I1: kv_cache.WinMLCache.reset() clears self.captured - I2: rename PIPELINE_MODEL_REGISTRY to COMPOSITE_MODEL_REGISTRY - I9: pad_inputs mode: Literal["left", "right"] - I11-I14: docstring fixes on mu2, qwen, decoder_only, encoder_decoder (including StaticCache references now stale after the T5 refactor) - NI-6: remove phantom position_id from Qwen forward docstring - NI-8: un-mocked regression test + negative-path test for composite from_onnx(dict) dispatch - NM-2: pad_inputs explicit ValueError on invalid mode 10 files, +137/-31. ruff check + ruff format clean. 3869 unit tests pass. Companion inline-comment review posted to PR #334. Constraint: PR author's API must be preserved; only defensive fixes Rejected: Add abc.ABC to WinMLCache (I1-AB deferred — subclass coordination) Rejected: Expose composite classes in __init__.py (C1 deferred — public API decision) Rejected: Fix silent fp16->fp32 fallback (C4 deferred — metadata contract) Confidence: high Scope-risk: narrow Directive: C5 surgery precondition and C2 skip-list policy need PR author response before merge; see pr_334_verdicts.md section 10.2 Phase 2 Not-tested: Mu2 num_hidden_layers runtime path (NI-5) --- src/winml/modelkit/commands/config.py | 4 +- src/winml/modelkit/models/auto.py | 13 ++- src/winml/modelkit/models/hf/mu2.py | 4 +- src/winml/modelkit/models/hf/qwen.py | 7 +- .../modelkit/models/winml/composite_model.py | 26 ++++-- .../modelkit/models/winml/decoder_only.py | 23 ++--- .../modelkit/models/winml/encoder_decoder.py | 9 +- src/winml/modelkit/models/winml/kv_cache.py | 1 + src/winml/modelkit/utils/data_utils.py | 4 +- tests/unit/models/auto/test_auto_onnx.py | 84 ++++++++++++++++++- 10 files changed, 142 insertions(+), 33 deletions(-) diff --git a/src/winml/modelkit/commands/config.py b/src/winml/modelkit/commands/config.py index 6fc7b6a09..d53edc6b7 100644 --- a/src/winml/modelkit/commands/config.py +++ b/src/winml/modelkit/commands/config.py @@ -528,7 +528,7 @@ def _resolve_composite_model_components( import winml.modelkit.models.hf # noqa: F401 # trigger pipeline registrations - from ..models.winml.composite_model import PIPELINE_MODEL_REGISTRY + from ..models.winml.composite_model import COMPOSITE_MODEL_REGISTRY # Resolve model_type from HF config if not provided resolved_type = model_type @@ -542,7 +542,7 @@ def _resolve_composite_model_components( if resolved_type is None: return None - cls = PIPELINE_MODEL_REGISTRY.get((resolved_type, task)) + cls = COMPOSITE_MODEL_REGISTRY.get((resolved_type, task)) return cls._SUB_MODEL_CONFIG if cls is not None else None diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index 6668bef7f..e9c0c01ad 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -36,6 +36,8 @@ if TYPE_CHECKING: + from transformers import PretrainedConfig + from ..config import WinMLBuildConfig from .winml.base import WinMLPreTrainedModel @@ -108,6 +110,7 @@ def from_onnx( force_rebuild: bool = False, skip_build: bool = False, session_options: Any | None = None, + hf_config: PretrainedConfig | None = None, **kwargs: Any, ) -> WinMLPreTrainedModel | WinMLCompositeModel: # noqa: F821 """Build from a pre-exported ONNX file. @@ -125,6 +128,10 @@ def from_onnx( cache_dir: Override cache directory. use_cache: Whether to use persistent cache. force_rebuild: Force rebuild even if cached. + hf_config: HF ``PretrainedConfig`` for composite (dict) dispatch only. + Required when ``onnx_path`` is a dict so the composite registry + lookup can resolve ``(model_type, task)``. Ignored for single-file + builds. **kwargs: Forwarded to ``build_onnx_model()``. Returns: @@ -136,7 +143,7 @@ def from_onnx( return WinMLCompositeModel.from_onnx( onnx_path, task=task, - config=config, + hf_config=hf_config, device=device, precision=precision, ep=ep, @@ -309,9 +316,9 @@ def from_pretrained( _hf_cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) _model_type = getattr(_hf_cfg, "model_type", None) - from .winml.composite_model import PIPELINE_MODEL_REGISTRY + from .winml.composite_model import COMPOSITE_MODEL_REGISTRY - if (_model_type, task) in PIPELINE_MODEL_REGISTRY: + if (_model_type, task) in COMPOSITE_MODEL_REGISTRY: from .winml.composite_model import WinMLCompositeModel return WinMLCompositeModel.from_pretrained( diff --git a/src/winml/modelkit/models/hf/mu2.py b/src/winml/modelkit/models/hf/mu2.py index 6f519f17c..54f09a5d1 100644 --- a/src/winml/modelkit/models/hf/mu2.py +++ b/src/winml/modelkit/models/hf/mu2.py @@ -145,8 +145,8 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: Returns: (logits, present_0_key, present_0_value, ...) where each - present KV is the full updated buffer [batch, n_kv_head, max_cache_len, head_dim] - (oldest entry evicted, new token appended at end). + present KV is the new-token slice only [batch, n_kv_head, seq_len, head_dim] + (raw key_states/value_states captured before Slice+Concat in WinMLSlidingWindowCache). """ decoder_input_ids = args[0] encoder_hidden_states = args[1] diff --git a/src/winml/modelkit/models/hf/qwen.py b/src/winml/modelkit/models/hf/qwen.py index 74373bb76..2bf727f7f 100644 --- a/src/winml/modelkit/models/hf/qwen.py +++ b/src/winml/modelkit/models/hf/qwen.py @@ -126,7 +126,8 @@ class QwenDecoderWrapper(nn.Module): Input KV: full static buffer ``[batch, kv_heads, max_cache_len, head_dim]``. Output KV: new positions only ``[batch, kv_heads, seq_len, head_dim]``. - Logits: last position only ``[batch, 1, vocab_size]`` (both prefill and gen). + Logits: all input positions ``[batch, seq_len, vocab_size]`` (both prefill and gen). + The caller selects the relevant position (last for gen, all for perplexity evaluation). """ def __init__(self, model: nn.Module, num_layers: int) -> None: @@ -151,12 +152,12 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: """Run decoder with static KV cache. Positional args (order matches OnnxConfig.inputs): - input_ids, attention_mask, position_ids, position_id, + input_ids, attention_mask, position_ids, past_0_key, past_0_value, past_1_key, past_1_value, ... Returns: (logits, present_0_key, present_0_value, ...) where: - - logits is ``[batch, 1, vocab_size]`` (last position only) + - logits is ``[batch, seq_len, vocab_size]`` (all positions) - present KV is ``[batch, kv_heads, seq_len, head_dim]`` """ input_ids = args[0] diff --git a/src/winml/modelkit/models/winml/composite_model.py b/src/winml/modelkit/models/winml/composite_model.py index b4fe1400b..f02d7ad4b 100644 --- a/src/winml/modelkit/models/winml/composite_model.py +++ b/src/winml/modelkit/models/winml/composite_model.py @@ -60,14 +60,21 @@ # Maps (model_type, task) → pipeline class with _SUB_MODEL_CONFIG. # Used by `wmk config` to generate one config file per sub-component. -PIPELINE_MODEL_REGISTRY: dict[tuple[str, str], type] = {} +COMPOSITE_MODEL_REGISTRY: dict[tuple[str, str], type] = {} def register_composite_model(model_type: str, task: str): """Class decorator that registers a composite model for `wmk config`.""" def decorator(cls: type) -> type: - PIPELINE_MODEL_REGISTRY[(model_type, task)] = cls + key = (model_type, task) + if key in COMPOSITE_MODEL_REGISTRY: + raise ValueError( + f"Composite model already registered for {key!r}: " + f"{COMPOSITE_MODEL_REGISTRY[key].__name__}. " + f"Cannot register {cls.__name__}." + ) + COMPOSITE_MODEL_REGISTRY[key] = cls return cls return decorator @@ -115,7 +122,7 @@ def from_pretrained( When called on ``WinMLCompositeModel`` directly (not a subclass), ``task`` is required to resolve the concrete class from - ``PIPELINE_MODEL_REGISTRY``. When called on a registered subclass + ``COMPOSITE_MODEL_REGISTRY``. When called on a registered subclass (e.g., ``WinMLT5Model``), ``task`` is optional. Args: @@ -143,11 +150,11 @@ def from_pretrained( if not cls._SUB_MODEL_CONFIG: # Resolve concrete class from registry when called on the base class - resolved_cls = PIPELINE_MODEL_REGISTRY.get((model_type, task)) + resolved_cls = COMPOSITE_MODEL_REGISTRY.get((model_type, task)) if resolved_cls is None: raise ValueError( f"No composite model registered for ({model_type!r}, {task!r}). " - f"Registered: {list(PIPELINE_MODEL_REGISTRY.keys())}" + f"Registered: {list(COMPOSITE_MODEL_REGISTRY.keys())}" ) return resolved_cls.from_pretrained( model_id, @@ -212,11 +219,11 @@ def from_onnx( # Resolve concrete class from registry model_type = getattr(hf_config, "model_type", None) if hf_config else None if not cls._SUB_MODEL_CONFIG: - resolved_cls = PIPELINE_MODEL_REGISTRY.get((model_type, task)) + resolved_cls = COMPOSITE_MODEL_REGISTRY.get((model_type, task)) if resolved_cls is None: raise ValueError( f"No composite model for ({model_type!r}, {task!r}). " - f"Registered: {list(PIPELINE_MODEL_REGISTRY.keys())}" + f"Registered: {list(COMPOSITE_MODEL_REGISTRY.keys())}" ) else: resolved_cls = cls @@ -226,6 +233,11 @@ def from_onnx( sub_models: dict[str, Any] = {} for name, path in onnx_path.items(): component_task = resolved_cls._SUB_MODEL_CONFIG.get(name) + if component_task is None: + valid = list(resolved_cls._SUB_MODEL_CONFIG.keys()) + raise ValueError( + f"Unknown component {name!r}. Valid names for {resolved_cls.__name__}: {valid}" + ) merged = {**kwargs, "task": component_task, **per_component.get(name, {})} sub_models[name] = WinMLAutoModel.from_onnx(Path(path), **merged) diff --git a/src/winml/modelkit/models/winml/decoder_only.py b/src/winml/modelkit/models/winml/decoder_only.py index 4a3f4671b..6bc89c993 100644 --- a/src/winml/modelkit/models/winml/decoder_only.py +++ b/src/winml/modelkit/models/winml/decoder_only.py @@ -7,7 +7,7 @@ Class hierarchy:: WinMLCompositeModel(PreTrainedModel) — multi-component base - └─ WinMLDecoderOnlyModel(GenerationMixin) — prefill + gen with StaticCache + └─ WinMLDecoderOnlyModel(GenerationMixin) — prefill + gen with WinMLCache └─ WinMLQwen3Model — Qwen3 tasks + generation config How it works: @@ -35,17 +35,20 @@ - **Generation** (``input_ids`` has 1 token): runs the gen ONNX model with the single token + full KV cache buffer as input. -4. KV cache uses HF ``StaticCache`` — same class as T5. ``get_seq_length()`` - counts non-zero positions; ``cache.update()`` writes new KV via - ``index_copy_``. The cache persists across generate() steps via - ``CausalLMOutputWithPast.past_key_values``. +4. KV cache is cache-agnostic — ``WinMLDecoderOnlyModel`` delegates mask + construction, position encoding, and cache writes to the ``WinMLCache`` + subclass. Two implementations ship: + ``WinMLStaticCache`` (ScatterElements/``index_copy_``) and + ``WinMLSlidingWindowCache`` (Slice+Concat FIFO). ``WinMLQwen3Model`` + selects the sliding-window variant. The cache persists across + ``generate()`` steps via ``CausalLMOutputWithPast.past_key_values``. 5. ``prepare_inputs_for_generation()`` handles a subtle interaction with ``GenerationMixin``: on the FIRST call, GenerationMixin may pass an auto-created ``DynamicCache`` (empty). We detect this (not a - ``StaticCache`` or empty) and pass the full prompt through for prefill + ``WinMLCache`` or empty) and pass the full prompt through for prefill rather than trimming to the last token. On subsequent calls with a - populated ``StaticCache``, we trim to the last token as usual. + populated ``WinMLCache``, we trim to the last token as usual. Design principles (same as composite_model.py): @@ -152,7 +155,7 @@ class DecoderOnlyPrefillInputGenerator(DecoderOnlyInputGenerator): # ========================================================================= -# WinMLDecoderOnlyModel — prefill + gen with StaticCache +# WinMLDecoderOnlyModel — prefill + gen with WinMLCache # ========================================================================= @@ -273,12 +276,12 @@ def forward( Args: input_ids: Token IDs ``[batch, seq_len]``. - past_key_values: StaticCache from previous step (None on first call). + past_key_values: ``WinMLCache`` from previous step (None on first call). attention_mask: Not used directly — rebuilt from cache occupancy. **kwargs: Absorbed for GenerationMixin compatibility. Returns: - CausalLMOutputWithPast with logits and updated StaticCache. + CausalLMOutputWithPast with logits and updated ``WinMLCache``. """ cache = self._resolve_cache(past_key_values) diff --git a/src/winml/modelkit/models/winml/encoder_decoder.py b/src/winml/modelkit/models/winml/encoder_decoder.py index f8b400a7d..962b8ef4c 100644 --- a/src/winml/modelkit/models/winml/encoder_decoder.py +++ b/src/winml/modelkit/models/winml/encoder_decoder.py @@ -285,11 +285,14 @@ def forward( input_ids: torch.Tensor | None = None, **model_kwargs: Any, ) -> Seq2SeqLMOutput: - """Run decoder with static KV cache. + """Run decoder with a ``WinMLCache`` (``WinMLStaticCache`` or + ``WinMLSlidingWindowCache``, selected by the subclass via + ``get_cache_class()``). Args: encoder_outputs: Pre-computed encoder hidden states. - past_key_values: StaticCache (or wrapper) from previous step. + past_key_values: ``WinMLCache`` (or ``EncoderDecoderCache`` + wrapper) from previous step. input_ids: Fallback — run encoder if encoder_outputs is None. **model_kwargs: Remaining kwargs forwarded to the decoder ONNX (e.g., decoder_input_ids, attention_mask). Each tensor is @@ -302,7 +305,7 @@ def forward( raise ValueError("Either encoder_outputs or input_ids required") enc_h = encoder_outputs["last_hidden_state"] - # Resolve or create cache (subclasses override _create_cache). + # Resolve or create cache (subclasses override get_cache_class). cache = self._resolve_cache(past_key_values) fc = cache.step diff --git a/src/winml/modelkit/models/winml/kv_cache.py b/src/winml/modelkit/models/winml/kv_cache.py index 65609d977..ff4862c9b 100644 --- a/src/winml/modelkit/models/winml/kv_cache.py +++ b/src/winml/modelkit/models/winml/kv_cache.py @@ -149,6 +149,7 @@ def update_all_layers(self, outputs: dict[str, Any]) -> None: def reset(self) -> None: """Zero out all layers and reset step (start of new generation).""" self.step = 0 + self.captured.clear() for i in range(self.num_layers): self.layers[i].keys.zero_() self.layers[i].values.zero_() diff --git a/src/winml/modelkit/utils/data_utils.py b/src/winml/modelkit/utils/data_utils.py index 148b60688..8d6919ff2 100644 --- a/src/winml/modelkit/utils/data_utils.py +++ b/src/winml/modelkit/utils/data_utils.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import Any +from typing import Any, Literal import torch @@ -15,7 +15,7 @@ def pad_inputs( source: dict[str, Any], expected: dict[str, list[int]], - mode: str = "right", + mode: Literal["left", "right"] = "right", ) -> dict[str, Any]: """Filter *source* to keys in *expected* and pad undersized tensors. diff --git a/tests/unit/models/auto/test_auto_onnx.py b/tests/unit/models/auto/test_auto_onnx.py index ba3a8b530..23aca3350 100644 --- a/tests/unit/models/auto/test_auto_onnx.py +++ b/tests/unit/models/auto/test_auto_onnx.py @@ -13,7 +13,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from unittest.mock import MagicMock, patch import pytest @@ -222,3 +222,85 @@ def test_dict_dispatches_to_composite(self, tmp_path: Path): call_kwargs = mock_from_onnx.call_args.kwargs assert call_kwargs["task"] == "translation" assert call_kwargs["skip_build"] is True + + def test_hf_config_dispatches_composite_via_registry(self, tmp_path: Path): + """hf_config kwarg threads through so model_type registry lookup works. + + Exercises the real WinMLCompositeModel.from_onnx body via a fake + subclass in a temporary registry slot. hf_config must be a dedicated + parameter on WinMLAutoModel.from_onnx (distinct from ``config``, which + is a WinMLBuildConfig and has no ``model_type`` attribute). + """ + from winml.modelkit.models.winml.composite_model import ( + COMPOSITE_MODEL_REGISTRY, + WinMLCompositeModel, + ) + + # Minimal HF-config stand-in: only attribute access (.model_type) is + # required; no isinstance check happens on hf_config in the dispatch. + class _FakeHFConfig: + model_type = "_test_dispatch_model_" + + enc_path = tmp_path / "enc.onnx" + dec_path = tmp_path / "dec.onnx" + enc_path.write_bytes(b"fake") + dec_path.write_bytes(b"fake") + + test_key = ("_test_dispatch_model_", "_test_task_") + + class _FakeComposite(WinMLCompositeModel): + _SUB_MODEL_CONFIG: ClassVar[dict[str, str]] = { + "encoder": "feature-extraction", + "decoder": "translation", + } + + def forward(self, **kwargs): # type: ignore[override] + pass + + assert test_key not in COMPOSITE_MODEL_REGISTRY + COMPOSITE_MODEL_REGISTRY[test_key] = _FakeComposite + try: + # Patch WinMLAutoModel.from_onnx: outer dict call falls through to + # the real implementation, inner per-component Path calls mocked. + _real_from_onnx = WinMLAutoModel.from_onnx + sub_mock = MagicMock() + sub_calls: list = [] + + def _side_effect(onnx_path, **kw): # type: ignore[no-untyped-def] + if isinstance(onnx_path, dict): + return _real_from_onnx(onnx_path, **kw) + sub_calls.append((onnx_path, kw)) + return sub_mock + + with patch.object(WinMLAutoModel, "from_onnx", side_effect=_side_effect): + result = WinMLAutoModel.from_onnx( + {"encoder": str(enc_path), "decoder": str(dec_path)}, + task="_test_task_", + hf_config=_FakeHFConfig(), + skip_build=True, + ) + + assert isinstance(result, _FakeComposite) + assert len(sub_calls) == 2 + tasks_called = {kw["task"] for _, kw in sub_calls} + assert tasks_called == {"feature-extraction", "translation"} + finally: + COMPOSITE_MODEL_REGISTRY.pop(test_key, None) + + def test_from_onnx_dict_without_hf_config_raises(self, tmp_path: Path): + """Dict dispatch without hf_config surfaces a clear registry-miss error. + + Guards against silent fallback: unregistered ``(model_type, task)`` must + raise ValueError immediately, not accept a wrong-typed kwarg and mis-dispatch. + """ + enc_path = tmp_path / "enc.onnx" + dec_path = tmp_path / "dec.onnx" + enc_path.write_bytes(b"fake") + dec_path.write_bytes(b"fake") + + with pytest.raises(ValueError, match="No composite model"): + WinMLAutoModel.from_onnx( + {"encoder": str(enc_path), "decoder": str(dec_path)}, + task="_unregistered_task_", + skip_build=True, + ) From e317536c379ac432ab06753ea434d938dcb252db Mon Sep 17 00:00:00 2001 From: Zac <1221537+tezheng@users.noreply.github.com> Date: Wed, 22 Apr 2026 23:19:07 +0800 Subject: [PATCH 30/32] =?UTF-8?q?fix(models):=20Phase=201=20follow-up=20?= =?UTF-8?q?=E2=80=94=206=20review=20findings=20+=202=20critic=20fixes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Applies Phase 1 quick-win fixes from pr_334_verdicts.md section 10.2, plus two bugs caught by an independent critic review of the initial pass: - C1: Export WinMLCompositeModel, COMPOSITE_MODEL_REGISTRY, register_composite_model, WinMLCache, WinMLStaticCache, WinMLSlidingWindowCache, WinMLEncoderDecoderModel, WinMLDecoderOnlyModel from models/winml/__init__.py - C4: Replace silent .get("past_0_key", np.float32) with explicit KeyError in encoder_decoder.py and decoder_only.py — stops fp16 models from silently being coerced to fp32 when ONNX metadata lacks the key - I1-AB: class WinMLCache(StaticCache, ABC) + ClassVar[str] default for position_input_name (subclass contract now enforced at instantiation time rather than via unreachable AttributeError) - I7: Narrow run_eval.py 'except Exception: pass' blocks to specific expected exceptions with diagnostic logging; fixed shadowing of loop variable (as e -> as exc) caught by critic — original fix would have crashed on corrupt result files - NI-4: Guard input_ids[:, -1:] slice in WinMLEncoderDecoderModel prepare_inputs_for_generation on cache occupancy; multi-token decoder prompts (e.g. forced BOS + prefix) no longer silently truncated on first decode step - NM-2: pad_inputs emits (0, 0) pair for non-int expected dims instead of skipping via continue — critic-caught dim-pair misalignment in 3D+ tensors when a dynamic dim sits between static dims 6 files, 3876 unit tests pass (unchanged from baseline), ruff check + ruff format clean on all changed files. Constraint: PR author's API preserved; defensive fixes only Rejected: Add logger.warning on NM-2 dim skip (pure polish, deferred) Confidence: high Scope-risk: narrow Directive: Remaining Phase 2 items (C2, C5, C9, I3, I10, NI-5) need PR author response before merge. See pr_334_verdicts.md section 10.2. Not-tested: NM-2 alignment on 5D+ tensors (project has no 5D+ ONNX inputs; verified behavioral for 3D + 4D interleaved patterns) --- scripts/e2e_eval/run_eval.py | 10 ++++---- src/winml/modelkit/models/winml/__init__.py | 20 ++++++++++++++++ .../modelkit/models/winml/decoder_only.py | 8 ++++++- .../modelkit/models/winml/encoder_decoder.py | 23 +++++++++++++++---- src/winml/modelkit/models/winml/kv_cache.py | 9 ++++---- src/winml/modelkit/utils/data_utils.py | 8 ++++++- 6 files changed, 63 insertions(+), 15 deletions(-) diff --git a/scripts/e2e_eval/run_eval.py b/scripts/e2e_eval/run_eval.py index fe4b2a547..2a56a165c 100644 --- a/scripts/e2e_eval/run_eval.py +++ b/scripts/e2e_eval/run_eval.py @@ -1069,8 +1069,8 @@ def main() -> None: if e.hf_id == args.hf_model: matched_entry = e break - except Exception: - pass # Registry is optional for single-model mode; proceed without enrichment + except Exception as e: + safe_print(f" [registry] Optional enrichment skipped: {e}") if matched_entry is not None: # Override task if explicitly provided on CLI if args.task and args.task != matched_entry.task: @@ -1133,8 +1133,10 @@ def main() -> None: if _should_skip_existing(existing, retry_types, args.eval_type): skipped_count += 1 continue - except Exception: - pass # Corrupt result file — include model for re-evaluation + except (OSError, json.JSONDecodeError, KeyError) as exc: + safe_print( + f" [continue] Corrupt result file {result_path}: {exc} — re-evaluating" + ) filtered.append(e) if skipped_count: safe_print( diff --git a/src/winml/modelkit/models/winml/__init__.py b/src/winml/modelkit/models/winml/__init__.py index f82959b09..4d57d3bdc 100644 --- a/src/winml/modelkit/models/winml/__init__.py +++ b/src/winml/modelkit/models/winml/__init__.py @@ -175,6 +175,13 @@ def register_specialization(model_type: str, task: str, class_name: str) -> None # ============================================================================= from .base import WinMLModelForGenericTask, WinMLPreTrainedModel +from .composite_model import ( + COMPOSITE_MODEL_REGISTRY, + WinMLCompositeModel, + register_composite_model, +) +from .decoder_only import WinMLDecoderOnlyModel +from .encoder_decoder import WinMLEncoderDecoderModel from .feature_extraction import WinMLModelForFeatureExtraction from .image_classification import WinMLModelForImageClassification from .image_segmentation import ( @@ -182,14 +189,24 @@ def register_specialization(model_type: str, task: str, class_name: str) -> None WinMLModelForImageSegmentation, WinMLModelForSemanticSegmentation, ) +from .kv_cache import ( + WinMLCache, + WinMLSlidingWindowCache, + WinMLStaticCache, +) from .object_detection import WinMLModelForObjectDetection from .sequence_classification import WinMLModelForSequenceClassification __all__ = [ + "COMPOSITE_MODEL_REGISTRY", "TASK_TO_WINML_CLASS", "WINML_MODEL_CLASS_MAPPING", "ImageSegmentationOutput", + "WinMLCache", + "WinMLCompositeModel", + "WinMLDecoderOnlyModel", + "WinMLEncoderDecoderModel", "WinMLModelForFeatureExtraction", "WinMLModelForGenericTask", "WinMLModelForImageClassification", @@ -198,7 +215,10 @@ def register_specialization(model_type: str, task: str, class_name: str) -> None "WinMLModelForSemanticSegmentation", "WinMLModelForSequenceClassification", "WinMLPreTrainedModel", + "WinMLSlidingWindowCache", + "WinMLStaticCache", "get_supported_tasks", "get_winml_class", + "register_composite_model", "register_specialization", ] diff --git a/src/winml/modelkit/models/winml/decoder_only.py b/src/winml/modelkit/models/winml/decoder_only.py index 6bc89c993..6048fb077 100644 --- a/src/winml/modelkit/models/winml/decoder_only.py +++ b/src/winml/modelkit/models/winml/decoder_only.py @@ -210,7 +210,13 @@ def __init__( ) import numpy as np - _np_dtype = gen_type_map.get("past_0_key", np.float32) + if "past_0_key" not in gen_type_map: + raise KeyError( + "'past_0_key' is missing from the decoder ONNX input type map; " + "cannot derive KV cache dtype. Verify the decoder ONNX was built with " + "PastKeyValueInputGenerator." + ) + _np_dtype = gen_type_map["past_0_key"] self._kv_dtype = torch.from_numpy(np.zeros(1, dtype=_np_dtype)).dtype # Prefill chunk size diff --git a/src/winml/modelkit/models/winml/encoder_decoder.py b/src/winml/modelkit/models/winml/encoder_decoder.py index 962b8ef4c..ae1b669fe 100644 --- a/src/winml/modelkit/models/winml/encoder_decoder.py +++ b/src/winml/modelkit/models/winml/encoder_decoder.py @@ -199,7 +199,13 @@ def __init__( ) import numpy as np - _np_dtype = dec_type_map.get("past_0_key", np.float32) + if "past_0_key" not in dec_type_map: + raise KeyError( + "'past_0_key' is missing from the decoder ONNX input type map; " + "cannot derive KV cache dtype. Verify the decoder ONNX was built with " + "PastKeyValueInputGenerator." + ) + _np_dtype = dec_type_map["past_0_key"] self._kv_dtype = torch.from_numpy(np.zeros(1, dtype=_np_dtype)).dtype # ----- Encoder ----- @@ -238,8 +244,14 @@ def prepare_inputs_for_generation( **kwargs: Any, ) -> dict[str, Any]: """Build decoder inputs for each generate() step.""" + from .kv_cache import WinMLCache + + if isinstance(past_key_values, WinMLCache) and past_key_values.get_seq_length() > 0: + decoder_input_ids = input_ids[:, -1:] + else: + decoder_input_ids = input_ids return { - "decoder_input_ids": input_ids[:, -1:], + "decoder_input_ids": decoder_input_ids, "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, "past_key_values": past_key_values, @@ -285,9 +297,10 @@ def forward( input_ids: torch.Tensor | None = None, **model_kwargs: Any, ) -> Seq2SeqLMOutput: - """Run decoder with a ``WinMLCache`` (``WinMLStaticCache`` or - ``WinMLSlidingWindowCache``, selected by the subclass via - ``get_cache_class()``). + """Run decoder with a WinML KV cache. + + Uses ``WinMLStaticCache`` or ``WinMLSlidingWindowCache``, selected by + the subclass via ``get_cache_class()``. Args: encoder_outputs: Pre-computed encoder hidden states. diff --git a/src/winml/modelkit/models/winml/kv_cache.py b/src/winml/modelkit/models/winml/kv_cache.py index ff4862c9b..b835281c5 100644 --- a/src/winml/modelkit/models/winml/kv_cache.py +++ b/src/winml/modelkit/models/winml/kv_cache.py @@ -41,8 +41,8 @@ from __future__ import annotations -from abc import abstractmethod -from typing import TYPE_CHECKING, Any +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, ClassVar from optimum.utils.input_generators import DummyInputGenerator from transformers import StaticCache @@ -59,7 +59,7 @@ # ============================================================================= -class WinMLCache(StaticCache): +class WinMLCache(StaticCache, ABC): """Abstract base for WinML KV caches (export + inference). Subclasses set ``position_input_name``, implement ``build_decoder_mask``, @@ -71,7 +71,8 @@ class WinMLCache(StaticCache): """ #: ONNX input name for the position tensor (subclasses override). - position_input_name: str + #: Empty string is a sentinel — concrete subclasses must set a real value. + position_input_name: ClassVar[str] = "" def __init__(self, config: PretrainedConfig, *args: Any, **kwargs: Any) -> None: super().__init__(config, *args, **kwargs) diff --git a/src/winml/modelkit/utils/data_utils.py b/src/winml/modelkit/utils/data_utils.py index 8d6919ff2..bccf7acd1 100644 --- a/src/winml/modelkit/utils/data_utils.py +++ b/src/winml/modelkit/utils/data_utils.py @@ -47,7 +47,13 @@ def pad_inputs( # dim backwards. Skip batch dim (dim 0). pad: list[int] = [] for dim in reversed(range(1, ndim)): - deficit = max(expected_shape[dim] - val.shape[dim], 0) + exp = expected_shape[dim] + # Dynamic ONNX dims may be None or a string symbol; emit a + # (0, 0) pair so later pairs stay aligned with their dim index. + if not isinstance(exp, int): + pad.extend([0, 0]) + continue + deficit = max(exp - val.shape[dim], 0) if mode == "right": pad.extend([0, deficit]) else: # left From 3b1a98351e298232b956409a65c3824d92aa0c93 Mon Sep 17 00:00:00 2001 From: Zac <1221537+tezheng@users.noreply.github.com> Date: Thu, 23 Apr 2026 00:37:10 +0800 Subject: [PATCH 31/32] =?UTF-8?q?fix(models):=20Phase=202=20follow-up=20?= =?UTF-8?q?=E2=80=94=209=20review=20findings=20+=20critic=20regression?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses 9 Important+Minor findings from pr_334_verdicts.md section 10, including one regression caught during critic review of the initial pass: - I3: trust_remote_code → explicit keyword-only param on WinMLCompositeModel.from_pretrained (was untyped **kwargs lookup); docstring updated. - I4: cache KNOWN_COMPOSITE_TASKS from registry; gate AutoConfig probe on task matching a known composite task. Non-composite callers no longer pay for a redundant HF config load on every from_pretrained. - I6: skip accuracy phase for composite models with explicit skip_reason=composite_model_not_supported (was arbitrary sub-model pick via next(iter(...))). - I16: mirror encoder_decoder.py's EncoderDecoderCache unwrap in decoder_only.py._resolve_cache for defensive symmetry; documented. - NI-7: atomic per-sub-component config writes via tmp+replace. - NI-9: move stale-config cleanup BEFORE wmk config invocation. The initial fix ran the cleanup AFTER, which silently deleted freshly- written composite sub-configs — caught by critic review. - NM-1: demote .to() no-op log to DEBUG (HF pipelines routinely call model.to('cpu') as setup; WARNING would spam normal usage). - NM-5: document RoPE-at-position-0 safety for padding position_ids (covered by attention mask; doc-only change). - NM-6: replace terse 'see C9 review comment' with self-contained explanation in both _run_prefill and _run_gen dead branches. 6 files. 3876 unit tests pass (unchanged from baseline). ruff clean on all touched files. Constraint: Preserve PR author's API; defensive fixes only Rejected: Raise ValueError on trust_remote_code=True default (kept False default to preserve existing caller semantics) Rejected: Elevate .to() log to WARNING (would spam normal HF pipeline usage) Confidence: high Scope-risk: narrow Directive: Critical items C2/C5/C9 still need PR author response; posted as Phase A review questions (pullrequestreview-4156157489). Architectural deferrals (I5, I10, NI-1, NI-2, NI-3, NI-5, NM-3, NM-4) tracked as a general PR comment. Not-tested: I6 composite accuracy skip flow has no unit test; verified by inspection + existing single-model regression --- scripts/e2e_eval/run_eval.py | 28 +++++++++++++++++-- src/winml/modelkit/commands/config.py | 4 ++- src/winml/modelkit/models/auto.py | 18 ++++++++---- .../modelkit/models/winml/composite_model.py | 16 +++++++++-- .../modelkit/models/winml/decoder_only.py | 25 ++++++++++++++++- src/winml/modelkit/models/winml/kv_cache.py | 3 ++ 6 files changed, 83 insertions(+), 11 deletions(-) diff --git a/scripts/e2e_eval/run_eval.py b/scripts/e2e_eval/run_eval.py index 2a56a165c..52e2be3d3 100644 --- a/scripts/e2e_eval/run_eval.py +++ b/scripts/e2e_eval/run_eval.py @@ -362,6 +362,16 @@ def _run_build( config_path = model_dir / "build_config.json" model_dir.mkdir(parents=True, exist_ok=True) + # Remove any stale suffixed sub-configs BEFORE `wmk config` runs. + # For composite models `wmk config` writes files matching {stem}_*.json + # (e.g., build_config_encoder.json); cleaning those AFTER the command would + # delete the freshly-written configs and silently degrade composite builds + # to single-model. Running cleanup first removes prior-run artifacts without + # touching the current run's output. + for _stale in config_path.parent.glob(f"{config_path.stem}_*.json"): + safe_print(f" [config] Removing stale sub-config from prior run: {_stale.name}") + _stale.unlink(missing_ok=True) + # Step 1: winml config config_args = [ *WINML_CLI, @@ -1279,8 +1289,13 @@ def main() -> None: model_dir, ) onnx_paths = build_result["onnx_paths"] if build_result["success"] else {} - # First ONNX path for accuracy phase (TODO: composite model support) - first_path = next(iter(onnx_paths.values()), None) if onnx_paths else None + # Composite models produce multiple ONNX paths; accuracy phase requires a + # single path and is not yet supported for composite models. + # TODO: composite model accuracy support + is_composite = len(onnx_paths) > 1 + first_path = ( + next(iter(onnx_paths.values()), None) if onnx_paths and not is_composite else None + ) if not build_result["success"]: # Build failed — synthesize failed result for downstream phases @@ -1293,6 +1308,15 @@ def main() -> None: perf_proc = fail_proc if args.eval_type != "perf": accuracy_result = {"skipped": True, "skip_reason": "build_failed"} + elif is_composite and args.eval_type != "perf": + # Accuracy phase skipped for composite models (TODO: composite accuracy support) + safe_print( + f" [accuracy] Skipped for composite model {entry.hf_id} " + "(multiple ONNX paths; composite accuracy evaluation not yet implemented)" + ) + accuracy_result = {"skipped": True, "skip_reason": "composite_model_not_supported"} + if args.eval_type == "both": + perf_proc = run_model(entry, args.device, args.timeout, onnx_paths) elif args.eval_type == "accuracy": accuracy_result = _run_accuracy_phase( entry, diff --git a/src/winml/modelkit/commands/config.py b/src/winml/modelkit/commands/config.py index d53edc6b7..445ab782a 100644 --- a/src/winml/modelkit/commands/config.py +++ b/src/winml/modelkit/commands/config.py @@ -594,7 +594,9 @@ def _generate_pipeline_configs( out_path = Path(output) suffixed = out_path.with_stem(f"{out_path.stem}_{component_name}") suffixed.parent.mkdir(parents=True, exist_ok=True) - suffixed.write_text(config_json) + tmp = suffixed.with_suffix(".json.tmp") + tmp.write_text(config_json) + tmp.replace(suffixed) console.print(f"[green]Config saved to:[/green] {suffixed}") else: console.print(f"[bold]--- {component_name} ({component_task}) ---[/bold]") diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index e9c0c01ad..f0f55a4db 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -310,15 +310,23 @@ def from_pretrained( # COMPOSITE MODEL CHECK — delegate to WinMLCompositeModel.from_pretrained # when (model_type, task) is a registered composite (e.g., T5 translation, # Qwen text-generation). AutoConfig is lightweight (~config.json only). + # The registry probe (AutoConfig.from_pretrained) is gated on whether + # `task` appears in any registered composite entry, avoiding an + # unconditional network/disk round-trip for every non-composite call. # ===================================================================== if task is not None: - from transformers import AutoConfig - - _hf_cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) - _model_type = getattr(_hf_cfg, "model_type", None) from .winml.composite_model import COMPOSITE_MODEL_REGISTRY - if (_model_type, task) in COMPOSITE_MODEL_REGISTRY: + _known_composite_tasks = {t for (_, t) in COMPOSITE_MODEL_REGISTRY} + if task in _known_composite_tasks: + from transformers import AutoConfig + + _hf_cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) + _model_type = getattr(_hf_cfg, "model_type", None) + else: + _model_type = None + + if _model_type is not None and (_model_type, task) in COMPOSITE_MODEL_REGISTRY: from .winml.composite_model import WinMLCompositeModel return WinMLCompositeModel.from_pretrained( diff --git a/src/winml/modelkit/models/winml/composite_model.py b/src/winml/modelkit/models/winml/composite_model.py index f02d7ad4b..3f4f5aa64 100644 --- a/src/winml/modelkit/models/winml/composite_model.py +++ b/src/winml/modelkit/models/winml/composite_model.py @@ -116,6 +116,7 @@ def from_pretrained( use_cache: bool = True, force_rebuild: bool = False, sub_model_kwargs: dict[str, dict[str, Any]] | None = None, + trust_remote_code: bool = False, **kwargs: Any, ) -> WinMLCompositeModel: """Build all sub-components and return ready-to-use model. @@ -139,12 +140,14 @@ def from_pretrained( ``"decoder_gen"``). Values are dicts merged on top of the shared ``**kwargs``. Use this to pass different ``shape_config`` per sub-model. + trust_remote_code: Forward to ``AutoConfig.from_pretrained`` + and each sub-model's ``WinMLAutoModel.from_pretrained``. + Required for custom-code HF models (e.g., Mu2). **kwargs: Forwarded to ``WinMLAutoModel.from_pretrained()`` for every sub-component (overridden by ``sub_model_kwargs``). """ from transformers import AutoConfig - trust_remote_code = kwargs.get("trust_remote_code", False) hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) model_type = hf_config.model_type @@ -163,6 +166,7 @@ def from_pretrained( use_cache=use_cache, force_rebuild=force_rebuild, sub_model_kwargs=sub_model_kwargs, + trust_remote_code=trust_remote_code, **kwargs, ) from ..auto import WinMLAutoModel @@ -178,6 +182,7 @@ def from_pretrained( device=device, use_cache=use_cache, force_rebuild=force_rebuild, + trust_remote_code=trust_remote_code, **merged, ) @@ -259,7 +264,14 @@ def dtype(self) -> torch.dtype: return torch.float32 def to(self, *args: Any, **kwargs: Any) -> WinMLCompositeModel: - """No-op for HF pipeline compatibility.""" + """No-op for HF pipeline compatibility; sub-models remain on their original device.""" + if args or kwargs: + # debug (not warning) — HF pipelines routinely call `.to("cpu")` as a + # setup step; surfacing that as a warning would spam normal usage. + logger.debug( + "WinMLCompositeModel.to(...) is a no-op; sub-models remain on their original " + "device. Use WinMLSession options to control device placement." + ) return self def __call__(self, **kwargs: Any) -> Any: diff --git a/src/winml/modelkit/models/winml/decoder_only.py b/src/winml/modelkit/models/winml/decoder_only.py index 6048fb077..3bfa77700 100644 --- a/src/winml/modelkit/models/winml/decoder_only.py +++ b/src/winml/modelkit/models/winml/decoder_only.py @@ -230,9 +230,23 @@ def get_cache_class(cls) -> type: raise NotImplementedError def _resolve_cache(self, past_key_values: Any) -> Any: - """Unwrap or create WinMLCache for this generation step.""" + """Unwrap or create WinMLCache for this generation step. + + 1. Unwrap EncoderDecoderCache wrapper (GenerationMixin may add it even for + decoder-only models in rare paths; handled here for symmetry with + encoder_decoder.py). + 2. If already a WinMLCache, return directly. + 3. Otherwise create a fresh one and reset it. + """ from .kv_cache import WinMLCache + # (1) Unwrap EncoderDecoderCache — never received by decoder-only models + # under the current GenerationMixin flow, but mirroring encoder_decoder.py's + # defensive unwrap keeps the two _resolve_cache paths symmetric. + if hasattr(past_key_values, "self_attention_cache"): + past_key_values = past_key_values.self_attention_cache + + # (2) Already our cache — return as-is if isinstance(past_key_values, WinMLCache): return past_key_values @@ -328,6 +342,11 @@ def _run_prefill(self, input_ids: torch.Tensor, cache: Any) -> torch.Tensor: "attention_mask": attn_mask, "position_ids": position_ids, } + # NOTE: currently dead for Qwen3 (cache_position is not in the Qwen + # prefill ONNX inputs). Kept defensively for future decoder-only + # models whose OnnxConfig declares cache_position; see the + # StaticCache switching instructions at the top of hf/qwen.py for + # the position-alignment caveat before activating this branch. if "cache_position" in self._prefill_expected: feeds["cache_position"] = position_ids.squeeze(0) for i in range(self._num_kv_layers): @@ -363,6 +382,10 @@ def _run_gen(self, input_ids: torch.Tensor, cache: Any) -> torch.Tensor: "attention_mask": attn_mask, "position_ids": torch.tensor([[fc]], dtype=torch.int64), } + # NOTE: see the matching note in `_run_prefill` above. Currently dead + # for Qwen3 (cache_position is not in the gen ONNX inputs). Kept for + # future decoder-only models that declare cache_position in their + # OnnxConfig; activate with care re: the position-alignment caveat. if "cache_position" in self._gen_expected: feeds["cache_position"] = feeds["position_ids"].squeeze(0) for i in range(self._num_kv_layers): diff --git a/src/winml/modelkit/models/winml/kv_cache.py b/src/winml/modelkit/models/winml/kv_cache.py index b835281c5..870659b03 100644 --- a/src/winml/modelkit/models/winml/kv_cache.py +++ b/src/winml/modelkit/models/winml/kv_cache.py @@ -317,6 +317,9 @@ def prepare_prefill_chunk( padded_ids = torch.zeros(1, prefill_seq_len, dtype=chunk_ids.dtype) padded_ids[0, pad_len:] = chunk_ids[0] + # Padding positions get 0 — RoPE computes embeddings for position 0 on these, + # but the attention mask at build_decoder_mask masks them out before softmax, + # so the RoPE artifacts don't influence outputs. position_ids = torch.zeros(1, prefill_seq_len, dtype=torch.int64) position_ids[0, pad_len:] = torch.arange(start, start + chunk_len, dtype=torch.int64) From 4fa8c56abe5cfad3533206b590dd3eeda9063fd2 Mon Sep 17 00:00:00 2001 From: Zac <1221537+tezheng@users.noreply.github.com> Date: Thu, 23 Apr 2026 11:28:47 +0800 Subject: [PATCH 32/32] =?UTF-8?q?fix(models):=20I8=20=E2=80=94=20widen=20W?= =?UTF-8?q?inMLCompositeModel.from=5Fonnx=20onnx=5Fpath=20type?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Match the caller signature. WinMLAutoModel.from_onnx declares `onnx_path: str | Path | dict[str, str | Path]` at auto.py:99 but the composite callee declared `dict[str, str]`, so a caller passing a Path value through the dict branch triggered a type-checker complaint even though runtime Path(path) coercion inside the dispatch loop made it work. - composite_model.py:194 — widen `dict[str, str]` → `dict[str, str | Path]` - composite_model.py:210 — docstring note that str and Path are both accepted - TYPE_CHECKING import: `from pathlib import Path` (no runtime cost) Previously retracted as "cosmetic/non-actionable" on the grounds that the runtime handled both via `Path(path)` coercion. Reversing that retraction — a type-annotation mismatch between a public caller and callee is a real defect a strict type-checker (mypy/pyright) would flag. 1 file, ruff check + format clean, targeted pytest green (81/81). Confidence: high Scope-risk: narrow --- src/winml/modelkit/models/winml/composite_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/winml/modelkit/models/winml/composite_model.py b/src/winml/modelkit/models/winml/composite_model.py index 3f4f5aa64..8a887cb2f 100644 --- a/src/winml/modelkit/models/winml/composite_model.py +++ b/src/winml/modelkit/models/winml/composite_model.py @@ -49,6 +49,8 @@ if TYPE_CHECKING: + from pathlib import Path + from transformers import PretrainedConfig logger = logging.getLogger(__name__) @@ -191,7 +193,7 @@ def from_pretrained( @classmethod def from_onnx( cls, - onnx_path: dict[str, str], + onnx_path: dict[str, str | Path], *, task: str | None = None, hf_config: PretrainedConfig | None = None, @@ -206,7 +208,9 @@ def from_onnx( Args: onnx_path: Maps component name (e.g., ``"encoder"``, - ``"decoder_prefill"``) to its ONNX file path. + ``"decoder_prefill"``) to its ONNX file path. Values may + be ``str`` or ``pathlib.Path``; coerced via ``Path(path)`` + inside the dispatch loop. task: Pipeline task (e.g., ``"translation"``, ``"text-generation"``). hf_config: HF ``PretrainedConfig`` for the model. Used to