Skip to content

feat: multi-model support with KV cache (T5, Qwen, Mu2)#334

Merged
vortex-captain merged 38 commits into
mainfrom
reny/multi_model
Apr 23, 2026
Merged

feat: multi-model support with KV cache (T5, Qwen, Mu2)#334
vortex-captain merged 38 commits into
mainfrom
reny/multi_model

Conversation

@vortex-captain
Copy link
Copy Markdown
Contributor

@vortex-captain vortex-captain commented Apr 14, 2026

Summary

Multi-model composite pipeline: build, optimize, and run inference on models composed of multiple ONNX sub-components (encoder+decoder, prefill+gen) through a unified WinMLAutoModel interface.

Composite model framework

  • WinMLCompositeModel base class with _SUB_MODEL_CONFIG mapping component names to HF tasks
  • @register_composite_model(model_type, task) registry for wmk config multi-config generation
  • from_pretrained() builds all sub-components via WinMLAutoModel with per-component sub_model_kwargs
  • from_onnx(dict) loads pre-built ONNX files, resolves concrete class from registry

Polymorphic KV cache (WinMLCache hierarchy)

  • WinMLStaticCache (index_copy_, left-aligned mask) — T5, learned position bias models
  • WinMLSlidingWindowCache (FIFO slice+concat, right-aligned mask) — Qwen3, Mu2, RoPE models
  • Shared interface: build_decoder_mask(num_new_tokens), prepare_prefill_chunk(), update_all_layers()
  • Decoder-only model is fully cache-agnostic — padding, mask, and cache update all delegated to cache class
  • Both cache types verified with e2e generation on Qwen3

Models

  • T5 encoder-decoder (WinMLT5Model): registered for translation and summarization tasks
  • Qwen3 decoder-only (WinMLQwen3Model): prefill/gen split with chunked prefill, WinMLSlidingWindowCache
  • Mu2 encoder-decoder (WinMLMu2Model): custom trust_remote_code model with WinMLSlidingWindowCache

WinMLAutoModel routing

  • from_pretrained(): checks PIPELINE_MODEL_REGISTRY before config phase, delegates to WinMLCompositeModel.from_pretrained for composite models
  • from_onnx(dict): delegates to WinMLCompositeModel.from_onnx for dict onnx_path
  • Circular import fix: WinMLCompositeModel import moved under TYPE_CHECKING + lazy import

Optimization

  • remove_isnan_in_attention_mask surgery: removes dead Softmax→IsNaN→Where NaN guard patterns
  • Build configs for T5, Qwen3, Mu2: gelu_fusion, fuse_rmsnorm, matmul_add_fusion, clamp_constant_values, remove_isnan_in_attention_mask

run_eval.py composite support

  • Generalized _run_build: winml config produces list of config JSONs (1 for single, N for composite), common build loop
  • Generalized run_model (perf): takes list of ONNX paths, runs perf for each, merges results
  • Fixed _extract_onnx_path to match "Artifact:" and "Existing artifact found:" markers

Other

  • trust_remote_code forwarded through CLI, WinMLAutoModel, WinMLCompositeModel
  • session_options passthrough from WinMLAutoModel.from_onnx to WinMLSession
  • Architecture-agnostic KV dtype from ONNX metadata (supports fp16 models)
  • DepthPro ONNX registration enabled

Verified

  • T5-small translation: exact match vs PyTorch (Bonjour, comment êtes-vous ?) — build + e2e generation
  • T5-small summarization: correct summary generation with sequence_length=512 — build + e2e generation + run_eval perf PASS
  • Qwen3-0.6B text generation: correct (8×7=56, 9+6=15) with both WinMLSlidingWindowCache and WinMLStaticCache — build + e2e generation. run_eval perf skipped (OOM segfault in quantization stage)
  • Mu2 translation: correct on all 6 queries (EN/CN/ES/DE/FR) with both cache types, streaming generation — build + e2e generation
  • resnet-50: run_eval perf PASS (no regression for single-model path)
  • Unit tests: 15 new tests for KV cache methods and composite from_onnx dispatch (79 total in modified files, all passing)

