feat: multi-model support with KV cache (T5, Qwen, Mu2)#334
Conversation
Derive KV cache dtype from ONNX model metadata instead of hardcoding float32, enabling fp16 models to use fp16 StaticCache. Thread session_options from WinMLAutoModel.from_onnx through to WinMLSession.
Introduce WinMLCache abstract base (step, num_layers, build_decoder_mask, update_all_layers, reset, create). Two concrete implementations: - WinMLStaticCache (ScatterElements/index_copy_, T5/Qwen) - WinMLSlidingWindowCache (Slice+Concat FIFO, Mu2) WinMLEncoderDecoderModel.forward is now cache-agnostic — calls only WinMLCache interface methods. Each model subclass declares get_cache_class(). No isinstance checks in the forward path.
Move content from docs/design/pipeline-model.md into the files that own the code: kv_cache.py (cache compatibility), encoder_decoder.py (forward gotchas), mu2.py (custom model integration), pipeline_model.py (registry + sub_model_kwargs). Add T5 get_cache_class with explanation of why sliding window is incompatible. Remove the now-redundant design doc.
WinMLSlidingWindowCache.update() now captures new-token KV (like WinMLStaticCache) and stores in self.captured. Export wrappers output captured[i] for both cache types. update_all_layers() does Slice+Concat on the inference side. Move captured dict to WinMLCache base class.
- Generalize WinMLSlidingWindowCache.update() for N tokens (prefill+gen) - Qwen3 uses sliding window: cache_position computed internally as right-aligned buffer positions, position_ids handles RoPE separately - Left-pad prefill chunks so real tokens are at END (matches causal mask) - Add _resolve_cache to WinMLDecoderOnlyModel (same pattern as enc-dec) - Make get_cache_class abstract in both base classes - Rename pipeline_model.py -> composite_model.py, register_composite_model - Remove position_id ONNX input from Qwen (no longer needed) - update_all_layers moved to WinMLCache base (calls subclass update())
Make WinMLDecoderOnlyModel cache-agnostic by delegating padding, mask construction, and cache updates to the WinMLCache subclass. kv_cache.py: - build_decoder_mask: add num_new_tokens param (default 1) - prepare_prefill_chunk: new abstract method — left-pad (sliding window) vs right-pad (static cache) - update_all_layers: cache_position as range instead of scalar so StaticCache.index_copy_ works with multi-token prefill KV decoder_only.py: - _run_prefill: delegates to cache.prepare_prefill_chunk and cache.build_decoder_mask; slices padding from outputs before update_all_layers - _run_gen: uses cache.build_decoder_mask instead of inline mask - Both pass cache_position in feeds when the ONNX model expects it Verified: Qwen3-0.6B e2e with both WinMLSlidingWindowCache and WinMLStaticCache produces correct results.
Fixes CodeQL finding: global variable '_pad_inputs' is not used. The refactoring to polymorphic cache methods removed the last call site.
Document how to switch Mu2 from WinMLSlidingWindowCache to WinMLStaticCache (3 changes: wrapper, OnnxConfig, get_cache_class). Verified: Mu2 e2e correct with both cache types (6/6 queries).
auto.py: - Fix circular import: move WinMLCompositeModel import under TYPE_CHECKING, lazy import in from_onnx - from_onnx: delegate dict onnx_path to WinMLCompositeModel.from_onnx - from_pretrained: check PIPELINE_MODEL_REGISTRY before config phase, delegate to WinMLCompositeModel.from_pretrained for composite models composite_model.py: - Implement from_onnx: resolves concrete class from registry using task + hf_config.model_type, builds each sub-component via WinMLAutoModel.from_onnx with per-component task from _SUB_MODEL_CONFIG Verified: T5 translation and Mu2 translation via both from_onnx(dict) and from_pretrained produce correct results.
run_eval.py:
- Generalize _run_build: winml config produces a list of config JSONs
(1 for single model, N for composite), build loop handles both
- Generalize run_model (perf): takes list of ONNX paths, runs perf
for each, merges results — single model is list-of-1 case
t5.py:
- Register WinMLT5Model for ("t5", "summarization") in addition to
("t5", "translation")
timeout_skip_list.json:
- Remove T5 entries (t5-small, t5-base, t5-3b) — composite model
build now works
Verified: T5-small summarization + translation via both from_pretrained
and run_eval.py perf pipeline.
…wen/Mu2 surgery capability: - Add REMOVE_ISNAN_IN_ATTENTION_MASK: removes Softmax->IsNaN->Where NaN guard patterns (dead code when clamp_constant_values replaces -inf with finite value) model configs (gelu_fusion, fuse_rmsnorm, matmul_add_fusion, clamp_constant_values, remove_isnan_in_attention_mask): - T5_CONFIG: registered for model_type "t5" - MU2_CONFIG: registered for model_type "mu2" - QWEN_CONFIG: add optim flags (already had export config) run_eval.py: - Fix _extract_onnx_path: match "Artifact:" and "Existing artifact found:" markers in addition to "Final artifact:"
- Uncomment @register_onnx_overwrite for depth_pro (was disabled pending quantization fix) - Add apple/DepthPro-hf and Qwen3 models to timeout skip list (OOM segfault during in-process quantization)
test_io.py (14 new tests): - TestStaticCacheBuildDecoderMask: left-aligned mask, num_new_tokens - TestSlidingWindowCacheBuildDecoderMask: right-aligned mask, saturation - TestStaticCachePreparePrefillChunk: right-pad, pad_len=0 - TestSlidingWindowCachePreparePrefillChunk: left-pad, pad_len>0 test_auto_onnx.py (1 new test): - TestFromOnnxDictDispatch: dict onnx_path delegates to WinMLCompositeModel.from_onnx with correct kwargs
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Applies localized fixes for 12 review findings + 1 partial, rebased onto 3c76601 (post T5 SlidingWindow refactor): - C3: register_composite_model duplicate-key ValueError - C6: WinMLCompositeModel.from_onnx unknown-component ValueError - C8: dedicated hf_config: PretrainedConfig on from_onnx(dict) dispatch (previously passed WinMLBuildConfig, causing registry miss; headline API was non-functional for real callers, hidden by mocked tests) - I1: kv_cache.WinMLCache.reset() clears self.captured - I2: rename PIPELINE_MODEL_REGISTRY to COMPOSITE_MODEL_REGISTRY - I9: pad_inputs mode: Literal["left", "right"] - I11-I14: docstring fixes on mu2, qwen, decoder_only, encoder_decoder (including StaticCache references now stale after the T5 refactor) - NI-6: remove phantom position_id from Qwen forward docstring - NI-8: un-mocked regression test + negative-path test for composite from_onnx(dict) dispatch - NM-2: pad_inputs explicit ValueError on invalid mode 10 files, +137/-31. ruff check + ruff format clean. 3869 unit tests pass. Companion inline-comment review posted to PR #334. Constraint: PR author's API must be preserved; only defensive fixes Rejected: Add abc.ABC to WinMLCache (I1-AB deferred — subclass coordination) Rejected: Expose composite classes in __init__.py (C1 deferred — public API decision) Rejected: Fix silent fp16->fp32 fallback (C4 deferred — metadata contract) Confidence: high Scope-risk: narrow Directive: C5 surgery precondition and C2 skip-list policy need PR author response before merge; see pr_334_verdicts.md section 10.2 Phase 2 Not-tested: Mu2 num_hidden_layers runtime path (NI-5)
Applies Phase 1 quick-win fixes from pr_334_verdicts.md section 10.2,
plus two bugs caught by an independent critic review of the initial pass:
- C1: Export WinMLCompositeModel, COMPOSITE_MODEL_REGISTRY,
register_composite_model, WinMLCache, WinMLStaticCache,
WinMLSlidingWindowCache, WinMLEncoderDecoderModel,
WinMLDecoderOnlyModel from models/winml/__init__.py
- C4: Replace silent .get("past_0_key", np.float32) with explicit
KeyError in encoder_decoder.py and decoder_only.py — stops
fp16 models from silently being coerced to fp32 when ONNX
metadata lacks the key
- I1-AB: class WinMLCache(StaticCache, ABC) + ClassVar[str] default
for position_input_name (subclass contract now enforced at
instantiation time rather than via unreachable AttributeError)
- I7: Narrow run_eval.py 'except Exception: pass' blocks to specific
expected exceptions with diagnostic logging; fixed shadowing of
loop variable (as e -> as exc) caught by critic — original fix
would have crashed on corrupt result files
- NI-4: Guard input_ids[:, -1:] slice in WinMLEncoderDecoderModel
prepare_inputs_for_generation on cache occupancy; multi-token
decoder prompts (e.g. forced BOS + prefix) no longer silently
truncated on first decode step
- NM-2: pad_inputs emits (0, 0) pair for non-int expected dims
instead of skipping via continue — critic-caught dim-pair
misalignment in 3D+ tensors when a dynamic dim sits between
static dims
6 files, 3876 unit tests pass (unchanged from baseline), ruff
check + ruff format clean on all changed files.
Constraint: PR author's API preserved; defensive fixes only
Rejected: Add logger.warning on NM-2 dim skip (pure polish, deferred)
Confidence: high
Scope-risk: narrow
Directive: Remaining Phase 2 items (C2, C5, C9, I3, I10, NI-5) need
PR author response before merge. See pr_334_verdicts.md section 10.2.
Not-tested: NM-2 alignment on 5D+ tensors (project has no 5D+ ONNX
inputs; verified behavioral for 3D + 4D interleaved patterns)
tezheng
left a comment
There was a problem hiding this comment.
Second follow-up review: Phase 1 quick wins from pr_334_verdicts.md section 10.2.
This commit addresses 6 findings plus 2 bugs caught by an independent
critic pass on the initial fix attempt (I7 variable shadowing and NM-2
dim-pair misalignment — both would have crashed in production).
3876 unit tests pass (unchanged from baseline). ruff check + ruff format
clean on all 6 touched files. See inline comments for finding-by-finding
diffs.
Remaining work (Phase 2 — needs author input): C2 skip-list policy,
C5 surgery precondition, C9 position-encoding footgun, I3 trust_remote_code
typing, I10 cache in-place semantics, NI-5 Mu2 num_hidden_layers alias.
Full list in pr_334_verdicts.md section 10.
tezheng
left a comment
There was a problem hiding this comment.
Three questions on the open Critical items — each has multiple acceptable
paths and needs your design call. No fixes in this review, just questions.
The other 17 remaining items (Important + Minor) will be addressed in a
follow-up commit alongside this review; those don't require your input.
Deferred items — tracked for follow-up / design discussionThe following 7 findings from the review pass are acknowledged but not addressed in the companion fix commits. Each needs design input, new test infrastructure, or a project-wide decision outside this PR's scope:
These are all real findings, just not mechanically fixable in a review-follow-up branch. Tracking as follow-up issues (or resolving inline here) is equally fine with me. |
tezheng
left a comment
There was a problem hiding this comment.
Second-follow-up review (Phase B + C): flags 9 Important+Minor findings on the current HEAD. A companion commit in this review addresses each of them. All changes have been verified (3876 unit tests pass, ruff clean) and an independent critic caught and corrected one regression before this commit.
Addresses 9 Important+Minor findings from pr_334_verdicts.md section 10,
including one regression caught during critic review of the initial pass:
- I3: trust_remote_code → explicit keyword-only param on
WinMLCompositeModel.from_pretrained (was untyped **kwargs lookup);
docstring updated.
- I4: cache KNOWN_COMPOSITE_TASKS from registry; gate AutoConfig probe
on task matching a known composite task. Non-composite callers no
longer pay for a redundant HF config load on every from_pretrained.
- I6: skip accuracy phase for composite models with explicit
skip_reason=composite_model_not_supported (was arbitrary sub-model
pick via next(iter(...))).
- I16: mirror encoder_decoder.py's EncoderDecoderCache unwrap in
decoder_only.py._resolve_cache for defensive symmetry; documented.
- NI-7: atomic per-sub-component config writes via tmp+replace.
- NI-9: move stale-config cleanup BEFORE wmk config invocation. The
initial fix ran the cleanup AFTER, which silently deleted freshly-
written composite sub-configs — caught by critic review.
- NM-1: demote .to() no-op log to DEBUG (HF pipelines routinely call
model.to('cpu') as setup; WARNING would spam normal usage).
- NM-5: document RoPE-at-position-0 safety for padding position_ids
(covered by attention mask; doc-only change).
- NM-6: replace terse 'see C9 review comment' with self-contained
explanation in both _run_prefill and _run_gen dead branches.
6 files. 3876 unit tests pass (unchanged from baseline). ruff clean on
all touched files.
Constraint: Preserve PR author's API; defensive fixes only
Rejected: Raise ValueError on trust_remote_code=True default (kept
False default to preserve existing caller semantics)
Rejected: Elevate .to() log to WARNING (would spam normal HF pipeline
usage)
Confidence: high
Scope-risk: narrow
Directive: Critical items C2/C5/C9 still need PR author response;
posted as Phase A review questions (pullrequestreview-4156157489).
Architectural deferrals (I5, I10, NI-1, NI-2, NI-3, NI-5, NM-3,
NM-4) tracked as a general PR comment.
Not-tested: I6 composite accuracy skip flow has no unit test;
verified by inspection + existing single-model regression
Match the caller signature. WinMLAutoModel.from_onnx declares `onnx_path: str | Path | dict[str, str | Path]` at auto.py:99 but the composite callee declared `dict[str, str]`, so a caller passing a Path value through the dict branch triggered a type-checker complaint even though runtime Path(path) coercion inside the dispatch loop made it work. - composite_model.py:194 — widen `dict[str, str]` → `dict[str, str | Path]` - composite_model.py:210 — docstring note that str and Path are both accepted - TYPE_CHECKING import: `from pathlib import Path` (no runtime cost) Previously retracted as "cosmetic/non-actionable" on the grounds that the runtime handled both via `Path(path)` coercion. Reversing that retraction — a type-annotation mismatch between a public caller and callee is a real defect a strict type-checker (mypy/pyright) would flag. 1 file, ruff check + format clean, targeted pytest green (81/81). Confidence: high Scope-risk: narrow
zhenchaoni
left a comment
There was a problem hiding this comment.
Re approve after comments are addressed.
Summary
Multi-model composite pipeline: build, optimize, and run inference on models composed of multiple ONNX sub-components (encoder+decoder, prefill+gen) through a unified
WinMLAutoModelinterface.Composite model framework
WinMLCompositeModelbase class with_SUB_MODEL_CONFIGmapping component names to HF tasks@register_composite_model(model_type, task)registry forwmk configmulti-config generationfrom_pretrained()builds all sub-components viaWinMLAutoModelwith per-componentsub_model_kwargsfrom_onnx(dict)loads pre-built ONNX files, resolves concrete class from registryPolymorphic KV cache (
WinMLCachehierarchy)WinMLStaticCache(index_copy_, left-aligned mask) — T5, learned position bias modelsWinMLSlidingWindowCache(FIFO slice+concat, right-aligned mask) — Qwen3, Mu2, RoPE modelsbuild_decoder_mask(num_new_tokens),prepare_prefill_chunk(),update_all_layers()Models
WinMLT5Model): registered fortranslationandsummarizationtasksWinMLQwen3Model): prefill/gen split with chunked prefill,WinMLSlidingWindowCacheWinMLMu2Model): customtrust_remote_codemodel withWinMLSlidingWindowCacheWinMLAutoModel routing
from_pretrained(): checksPIPELINE_MODEL_REGISTRYbefore config phase, delegates toWinMLCompositeModel.from_pretrainedfor composite modelsfrom_onnx(dict): delegates toWinMLCompositeModel.from_onnxfor dictonnx_pathWinMLCompositeModelimport moved underTYPE_CHECKING+ lazy importOptimization
remove_isnan_in_attention_masksurgery: removes deadSoftmax→IsNaN→WhereNaN guard patternsgelu_fusion,fuse_rmsnorm,matmul_add_fusion,clamp_constant_values,remove_isnan_in_attention_maskrun_eval.py composite support
_run_build:winml configproduces list of config JSONs (1 for single, N for composite), common build looprun_model(perf): takes list of ONNX paths, runs perf for each, merges results_extract_onnx_pathto match"Artifact:"and"Existing artifact found:"markersOther
trust_remote_codeforwarded through CLI,WinMLAutoModel,WinMLCompositeModelsession_optionspassthrough fromWinMLAutoModel.from_onnxtoWinMLSessionVerified
Bonjour, comment êtes-vous ?) — build + e2e generationsequence_length=512— build + e2e generation +run_evalperf PASS8×7=56,9+6=15) with bothWinMLSlidingWindowCacheandWinMLStaticCache— build + e2e generation.run_evalperf skipped (OOM segfault in quantization stage)run_evalperf PASS (no regression for single-model path)from_onnxdispatch (79 total in modified files, all passing)