@vortex-captain vortex-captain requested a review from a team as a code owner April 14, 2026 03:57
Derive KV cache dtype from ONNX model metadata instead of hardcoding
float32, enabling fp16 models to use fp16 StaticCache. Thread
session_options from WinMLAutoModel.from_onnx through to WinMLSession.
Comment thread docs/design/pipeline-model.md Outdated
@vortex-captain vortex-captain marked this pull request as draft April 14, 2026 05:03
Yi Ren added 2 commits April 14, 2026 13:20
Introduce WinMLCache abstract base (step, num_layers, build_decoder_mask,
update_all_layers, reset, create). Two concrete implementations:
- WinMLStaticCache (ScatterElements/index_copy_, T5/Qwen)
- WinMLSlidingWindowCache (Slice+Concat FIFO, Mu2)

WinMLEncoderDecoderModel.forward is now cache-agnostic — calls only
WinMLCache interface methods. Each model subclass declares get_cache_class().
No isinstance checks in the forward path.
Comment thread src/winml/modelkit/models/hf/kv_cache.py Fixed
Yi Ren added 5 commits April 15, 2026 14:47
Move content from docs/design/pipeline-model.md into the files that
own the code: kv_cache.py (cache compatibility), encoder_decoder.py
(forward gotchas), mu2.py (custom model integration), pipeline_model.py
(registry + sub_model_kwargs). Add T5 get_cache_class with explanation
of why sliding window is incompatible. Remove the now-redundant design doc.
WinMLSlidingWindowCache.update() now captures new-token KV (like
WinMLStaticCache) and stores in self.captured. Export wrappers output
captured[i] for both cache types. update_all_layers() does Slice+Concat
on the inference side. Move captured dict to WinMLCache base class.
- Generalize WinMLSlidingWindowCache.update() for N tokens (prefill+gen)
- Qwen3 uses sliding window: cache_position computed internally as
  right-aligned buffer positions, position_ids handles RoPE separately
- Left-pad prefill chunks so real tokens are at END (matches causal mask)
- Add _resolve_cache to WinMLDecoderOnlyModel (same pattern as enc-dec)
- Make get_cache_class abstract in both base classes
- Rename pipeline_model.py -> composite_model.py, register_composite_model
- Remove position_id ONNX input from Qwen (no longer needed)
- update_all_layers moved to WinMLCache base (calls subclass update())
Comment thread src/winml/modelkit/models/winml/decoder_only.py Fixed
Yi Ren added 4 commits April 16, 2026 11:04
Make WinMLDecoderOnlyModel cache-agnostic by delegating padding,
mask construction, and cache updates to the WinMLCache subclass.

kv_cache.py:
- build_decoder_mask: add num_new_tokens param (default 1)
- prepare_prefill_chunk: new abstract method — left-pad (sliding
  window) vs right-pad (static cache)
- update_all_layers: cache_position as range instead of scalar so
  StaticCache.index_copy_ works with multi-token prefill KV

decoder_only.py:
- _run_prefill: delegates to cache.prepare_prefill_chunk and
  cache.build_decoder_mask; slices padding from outputs before
  update_all_layers
- _run_gen: uses cache.build_decoder_mask instead of inline mask
- Both pass cache_position in feeds when the ONNX model expects it

Verified: Qwen3-0.6B e2e with both WinMLSlidingWindowCache and
WinMLStaticCache produces correct results.
Fixes CodeQL finding: global variable '_pad_inputs' is not used.
The refactoring to polymorphic cache methods removed the last call site.
Document how to switch Mu2 from WinMLSlidingWindowCache to
WinMLStaticCache (3 changes: wrapper, OnnxConfig, get_cache_class).
Verified: Mu2 e2e correct with both cache types (6/6 queries).
auto.py:
- Fix circular import: move WinMLCompositeModel import under
  TYPE_CHECKING, lazy import in from_onnx
- from_onnx: delegate dict onnx_path to WinMLCompositeModel.from_onnx
- from_pretrained: check PIPELINE_MODEL_REGISTRY before config phase,
  delegate to WinMLCompositeModel.from_pretrained for composite models

composite_model.py:
- Implement from_onnx: resolves concrete class from registry using
  task + hf_config.model_type, builds each sub-component via
  WinMLAutoModel.from_onnx with per-component task from _SUB_MODEL_CONFIG

Verified: T5 translation and Mu2 translation via both from_onnx(dict)
and from_pretrained produce correct results.
@vortex-captain vortex-captain changed the title feat: multi-model pipeline with KV cache (T5, Qwen, Mu2) feat: multi-model support with KV cache (T5, Qwen, Mu2) Apr 16, 2026
Comment thread src/winml/modelkit/models/auto.py Fixed
Comment thread src/winml/modelkit/models/auto.py Dismissed
Comment thread src/winml/modelkit/models/winml/composite_model.py Dismissed
Yi Ren and others added 6 commits April 16, 2026 16:13
run_eval.py:
- Generalize _run_build: winml config produces a list of config JSONs
  (1 for single model, N for composite), build loop handles both
- Generalize run_model (perf): takes list of ONNX paths, runs perf
  for each, merges results — single model is list-of-1 case

t5.py:
- Register WinMLT5Model for ("t5", "summarization") in addition to
  ("t5", "translation")

timeout_skip_list.json:
- Remove T5 entries (t5-small, t5-base, t5-3b) — composite model
  build now works

Verified: T5-small summarization + translation via both from_pretrained
and run_eval.py perf pipeline.
…wen/Mu2

surgery capability:
- Add REMOVE_ISNAN_IN_ATTENTION_MASK: removes Softmax->IsNaN->Where
  NaN guard patterns (dead code when clamp_constant_values replaces
  -inf with finite value)

model configs (gelu_fusion, fuse_rmsnorm, matmul_add_fusion,
clamp_constant_values, remove_isnan_in_attention_mask):
- T5_CONFIG: registered for model_type "t5"
- MU2_CONFIG: registered for model_type "mu2"
- QWEN_CONFIG: add optim flags (already had export config)

run_eval.py:
- Fix _extract_onnx_path: match "Artifact:" and "Existing artifact
  found:" markers in addition to "Final artifact:"
- Uncomment @register_onnx_overwrite for depth_pro (was disabled
  pending quantization fix)
- Add apple/DepthPro-hf and Qwen3 models to timeout skip list
  (OOM segfault during in-process quantization)
test_io.py (14 new tests):
- TestStaticCacheBuildDecoderMask: left-aligned mask, num_new_tokens
- TestSlidingWindowCacheBuildDecoderMask: right-aligned mask, saturation
- TestStaticCachePreparePrefillChunk: right-pad, pad_len=0
- TestSlidingWindowCachePreparePrefillChunk: left-pad, pad_len>0

test_auto_onnx.py (1 new test):
- TestFromOnnxDictDispatch: dict onnx_path delegates to
  WinMLCompositeModel.from_onnx with correct kwargs
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
@vortex-captain vortex-captain marked this pull request as ready for review April 16, 2026 10:00
tezheng added 2 commits April 22, 2026 21:31
Applies localized fixes for 12 review findings + 1 partial, rebased
onto 3c76601 (post T5 SlidingWindow refactor):

- C3: register_composite_model duplicate-key ValueError
- C6: WinMLCompositeModel.from_onnx unknown-component ValueError
- C8: dedicated hf_config: PretrainedConfig on from_onnx(dict) dispatch
  (previously passed WinMLBuildConfig, causing registry miss; headline
  API was non-functional for real callers, hidden by mocked tests)
- I1: kv_cache.WinMLCache.reset() clears self.captured
- I2: rename PIPELINE_MODEL_REGISTRY to COMPOSITE_MODEL_REGISTRY
- I9: pad_inputs mode: Literal["left", "right"]
- I11-I14: docstring fixes on mu2, qwen, decoder_only, encoder_decoder
  (including StaticCache references now stale after the T5 refactor)
- NI-6: remove phantom position_id from Qwen forward docstring
- NI-8: un-mocked regression test + negative-path test for composite
  from_onnx(dict) dispatch
- NM-2: pad_inputs explicit ValueError on invalid mode

10 files, +137/-31. ruff check + ruff format clean. 3869 unit tests pass.
Companion inline-comment review posted to PR #334.

Constraint: PR author's API must be preserved; only defensive fixes
Rejected: Add abc.ABC to WinMLCache (I1-AB deferred — subclass coordination)
Rejected: Expose composite classes in __init__.py (C1 deferred — public API decision)
Rejected: Fix silent fp16->fp32 fallback (C4 deferred — metadata contract)
Confidence: high
Scope-risk: narrow
Directive: C5 surgery precondition and C2 skip-list policy need PR author response before merge; see pr_334_verdicts.md section 10.2 Phase 2
Not-tested: Mu2 num_hidden_layers runtime path (NI-5)
Applies Phase 1 quick-win fixes from pr_334_verdicts.md section 10.2,
plus two bugs caught by an independent critic review of the initial pass:

- C1: Export WinMLCompositeModel, COMPOSITE_MODEL_REGISTRY,
  register_composite_model, WinMLCache, WinMLStaticCache,
  WinMLSlidingWindowCache, WinMLEncoderDecoderModel,
  WinMLDecoderOnlyModel from models/winml/__init__.py
- C4: Replace silent .get("past_0_key", np.float32) with explicit
  KeyError in encoder_decoder.py and decoder_only.py — stops
  fp16 models from silently being coerced to fp32 when ONNX
  metadata lacks the key
- I1-AB: class WinMLCache(StaticCache, ABC) + ClassVar[str] default
  for position_input_name (subclass contract now enforced at
  instantiation time rather than via unreachable AttributeError)
- I7: Narrow run_eval.py 'except Exception: pass' blocks to specific
  expected exceptions with diagnostic logging; fixed shadowing of
  loop variable (as e -> as exc) caught by critic — original fix
  would have crashed on corrupt result files
- NI-4: Guard input_ids[:, -1:] slice in WinMLEncoderDecoderModel
  prepare_inputs_for_generation on cache occupancy; multi-token
  decoder prompts (e.g. forced BOS + prefix) no longer silently
  truncated on first decode step
- NM-2: pad_inputs emits (0, 0) pair for non-int expected dims
  instead of skipping via continue — critic-caught dim-pair
  misalignment in 3D+ tensors when a dynamic dim sits between
  static dims

6 files, 3876 unit tests pass (unchanged from baseline), ruff
check + ruff format clean on all changed files.

Constraint: PR author's API preserved; defensive fixes only
Rejected: Add logger.warning on NM-2 dim skip (pure polish, deferred)
Confidence: high
Scope-risk: narrow
Directive: Remaining Phase 2 items (C2, C5, C9, I3, I10, NI-5) need
  PR author response before merge. See pr_334_verdicts.md section 10.2.
Not-tested: NM-2 alignment on 5D+ tensors (project has no 5D+ ONNX
  inputs; verified behavioral for 3D + 4D interleaved patterns)
Copy link
Copy Markdown
Collaborator

@tezheng tezheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second follow-up review: Phase 1 quick wins from pr_334_verdicts.md section 10.2.

This commit addresses 6 findings plus 2 bugs caught by an independent
critic pass on the initial fix attempt (I7 variable shadowing and NM-2
dim-pair misalignment — both would have crashed in production).

3876 unit tests pass (unchanged from baseline). ruff check + ruff format
clean on all 6 touched files. See inline comments for finding-by-finding
diffs.

Remaining work (Phase 2 — needs author input): C2 skip-list policy,
C5 surgery precondition, C9 position-encoding footgun, I3 trust_remote_code
typing, I10 cache in-place semantics, NI-5 Mu2 num_hidden_layers alias.
Full list in pr_334_verdicts.md section 10.

Comment thread src/winml/modelkit/models/winml/__init__.py
Comment thread src/winml/modelkit/models/winml/encoder_decoder.py
Comment thread src/winml/modelkit/models/winml/decoder_only.py
Comment thread src/winml/modelkit/models/winml/kv_cache.py
Comment thread scripts/e2e_eval/run_eval.py
Comment thread src/winml/modelkit/models/winml/encoder_decoder.py
Comment thread src/winml/modelkit/utils/data_utils.py
Copy link
Copy Markdown
Collaborator

@tezheng tezheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Three questions on the open Critical items — each has multiple acceptable
paths and needs your design call. No fixes in this review, just questions.

The other 17 remaining items (Important + Minor) will be addressed in a
follow-up commit alongside this review; those don't require your input.

Comment thread scripts/e2e_eval/cache/timeout_skip_list.json
Comment thread src/winml/modelkit/optim/pipes/surgery.py
Comment thread src/winml/modelkit/models/winml/decoder_only.py
Comment thread scripts/e2e_eval/cache/timeout_skip_list.json
Comment thread src/winml/modelkit/optim/pipes/surgery.py
Comment thread src/winml/modelkit/models/winml/decoder_only.py
@tezheng
Copy link
Copy Markdown
Collaborator

tezheng commented Apr 22, 2026

Deferred items — tracked for follow-up / design discussion

The following 7 findings from the review pass are acknowledged but not addressed in the companion fix commits. Each needs design input, new test infrastructure, or a project-wide decision outside this PR's scope:

  • I5 — mixed torch.nn.Module / PreTrainedModel paradigm inside WinMLEncoderDecoderModel. Major refactor; not PR-specific.
  • I10WinMLSlidingWindowCache.update() reassigns .keys/.values (non-in-place), deviating from HF StaticCache in-place contract. Touches cache semantics — needs design call before changing.
  • NI-1 — No ORT InferenceSession lifecycle (no close/__del__/context manager) on WinMLCompositeModel. Project-wide gap, not specific to this PR.
  • NI-2 + NI-3 — Surgery-pass graph hygiene: _remove_isnan_in_attention_mask leaves orphan graph.value_info / dead initializers, and rewires all Where consumers without verifying input position (condition vs. value_if_true vs. value_if_false). Needs ONNX domain expertise; not low-risk.
  • NI-5WinMLCache.__init__ reads config.num_hidden_layers, while Mu2DecoderWrapper uses config.n_decoder_layer. If the remote-code Mu2Config doesn't expose both attributes, Mu2 inference crashes at cache construction. Needs author confirmation on the Mu2Config shape.
  • NM-3Path.with_stem() usage in src/winml/modelkit/commands/config.py requires Python 3.9+. Project-wide Python version floor decision.
  • NM-4 — KV cache unit tests assert shapes only; no value-level tests for update / captured / reset. Needs new test infrastructure.

These are all real findings, just not mechanically fixable in a review-follow-up branch. Tracking as follow-up issues (or resolving inline here) is equally fine with me.

Copy link
Copy Markdown
Collaborator

@tezheng tezheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second-follow-up review (Phase B + C): flags 9 Important+Minor findings on the current HEAD. A companion commit in this review addresses each of them. All changes have been verified (3876 unit tests pass, ruff clean) and an independent critic caught and corrected one regression before this commit.

Comment thread src/winml/modelkit/models/winml/composite_model.py Outdated
Comment thread src/winml/modelkit/models/auto.py Outdated
Comment thread scripts/e2e_eval/run_eval.py Outdated
Comment thread src/winml/modelkit/models/winml/decoder_only.py
Comment thread src/winml/modelkit/commands/config.py Outdated
Comment thread scripts/e2e_eval/run_eval.py
Comment thread src/winml/modelkit/models/winml/composite_model.py
Comment thread src/winml/modelkit/models/winml/kv_cache.py
Comment thread src/winml/modelkit/models/winml/decoder_only.py
tezheng and others added 2 commits April 23, 2026 00:37
Addresses 9 Important+Minor findings from pr_334_verdicts.md section 10,
including one regression caught during critic review of the initial pass:

- I3: trust_remote_code → explicit keyword-only param on
  WinMLCompositeModel.from_pretrained (was untyped **kwargs lookup);
  docstring updated.
- I4: cache KNOWN_COMPOSITE_TASKS from registry; gate AutoConfig probe
  on task matching a known composite task. Non-composite callers no
  longer pay for a redundant HF config load on every from_pretrained.
- I6: skip accuracy phase for composite models with explicit
  skip_reason=composite_model_not_supported (was arbitrary sub-model
  pick via next(iter(...))).
- I16: mirror encoder_decoder.py's EncoderDecoderCache unwrap in
  decoder_only.py._resolve_cache for defensive symmetry; documented.
- NI-7: atomic per-sub-component config writes via tmp+replace.
- NI-9: move stale-config cleanup BEFORE wmk config invocation. The
  initial fix ran the cleanup AFTER, which silently deleted freshly-
  written composite sub-configs — caught by critic review.
- NM-1: demote .to() no-op log to DEBUG (HF pipelines routinely call
  model.to('cpu') as setup; WARNING would spam normal usage).
- NM-5: document RoPE-at-position-0 safety for padding position_ids
  (covered by attention mask; doc-only change).
- NM-6: replace terse 'see C9 review comment' with self-contained
  explanation in both _run_prefill and _run_gen dead branches.

6 files. 3876 unit tests pass (unchanged from baseline). ruff clean on
all touched files.

Constraint: Preserve PR author's API; defensive fixes only
Rejected: Raise ValueError on trust_remote_code=True default (kept
  False default to preserve existing caller semantics)
Rejected: Elevate .to() log to WARNING (would spam normal HF pipeline
  usage)
Confidence: high
Scope-risk: narrow
Directive: Critical items C2/C5/C9 still need PR author response;
  posted as Phase A review questions (pullrequestreview-4156157489).
  Architectural deferrals (I5, I10, NI-1, NI-2, NI-3, NI-5, NM-3,
  NM-4) tracked as a general PR comment.
Not-tested: I6 composite accuracy skip flow has no unit test;
  verified by inspection + existing single-model regression
@vortex-captain vortex-captain requested a review from tezheng April 23, 2026 03:27
Match the caller signature. WinMLAutoModel.from_onnx declares
`onnx_path: str | Path | dict[str, str | Path]` at auto.py:99 but the
composite callee declared `dict[str, str]`, so a caller passing a Path
value through the dict branch triggered a type-checker complaint even
though runtime Path(path) coercion inside the dispatch loop made it work.

- composite_model.py:194 — widen `dict[str, str]` → `dict[str, str | Path]`
- composite_model.py:210 — docstring note that str and Path are both accepted
- TYPE_CHECKING import: `from pathlib import Path` (no runtime cost)

Previously retracted as "cosmetic/non-actionable" on the grounds that the
runtime handled both via `Path(path)` coercion. Reversing that retraction
— a type-annotation mismatch between a public caller and callee is a
real defect a strict type-checker (mypy/pyright) would flag.

1 file, ruff check + format clean, targeted pytest green (81/81).

Confidence: high
Scope-risk: narrow
Comment thread src/winml/modelkit/models/winml/composite_model.py
Comment thread src/winml/modelkit/models/winml/composite_model.py Outdated
Copy link
Copy Markdown
Member

@zhenchaoni zhenchaoni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re approve after comments are addressed.

@vortex-captain vortex-captain merged commit 38775de into main Apr 23, 2026
9 checks passed
@vortex-captain vortex-captain deleted the reny/multi_model branch April 23, 2026 04:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants