diff --git a/scripts/e2e_eval/cache/timeout_skip_list.json b/scripts/e2e_eval/cache/timeout_skip_list.json index 02e9e2caf..246344435 100644 --- a/scripts/e2e_eval/cache/timeout_skip_list.json +++ b/scripts/e2e_eval/cache/timeout_skip_list.json @@ -29,21 +29,6 @@ "task": "translation", "reason": "Build hangs >14000s, m2m_100 enc-dec export issue" }, - { - "hf_id": "google-t5/t5-3b", - "task": "translation", - "reason": "Build hangs >14000s, t5 enc-dec + network download issue" - }, - { - "hf_id": "google-t5/t5-base", - "task": "summarization", - "reason": "Build hangs >2000s, t5 enc-dec export issue" - }, - { - "hf_id": "google-t5/t5-small", - "task": "translation", - "reason": "Build hangs >8000s, t5 enc-dec export issue" - }, { "hf_id": "knkarthick/MEETING_SUMMARY", "task": "summarization", @@ -58,5 +43,25 @@ "hf_id": "philschmid/bart-large-cnn-samsum", "task": "summarization", "reason": "Build hangs >10000s, bart enc-dec export issue" + }, + { + "hf_id": "apple/DepthPro-hf", + "task": "depth-estimation", + "reason": "OOM in quantization (model too large for in-process quantize)" + }, + { + "hf_id": "Qwen/Qwen3-0.6B", + "task": "text-generation", + "reason": "OOM in quantization (segfault during quantize, 2.9GB model)" + }, + { + "hf_id": "Qwen/Qwen3-1.7B", + "task": "text-generation", + "reason": "OOM in quantization (model too large for in-process quantize)" + }, + { + "hf_id": "Qwen/Qwen3-8B", + "task": "text-generation", + "reason": "OOM in quantization (model too large for in-process quantize)" } ] diff --git a/scripts/e2e_eval/run_eval.py b/scripts/e2e_eval/run_eval.py index adb7bd858..52e2be3d3 100644 --- a/scripts/e2e_eval/run_eval.py +++ b/scripts/e2e_eval/run_eval.py @@ -353,11 +353,25 @@ def _run_build( ) -> dict: """Run winml config + winml build for one model. Returns build result dict. - Flow: winml config → config.json → winml build --use-cache → ONNX path. + Flow: winml config → list of config JSONs → winml build each → ONNX paths. + + Single models produce one config; composite models (e.g., T5 translation) + produce one per sub-component (suffixed names). Both go through the same + build loop — single model is just the list-of-1 case. """ config_path = model_dir / "build_config.json" model_dir.mkdir(parents=True, exist_ok=True) + # Remove any stale suffixed sub-configs BEFORE `wmk config` runs. + # For composite models `wmk config` writes files matching {stem}_*.json + # (e.g., build_config_encoder.json); cleaning those AFTER the command would + # delete the freshly-written configs and silently degrade composite builds + # to single-model. Running cleanup first removes prior-run artifacts without + # touching the current run's output. + for _stale in config_path.parent.glob(f"{config_path.stem}_*.json"): + safe_print(f" [config] Removing stale sub-config from prior run: {_stale.name}") + _stale.unlink(missing_ok=True) + # Step 1: winml config config_args = [ *WINML_CLI, @@ -378,59 +392,93 @@ def _run_build( if config_proc["exit_code"] != 0: return { "success": False, - "onnx_path": None, + "onnx_paths": {}, "stage": "config", "proc": config_proc, } - # Step 2: winml build --use-cache - build_args = [ - *WINML_CLI, - "build", - "-c", - str(config_path), - "-m", - entry.hf_id, - "--use-cache", - ] + # Collect config files: composite models produce suffixed files + # (e.g., build_config_encoder.json); single models produce config_path itself. + sub_configs = sorted(config_path.parent.glob(f"{config_path.stem}_*.json")) + if not sub_configs: + sub_configs = [config_path] - build_proc = _run_subprocess(build_args, timeout) - if build_proc["exit_code"] != 0: - return { - "success": False, - "onnx_path": None, - "stage": "build", - "proc": build_proc, - } + # Step 2: build each sub-config + # Map component label → ONNX path. Single model uses "" as label. + onnx_paths: dict[str, str] = {} + last_proc = config_proc - # Extract ONNX path from build output - # winml build prints "Final artifact: " in stderr - onnx_path = None - for line in build_proc["stderr"].splitlines(): - if "Final artifact:" in line: - onnx_path = line.split("Final artifact:")[-1].strip() - break + # TODO: remove for loop once wimnl build supports building composite model to multiple onnx files + for sub_cfg in sub_configs: + label = sub_cfg.stem.removeprefix(f"{config_path.stem}_") if len(sub_configs) > 1 else "" + if label: + safe_print(f" building component: {label}") - # Fallback: search cache for the built model - if not onnx_path: - for line in build_proc["stdout"].splitlines(): - if "Final artifact:" in line: - onnx_path = line.split("Final artifact:")[-1].strip() - break + build_args = [ + *WINML_CLI, + "build", + "-c", + str(sub_cfg), + "-m", + entry.hf_id, + "--use-cache", + ] - if not onnx_path or not Path(onnx_path).exists(): - # Last resort: find _model.onnx in the cache - onnx_path = _find_cached_model(entry.hf_id, build_proc, entry.task) + build_proc = _run_subprocess(build_args, timeout) + last_proc = build_proc + if build_proc["exit_code"] != 0: + stage = f"build_{label}" if label else "build" + return { + "success": False, + "onnx_paths": onnx_paths, + "stage": stage, + "proc": build_proc, + } + + task_hint = _extract_task_from_config(sub_cfg) or entry.task + path = _extract_onnx_path(build_proc, entry.hf_id, task_hint) + if path: + onnx_paths[label] = path return { - "success": onnx_path is not None, - "onnx_path": onnx_path, + "success": len(onnx_paths) == len(sub_configs), + "onnx_paths": onnx_paths, "stage": "complete", - "proc": build_proc, - "config_path": str(config_path), + "proc": last_proc, } +def _extract_onnx_path(build_proc: dict, hf_id: str, task: str | None) -> str | None: + """Extract ONNX path from build subprocess output.""" + # Patterns used by winml build to report the artifact path + markers = ("Final artifact:", "Existing artifact found:", "Artifact:") + onnx_path = None + for line in (build_proc["stderr"] + build_proc["stdout"]).splitlines(): + for marker in markers: + if marker in line: + candidate = line.split(marker)[-1].strip() + if candidate and Path(candidate).exists(): + onnx_path = candidate + break + if onnx_path: + break + + if not onnx_path or not Path(onnx_path).exists(): + onnx_path = _find_cached_model(hf_id, build_proc, task) + + return onnx_path + + +def _extract_task_from_config(config_path: Path) -> str | None: + """Read the task from a build config JSON file.""" + try: + data = json.loads(config_path.read_text(encoding="utf-8")) + loader = data.get("loader", {}) + return loader.get("task") + except (OSError, json.JSONDecodeError): + return None + + def _find_cached_model(hf_id: str, build_proc: dict, task: str | None = None) -> str | None: """Try to find the built ONNX model in the WinML cache. @@ -447,6 +495,7 @@ def _find_cached_model(hf_id: str, build_proc: dict, task: str | None = None) -> return None from winml.modelkit.loader.task import get_task_abbrev + prefix = get_task_abbrev(task) + "_" model_files = sorted( @@ -466,16 +515,16 @@ def run_model( entry: ModelEntry, device: str, timeout: int, - onnx_path: str | None = None, + onnx_paths: dict[str, str] | None = None, ) -> dict: - """Execute winml perf for one model. Returns raw subprocess result dict. + """Execute winml perf for one or more ONNX models. Returns merged result dict. - When onnx_path is provided, benchmarks the pre-built ONNX directly - (skips internal build). Otherwise falls back to HF model ID. + When onnx_paths is provided, benchmarks each pre-built ONNX directly. + Single model is the {"": path} case. Results are merged (worst exit + code, concatenated stdout/stderr, summed elapsed). """ - if onnx_path: - args = [*WINML_CLI, "perf", "-m", onnx_path, "--device", device] - else: + if not onnx_paths: + # No pre-built paths: fall back to HF model ID (single model only) args = [ *WINML_CLI, "perf", @@ -488,22 +537,68 @@ def run_model( ] if entry.task: args += ["--task", entry.task] + args += ["--iterations", "10", "--warmup", "2"] + args += entry.perf_args + + proc = _run_subprocess(args, timeout) + proc["device"] = device + proc["timestamp"] = _utc_now() + proc["error_summary"] = ( + "" + if proc["exit_code"] == 0 + else f"timeout ({timeout}s)" + if proc["timeout"] + else f"exit code {proc['exit_code']}" + ) + return proc + + # Run perf for each sub-model and merge results + all_stdout: list[str] = [] + all_stderr: list[str] = [] + total_elapsed = 0.0 + worst_exit = 0 + any_timeout = False + commands: list[str] = [] + + for label, path in onnx_paths.items(): + if label: + safe_print(f" perf: {label}") + + args = [*WINML_CLI, "perf", "-m", path, "--device", device] + args += ["--iterations", "10", "--warmup", "2"] + args += entry.perf_args + + proc = _run_subprocess(args, timeout) + if label: + all_stdout.append(f"=== {label} ===\n{proc['stdout']}") + all_stderr.append(f"=== {label} ===\n{proc['stderr']}") + else: + all_stdout.append(proc["stdout"]) + all_stderr.append(proc["stderr"]) + total_elapsed += proc["elapsed"] + commands.append(proc["command"]) + if proc["exit_code"] != 0: + worst_exit = proc["exit_code"] + if proc["timeout"]: + any_timeout = True - args += ["--iterations", "10", "--warmup", "2"] - args += entry.perf_args - - proc = _run_subprocess(args, timeout) - # Attach device and timestamp for build_eval_result - proc["device"] = device - proc["timestamp"] = _utc_now() - proc["error_summary"] = ( - "" - if proc["exit_code"] == 0 - else f"timeout ({timeout}s)" - if proc["timeout"] - else f"exit code {proc['exit_code']}" - ) - return proc + return { + "stdout": "\n".join(all_stdout), + "stderr": "\n".join(all_stderr), + "exit_code": worst_exit, + "elapsed": round(total_elapsed, 1), + "timeout": any_timeout, + "command": commands[0] if len(commands) == 1 else " | ".join(commands), + "device": device, + "timestamp": _utc_now(), + "error_summary": ( + "" + if worst_exit == 0 + else f"timeout ({timeout}s)" + if any_timeout + else f"exit code {worst_exit}" + ), + } # --------------------------------------------------------------------------- @@ -984,8 +1079,8 @@ def main() -> None: if e.hf_id == args.hf_model: matched_entry = e break - except Exception: - pass # Registry is optional for single-model mode; proceed without enrichment + except Exception as e: + safe_print(f" [registry] Optional enrichment skipped: {e}") if matched_entry is not None: # Override task if explicitly provided on CLI if args.task and args.task != matched_entry.task: @@ -1048,8 +1143,10 @@ def main() -> None: if _should_skip_existing(existing, retry_types, args.eval_type): skipped_count += 1 continue - except Exception: - pass # Corrupt result file — include model for re-evaluation + except (OSError, json.JSONDecodeError, KeyError) as exc: + safe_print( + f" [continue] Corrupt result file {result_path}: {exc} — re-evaluating" + ) filtered.append(e) if skipped_count: safe_print( @@ -1182,70 +1279,67 @@ def main() -> None: perf_proc: dict | None = None accuracy_result: dict | None = None - # Build phase: winml config + winml build → ONNX path + # Build phase: winml config + winml build → list of ONNX paths # Build is shared by perf and eval, avoiding redundant builds. - onnx_path: str | None = None - if args.eval_type in ("perf", "both"): - build_result = _run_build( - entry, - args.device, - _DEFAULT_PRECISION, - args.timeout, - model_dir, - ) - if build_result["success"]: - onnx_path = build_result["onnx_path"] + build_result = _run_build( + entry, + args.device, + _DEFAULT_PRECISION, + args.timeout, + model_dir, + ) + onnx_paths = build_result["onnx_paths"] if build_result["success"] else {} + # Composite models produce multiple ONNX paths; accuracy phase requires a + # single path and is not yet supported for composite models. + # TODO: composite model accuracy support + is_composite = len(onnx_paths) > 1 + first_path = ( + next(iter(onnx_paths.values()), None) if onnx_paths and not is_composite else None + ) + + if not build_result["success"]: + # Build failed — synthesize failed result for downstream phases + fail_proc = build_result["proc"] + fail_proc["device"] = args.device + fail_proc["timestamp"] = _utc_now() + fail_proc["error_summary"] = f"build_{build_result['stage']}_failed" - if args.eval_type == "accuracy": - # Accuracy-only: build + eval (no perf) - build_result = _run_build( + if args.eval_type != "accuracy": + perf_proc = fail_proc + if args.eval_type != "perf": + accuracy_result = {"skipped": True, "skip_reason": "build_failed"} + elif is_composite and args.eval_type != "perf": + # Accuracy phase skipped for composite models (TODO: composite accuracy support) + safe_print( + f" [accuracy] Skipped for composite model {entry.hf_id} " + "(multiple ONNX paths; composite accuracy evaluation not yet implemented)" + ) + accuracy_result = {"skipped": True, "skip_reason": "composite_model_not_supported"} + if args.eval_type == "both": + perf_proc = run_model(entry, args.device, args.timeout, onnx_paths) + elif args.eval_type == "accuracy": + accuracy_result = _run_accuracy_phase( entry, args.device, - _DEFAULT_PRECISION, args.timeout, model_dir, + first_path, ) - if build_result["success"]: - onnx_path = build_result["onnx_path"] + elif args.eval_type == "perf": + perf_proc = run_model(entry, args.device, args.timeout, onnx_paths) + else: + # "both": perf → eval + perf_proc = run_model(entry, args.device, args.timeout, onnx_paths) + if perf_proc["exit_code"] != 0: + accuracy_result = {"skipped": True, "skip_reason": "perf_failed"} + else: accuracy_result = _run_accuracy_phase( entry, args.device, args.timeout, model_dir, - onnx_path, + first_path, ) - else: - accuracy_result = {"skipped": True, "skip_reason": "build_failed"} - elif args.eval_type == "perf": - if onnx_path: - perf_proc = run_model(entry, args.device, args.timeout, onnx_path) - else: - # Build failed — synthesize a failed perf result - perf_proc = build_result["proc"] - perf_proc["device"] = args.device - perf_proc["timestamp"] = _utc_now() - perf_proc["error_summary"] = f"build_{build_result['stage']}_failed" - else: - # "both": build → perf → eval - if onnx_path: - perf_proc = run_model(entry, args.device, args.timeout, onnx_path) - if perf_proc["exit_code"] != 0: - accuracy_result = {"skipped": True, "skip_reason": "perf_failed"} - else: - accuracy_result = _run_accuracy_phase( - entry, - args.device, - args.timeout, - model_dir, - onnx_path, - ) - else: - # Build failed - perf_proc = build_result["proc"] - perf_proc["device"] = args.device - perf_proc["timestamp"] = _utc_now() - perf_proc["error_summary"] = f"build_{build_result['stage']}_failed" - accuracy_result = {"skipped": True, "skip_reason": "build_failed"} except KeyboardInterrupt: safe_print("\n\n[Ctrl+C] Interrupted — generating reports for completed models...") diff --git a/src/winml/modelkit/commands/build.py b/src/winml/modelkit/commands/build.py index d35ea0882..c25db4e42 100644 --- a/src/winml/modelkit/commands/build.py +++ b/src/winml/modelkit/commands/build.py @@ -295,6 +295,12 @@ def _build_modules( default=None, help="Maximum autoconf re-optimization rounds (default: 3). --no-analyze sets this to 0.", ) +@click.option( + "--trust-remote-code", + is_flag=True, + default=False, + help="Trust remote code for custom model architectures (e.g., Mu2).", +) @click.option( "-v", "--verbose", @@ -317,6 +323,7 @@ def build( device: str | None, no_analyze: bool, max_optim_iterations: int | None, + trust_remote_code: bool, verbose: bool, ) -> None: r"""Build a WinML-optimized ONNX model from a HuggingFace model or .onnx file. @@ -398,6 +405,8 @@ def build( extra_kwargs["hack_max_optim_iterations"] = 0 elif max_optim_iterations is not None: extra_kwargs["hack_max_optim_iterations"] = max_optim_iterations + if trust_remote_code: + extra_kwargs["trust_remote_code"] = True if is_module_mode: # ---- MODULE MODE: array config, one build per submodule ---- diff --git a/src/winml/modelkit/commands/config.py b/src/winml/modelkit/commands/config.py index e6191f83e..445ab782a 100644 --- a/src/winml/modelkit/commands/config.py +++ b/src/winml/modelkit/commands/config.py @@ -328,6 +328,31 @@ def config( else: _is_onnx_mode = False + # Check composite model registry: (model_type, task) -> multi-config + pipeline_components = _resolve_composite_model_components( + hf_model, model_type, task, trust_remote_code=trust_remote_code + ) + if pipeline_components: + # composite model: generate one config per sub-component + _generate_pipeline_configs( + pipeline_components, + hf_model=hf_model, + model_class=model_class, + model_type=model_type, + override=override, + shape_config=shape_config, + library_name=library_name, + device=device, + precision=precision, + trust_remote_code=trust_remote_code, + ep=ep, + no_quant=no_quant, + no_compile=no_compile, + output=output, + console=console, + ) + return + # Generate config(s) - returns single or list based on module parameter result = generate_hf_build_config( model_id=hf_model, @@ -486,3 +511,93 @@ def config( if verbose: logger.exception("Unexpected error during config generation") raise click.ClickException(f"Unexpected error: {e}") from e + + +def _resolve_composite_model_components( + hf_model: str | None, + model_type: str | None, + task: str | None, + trust_remote_code: bool = False, +) -> dict[str, str] | None: + """Check if (model_type, task) is a registered composite model. + + Returns _SUB_MODEL_CONFIG dict if found, None otherwise. + """ + if task is None: + return None + + import winml.modelkit.models.hf # noqa: F401 # trigger pipeline registrations + + from ..models.winml.composite_model import COMPOSITE_MODEL_REGISTRY + + # Resolve model_type from HF config if not provided + resolved_type = model_type + if resolved_type is None and hf_model is not None: + from transformers import AutoConfig + + resolved_type = AutoConfig.from_pretrained( + hf_model, trust_remote_code=trust_remote_code + ).model_type + + if resolved_type is None: + return None + + cls = COMPOSITE_MODEL_REGISTRY.get((resolved_type, task)) + return cls._SUB_MODEL_CONFIG if cls is not None else None + + +def _generate_pipeline_configs( + components: dict[str, str], + *, + hf_model: str | None, + model_class: str | None, + model_type: str | None, + override: Any, + shape_config: dict | None, + library_name: str, + device: str, + precision: str, + trust_remote_code: bool, + ep: str | None, + no_quant: bool, + no_compile: bool, + output: str | None, + console: Any, +) -> None: + """Generate and save one config file per pipeline sub-component.""" + from ..config import generate_hf_build_config + + for component_name, component_task in components.items(): + console.print( + f"[dim]Generating config for component '{component_name}' " + f"(task={component_task})...[/dim]" + ) + + cfg = generate_hf_build_config( + model_id=hf_model, + task=component_task, + model_class=model_class, + model_type=model_type, + override=override, + shape_config=shape_config, + library_name=library_name, + device=device, + precision=precision, + trust_remote_code=trust_remote_code, + ep=ep, + ) + _apply_stage_overrides(cfg, no_quant=no_quant, no_compile=no_compile) + + config_json = json.dumps(cfg.to_dict(), indent=2) + + if output: + out_path = Path(output) + suffixed = out_path.with_stem(f"{out_path.stem}_{component_name}") + suffixed.parent.mkdir(parents=True, exist_ok=True) + tmp = suffixed.with_suffix(".json.tmp") + tmp.write_text(config_json) + tmp.replace(suffixed) + console.print(f"[green]Config saved to:[/green] {suffixed}") + else: + console.print(f"[bold]--- {component_name} ({component_task}) ---[/bold]") + print(config_json) diff --git a/src/winml/modelkit/config/build.py b/src/winml/modelkit/config/build.py index 17da4ab0b..3f334f287 100644 --- a/src/winml/modelkit/config/build.py +++ b/src/winml/modelkit/config/build.py @@ -874,6 +874,7 @@ def _merge_export_config( dynamic_axes=( override.dynamic_axes if override.dynamic_axes is not None else base.dynamic_axes ), + dynamo=override.dynamo if override.dynamo else base.dynamo, ) diff --git a/src/winml/modelkit/export/htp/exporter.py b/src/winml/modelkit/export/htp/exporter.py index b844e7787..aa7c1fdeb 100644 --- a/src/winml/modelkit/export/htp/exporter.py +++ b/src/winml/modelkit/export/htp/exporter.py @@ -439,9 +439,15 @@ def _convert_model_to_onnx( if export_config.dynamic_axes: onnx_kwargs["dynamic_axes"] = export_config.dynamic_axes - tuple(inputs.values()) with self._get_optimum_patcher(model, task): - torch.onnx.export(model, (), output_path, kwargs=inputs, **onnx_kwargs) + # Models can override input binding by implementing + # get_export_args(inputs) → tuple of positional args. + # Default: pass inputs dict as kwargs. + if hasattr(model, "get_export_args"): + export_args = model.get_export_args(inputs) + torch.onnx.export(model, export_args, output_path, **onnx_kwargs) + else: + torch.onnx.export(model, (), output_path, kwargs=inputs, **onnx_kwargs) @staticmethod def _get_optimum_patcher(model: nn.Module, task: str | None) -> Any: diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index ae21e1881..f0f55a4db 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -36,6 +36,8 @@ if TYPE_CHECKING: + from transformers import PretrainedConfig + from ..config import WinMLBuildConfig from .winml.base import WinMLPreTrainedModel @@ -96,7 +98,7 @@ def __init__(self) -> None: @classmethod def from_onnx( cls, - onnx_path: str | Path, + onnx_path: str | Path | dict[str, str | Path], *, task: str | None = None, config: WinMLBuildConfig | None = None, @@ -107,8 +109,10 @@ def from_onnx( use_cache: bool = True, force_rebuild: bool = False, skip_build: bool = False, + session_options: Any | None = None, + hf_config: PretrainedConfig | None = None, **kwargs: Any, - ) -> WinMLPreTrainedModel: + ) -> WinMLPreTrainedModel | WinMLCompositeModel: # noqa: F821 """Build from a pre-exported ONNX file. Runs optimize -> [quantize] -> [compile] via ``build_onnx_model()``. @@ -124,11 +128,33 @@ def from_onnx( cache_dir: Override cache directory. use_cache: Whether to use persistent cache. force_rebuild: Force rebuild even if cached. + hf_config: HF ``PretrainedConfig`` for composite (dict) dispatch only. + Required when ``onnx_path`` is a dict so the composite registry + lookup can resolve ``(model_type, task)``. Ignored for single-file + builds. **kwargs: Forwarded to ``build_onnx_model()``. Returns: WinMLPreTrainedModel inference wrapper. """ + if isinstance(onnx_path, dict): + from .winml.composite_model import WinMLCompositeModel + + return WinMLCompositeModel.from_onnx( + onnx_path, + task=task, + hf_config=hf_config, + device=device, + precision=precision, + ep=ep, + cache_dir=cache_dir, + use_cache=use_cache, + force_rebuild=force_rebuild, + skip_build=skip_build, + session_options=session_options, + **kwargs, + ) + onnx_path = Path(onnx_path) if not onnx_path.is_file(): raise FileNotFoundError( @@ -165,6 +191,7 @@ def from_onnx( onnx_path=onnx_path, config=None, device=device, + session_options=session_options, ) # Resolve output directory @@ -200,6 +227,7 @@ def from_onnx( onnx_path=result.final_onnx_path, config=None, # No HF PretrainedConfig for bare ONNX builds device=device, + session_options=session_options, ) @classmethod @@ -278,6 +306,43 @@ def from_pretrained( **kwargs, ) + # ===================================================================== + # COMPOSITE MODEL CHECK — delegate to WinMLCompositeModel.from_pretrained + # when (model_type, task) is a registered composite (e.g., T5 translation, + # Qwen text-generation). AutoConfig is lightweight (~config.json only). + # The registry probe (AutoConfig.from_pretrained) is gated on whether + # `task` appears in any registered composite entry, avoiding an + # unconditional network/disk round-trip for every non-composite call. + # ===================================================================== + if task is not None: + from .winml.composite_model import COMPOSITE_MODEL_REGISTRY + + _known_composite_tasks = {t for (_, t) in COMPOSITE_MODEL_REGISTRY} + if task in _known_composite_tasks: + from transformers import AutoConfig + + _hf_cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) + _model_type = getattr(_hf_cfg, "model_type", None) + else: + _model_type = None + + if _model_type is not None and (_model_type, task) in COMPOSITE_MODEL_REGISTRY: + from .winml.composite_model import WinMLCompositeModel + + return WinMLCompositeModel.from_pretrained( + model_id, + task, + device=device, + use_cache=use_cache, + force_rebuild=force_rebuild, + trust_remote_code=trust_remote_code, + shape_config=shape_config, + precision=precision, + config=config, + cache_dir=cache_dir, + **kwargs, + ) + # ===================================================================== # [1] CONFIG PHASE - Generate complete config with I/O specs (Lightweight, ~2s) # ===================================================================== @@ -292,6 +357,7 @@ def from_pretrained( shape_config=shape_config, device=device, precision=precision, + trust_remote_code=trust_remote_code, ep=kwargs.get("ep"), ) diff --git a/src/winml/modelkit/models/hf/__init__.py b/src/winml/modelkit/models/hf/__init__.py index d26685b12..6b46a6567 100644 --- a/src/winml/modelkit/models/hf/__init__.py +++ b/src/winml/modelkit/models/hf/__init__.py @@ -36,11 +36,23 @@ from .depth_anything import DepthAnythingIOConfig as _DepthAnythingIOConfig # triggers registration from .depth_pro import DepthProIOConfig as _DepthProIOConfig # triggers registration from .detr import DETR_CONFIG +from .mu2 import MODEL_CLASS_MAPPING as _MU2_CLASS_MAPPING +from .mu2 import MU2_CONFIG +from .mu2 import Mu2DecoderIOConfig as _Mu2DecoderIOConfig # triggers registration +from .mu2 import Mu2EncoderIOConfig as _Mu2EncoderIOConfig # triggers registration +from .qwen import MODEL_CLASS_MAPPING as _QWEN_CLASS_MAPPING +from .qwen import QWEN_CONFIG +from .qwen import QwenGenIOConfig as _QwenGenIOConfig +from .qwen import QwenPrefillIOConfig as _QwenPrefillIOConfig from .roberta import ROBERTA_FAMILY_CONFIG from .roberta import RobertaIOConfig as _RobertaIOConfig # triggers registration from .sam import MODEL_CLASS_MAPPING as _SAM2_CLASS_MAPPING from .segformer import MODEL_CLASS_MAPPING as _SEGFORMER_CLASS_MAPPING from .segformer import SegformerIOConfig as _SegformerIOConfig # triggers registration +from .t5 import MODEL_CLASS_MAPPING as _T5_CLASS_MAPPING +from .t5 import T5_CONFIG +from .t5 import T5DecoderIOConfig as _T5DecoderIOConfig # triggers registration +from .t5 import T5EncoderIOConfig as _T5EncoderIOConfig # triggers registration from .vision_encoder_decoder import VISION_ENCODER_DECODER_CONFIG from .zoedepth import ZoeDepthIOConfig as _ZoeDepthIOConfig # triggers registration @@ -48,8 +60,11 @@ # Aggregated model class mappings: (model_type, task) -> HF model class MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = { **_CLIP_CLASS_MAPPING, + **_MU2_CLASS_MAPPING, + **_QWEN_CLASS_MAPPING, **_SAM2_CLASS_MAPPING, **_SEGFORMER_CLASS_MAPPING, + **_T5_CLASS_MAPPING, } # Registry: model_type -> WinMLBuildConfig @@ -64,6 +79,9 @@ "clip-vision-model": CLIP_CONFIG, "detr": DETR_CONFIG, "roberta": ROBERTA_FAMILY_CONFIG, + "mu2": MU2_CONFIG, + "qwen3": QWEN_CONFIG, + "t5": T5_CONFIG, "vision-encoder-decoder": VISION_ENCODER_DECODER_CONFIG, "xlm-roberta": ROBERTA_FAMILY_CONFIG, } diff --git a/src/winml/modelkit/models/hf/depth_pro.py b/src/winml/modelkit/models/hf/depth_pro.py index d24850900..a5d53a770 100644 --- a/src/winml/modelkit/models/hf/depth_pro.py +++ b/src/winml/modelkit/models/hf/depth_pro.py @@ -27,6 +27,8 @@ from optimum.utils import NormalizedConfig from optimum.utils.input_generators import DummyVisionInputGenerator +from ...export import register_onnx_overwrite + class _DepthProNormalizedConfig(NormalizedConfig): """Normalized config for DepthPro with computed image_size. @@ -44,8 +46,7 @@ def image_size(self) -> int: return int(self.config.patch_size / min(self.config.scaled_images_ratios)) -# TODO: enable registration once quantization can be done with enough RAM -# @register_onnx_overwrite("depth_pro", "depth-estimation", library_name="transformers") +@register_onnx_overwrite("depth_pro", "depth-estimation", library_name="transformers") class DepthProIOConfig(OnnxConfig): """ONNX config for DepthPro depth estimation. diff --git a/src/winml/modelkit/models/hf/mu2.py b/src/winml/modelkit/models/hf/mu2.py new file mode 100644 index 000000000..54f09a5d1 --- /dev/null +++ b/src/winml/modelkit/models/hf/mu2.py @@ -0,0 +1,335 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Mu2 encoder-decoder model with KV cache. + +Export wrappers, OnnxConfig registrations, and ``WinMLMu2Model`` inference +class for Mu2 (custom ``trust_remote_code`` model). + +Export Strategy: +- Mu2EncoderWrapper (``feature-extraction``): encoder-only ONNX. +- Mu2DecoderWrapper (``text2text-generation``): decoder with + ``WinMLSlidingWindowCache`` (Slice+Concat, no ScatterElements). + Present KV output is the new-token KV only. + +Custom model integration (``auto_map``): + The Mu2 model uses ``trust_remote_code=True`` with ``auto_map`` in + ``config.json`` pointing to ``modeling_mu.py`` / ``configuration_mu.py`` + alongside the weights. KV cache support was added to the model source + (``MuAttentionSDPA`` accepts ``past_key_value`` + ``cache_position``). + +Key decisions: +- Uses ``WinMLSlidingWindowCache`` (not Static) because Mu2 uses RoPE, + not learned relative position bias. RoPE is baked into K tensors, + so buffer positions don't affect attention — sliding window is safe. +- The decoder ONNX input is ``position_id`` (absolute seq position for + RoPE), not ``cache_position`` (which implies buffer-position indexing). +- Mu2's ``generate_sin_cos_pos_emb`` was patched for transformers < 5.x + compatibility (computes inv_freq directly instead of using + ``LlamaRotaryEmbedding.compute_default_rope_parameters``). +- Mu2's ``Mu2Config`` must pass ``pad_token_id`` / ``bos_token_id`` / + ``eos_token_id`` to ``super().__init__()`` or PretrainedConfig + overrides them to None. + +Cache type: + +The default configuration uses ``WinMLSlidingWindowCache`` (FIFO +Slice+Concat). ``WinMLEncoderDecoderModel`` is cache-agnostic — mask +construction and cache updates are delegated to the cache class via +``build_decoder_mask``, ``position_input_name``, and +``update_all_layers``. To switch to ``WinMLStaticCache`` (index_copy_): + +1. **Export wrapper**: change ``Mu2DecoderWrapper.forward()`` to use + ``WinMLStaticCache`` and rename the position arg from ``position_id`` + to ``cache_position``. +2. **OnnxConfig inputs**: change ``"position_id"`` to + ``"cache_position"`` in ``Mu2DecoderIOConfig.inputs``. +3. **Inference**: override ``get_cache_class()`` to return + ``WinMLStaticCache``. ``WinMLEncoderDecoderModel`` uses + ``cache.position_input_name`` to select the correct ONNX input name + automatically. + +Usage:: + + wmk config -m path/to/mu2 --task translation --trust-remote-code -o mu2.json + wmk build -c mu2_encoder.json -m path/to/mu2 --trust-remote-code -o output/encoder + wmk build -c mu2_decoder.json -m path/to/mu2 --trust-remote-code -o output/decoder +""" + +from __future__ import annotations + +from typing import Any, ClassVar + +import torch +import torch.nn as nn +from optimum.exporters.onnx import OnnxConfig +from optimum.utils import NormalizedConfig +from optimum.utils.input_generators import DummyTextInputGenerator + +from ...config import WinMLBuildConfig +from ...export import register_onnx_overwrite +from ...optim import WinMLOptimizationConfig +from ..winml.composite_model import register_composite_model +from ..winml.encoder_decoder import EncoderDecoderInputGenerator, WinMLEncoderDecoderModel +from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLSlidingWindowCache + + +# ============================================================================= +# Wrapper nn.Modules +# ============================================================================= + + +class Mu2EncoderWrapper(nn.Module): + """Wraps Mu2 encoder for standalone ONNX export.""" + + def __init__(self, model: nn.Module) -> None: + super().__init__() + self.encoder = model.encoder + self.config = model.config + + @classmethod + def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> Mu2EncoderWrapper: + """Load full Mu2, extract encoder.""" + from transformers import AutoModelForSeq2SeqLM + + full_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, **kwargs) + wrapper = cls(full_model) + wrapper.eval() + return wrapper + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + """Return encoder last hidden state.""" + return self.encoder( + input_ids=input_ids, attention_mask=attention_mask.bool() + ).last_hidden_state + + +class Mu2DecoderWrapper(nn.Module): + """Wraps Mu2 decoder for ONNX export. + + Delegates to the model's own decoder (which now accepts ``past_key_values`` + and ``cache_position``). This wrapper just builds the cache from flat + KV inputs, calls the decoder, and collects captured KV outputs. + + Same pattern as ``T5DecoderWrapper``. + """ + + def __init__(self, model: nn.Module) -> None: + super().__init__() + self.model = model + self.config = model.config + self.num_layers = model.config.n_decoder_layer + + @classmethod + def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> Mu2DecoderWrapper: + """Load full Mu2, wrap for cached decoder export.""" + from transformers import AutoModelForSeq2SeqLM + + full_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, **kwargs) + wrapper = cls(full_model) + wrapper.eval() + return wrapper + + def get_export_args(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, ...]: + """Convert dict inputs to positional args for torch.onnx.export.""" + return tuple(inputs.values()) + + def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Run decoder with FIFO KV cache (Slice+Concat). + + Positional args (order matches OnnxConfig.inputs): + decoder_input_ids, encoder_hidden_states, attention_mask (encoder), + decoder_attention_mask, position_id, + past_0_key, past_0_value, past_1_key, past_1_value, ... + + Returns: + (logits, present_0_key, present_0_value, ...) where each + present KV is the new-token slice only [batch, n_kv_head, seq_len, head_dim] + (raw key_states/value_states captured before Slice+Concat in WinMLSlidingWindowCache). + """ + decoder_input_ids = args[0] + encoder_hidden_states = args[1] + encoder_attention_mask = args[2] # "attention_mask" in OnnxConfig + decoder_attention_mask = args[3] + position_id = args[4] # absolute sequence position for RoPE + kv_start = 5 + + # Build WinMLSlidingWindowCache (FIFO: Slice+Concat instead of ScatterElements) + cache = WinMLSlidingWindowCache(self.config, max_cache_len=args[kv_start].size(2)) + cache.early_initialization( + batch_size=decoder_input_ids.size(0), + num_heads=self.config.n_kv_head, + head_dim=self.config.head_dim, + dtype=args[kv_start].dtype, + device=decoder_input_ids.device, + ) + for i in range(self.num_layers): + cache.layers[i].keys = args[kv_start + i * 2] + cache.layers[i].values = args[kv_start + i * 2 + 1] + + # Delegate to model's decoder — position_id is passed as cache_position + # for RoPE computation (WinMLSlidingWindowCache.update ignores it for indexing) + hidden_states = self.model.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=cache, + cache_position=position_id, + ) + logits = self.model.lm_head(hidden_states) + + # Output new-token KV only (same as T5 — captured during update) + result: list[torch.Tensor] = [logits] + for i in range(self.num_layers): + k, v = cache.captured[i] + result.extend([k, v]) + return tuple(result) + + +# ============================================================================= +# OnnxConfig Registrations +# ============================================================================= + + +@register_onnx_overwrite("mu2", "feature-extraction", library_name="transformers") +class Mu2EncoderIOConfig(OnnxConfig): + """ONNX config for Mu2 encoder (feature-extraction task).""" + + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( + vocab_size="vocab_size", + allow_new=True, + ) + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator,) + + @property + def inputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + return { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + } + + @property + def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + return { + "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, + } + + +@register_onnx_overwrite("mu2", "text2text-generation", library_name="transformers") +class Mu2DecoderIOConfig(OnnxConfig): + """ONNX config for Mu2 decoder with static KV cache.""" + + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( + hidden_size="n_embd", + num_layers="n_decoder_layer", + num_attention_heads="n_kv_head", + head_dim="head_dim", + max_cache_len="block_size", + vocab_size="vocab_size", + allow_new=True, + ) + DUMMY_INPUT_GENERATOR_CLASSES = ( + EncoderDecoderInputGenerator, + PastKeyValueInputGenerator, + ) + + @property + def inputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + result: dict[str, dict[int, str]] = { + "decoder_input_ids": {0: "batch_size"}, + "encoder_hidden_states": {0: "batch_size"}, + "attention_mask": {0: "batch_size"}, + "decoder_attention_mask": {0: "batch_size"}, + "position_id": {}, + } + num_layers = self._normalized_config.num_layers + for i in range(num_layers): + result[f"past_{i}_key"] = {0: "batch_size"} + result[f"past_{i}_value"] = {0: "batch_size"} + return result + + @property + def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + result: dict[str, dict[int, str]] = {"logits": {0: "batch_size"}} + num_layers = self._normalized_config.num_layers + for i in range(num_layers): + result[f"present_{i}_key"] = {0: "batch_size"} + result[f"present_{i}_value"] = {0: "batch_size"} + return result + + +# ============================================================================= +# Model Class Mapping + WinML Inference Model +# ============================================================================= + +MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = { + ("mu2", "feature-extraction"): Mu2EncoderWrapper, + ("mu2", "text2text-generation"): Mu2DecoderWrapper, +} + +MU2_CONFIG = WinMLBuildConfig( + optim=WinMLOptimizationConfig( + gelu_fusion=True, + fuse_rmsnorm=True, + matmul_add_fusion=True, + clamp_constant_values=True, + remove_isnan_in_attention_mask=True, + ), +) + + +@register_composite_model("mu2", "translation") +class WinMLMu2Model(WinMLEncoderDecoderModel): + """Mu2 encoder-decoder model with sliding-window KV cache. + + Only differs from T5 in ``get_cache_class`` and ``_SUB_MODEL_CONFIG``. + All forward/cache logic lives in ``WinMLEncoderDecoderModel``. + """ + + _SUB_MODEL_CONFIG: ClassVar[dict[str, str]] = { + "encoder": "feature-extraction", + "decoder": "text2text-generation", + } + + @classmethod + def get_cache_class(cls) -> type: # noqa: D102 + return WinMLSlidingWindowCache + + @property + def generation_config(self): # noqa: D102 + if not hasattr(self, "_generation_config"): + from transformers import GenerationConfig + + gc_kw: dict[str, Any] = {} + if self.config is not None: + for attr in ( + "decoder_start_token_id", + "bos_token_id", + "eos_token_id", + "pad_token_id", + ): + val = getattr(self.config, attr, None) + if val is not None: + gc_kw[attr] = val + gc_kw.setdefault("max_new_tokens", self._max_dec - 1) + gc_kw.setdefault("num_beams", 1) + gc_kw.setdefault("do_sample", False) + self._generation_config = GenerationConfig(**gc_kw) + return self._generation_config + + @generation_config.setter + def generation_config(self, value: Any) -> None: + self._generation_config = value + + +__all__ = [ + "MODEL_CLASS_MAPPING", + "MU2_CONFIG", + "Mu2DecoderIOConfig", + "Mu2DecoderWrapper", + "Mu2EncoderIOConfig", + "Mu2EncoderWrapper", + "WinMLMu2Model", +] diff --git a/src/winml/modelkit/models/hf/qwen.py b/src/winml/modelkit/models/hf/qwen.py new file mode 100644 index 000000000..2bf727f7f --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen.py @@ -0,0 +1,371 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Qwen3 HuggingFace Model Configuration. + +Provides decoder export wrappers and OnnxConfig registrations for +Qwen3 decoder-only models with KV cache, split into prefill and +generation sub-models. + +Export Strategy (split by task): +- QwenDecoderWrapper + QwenPrefillIOConfig: ``feature-extraction`` task + → prefill ONNX (input_ids [1, 64] → logits [1, 64, vocab] + KV [1, kv_heads, 64, head_dim]) +- QwenDecoderWrapper + QwenGenIOConfig: ``text-generation`` task + → generation ONNX (input_ids [1, 1] → logits [1, 1, vocab] + KV [1, kv_heads, 1, head_dim]) + +Both tasks share the same wrapper class; OnnxConfig determines static shapes. +The wrapper captures new-token KV directly as ONNX outputs, eliminating the +scatter→gather round-trip. + +How it works: + +1. ``QwenDecoderWrapper.forward()`` takes positional args (order matches + OnnxConfig.inputs): input_ids, attention_mask, position_ids, + past_0_key, past_0_value, ... It builds a ``WinMLSlidingWindowCache`` + from the input KV buffers, computes right-aligned ``cache_position`` + internally, runs ``Qwen3ForCausalLM``, and returns logits + captured KV. + +2. Decoder-only models need NO ``EncoderDecoderCache`` wrapping — + ``StaticCache`` is passed directly as ``past_key_values``. (Contrast with + T5 where ``EncoderDecoderCache`` is required to route self-attention and + cross-attention to separate caches.) + +3. Logits are returned for ALL input positions (not just last token). + This matches HF convention and enables both generation (last-token logits) + and perplexity evaluation (all-position logits with shifted labels). + +4. ``dynamo=True`` is required for Qwen3 ONNX export — the TorchScript + exporter fails with an internal error. Dynamo produces opset 18 models; + opset 17 downconversion currently fails for these graphs. + +Cache type: + +The default configuration uses ``WinMLSlidingWindowCache`` (FIFO +Slice+Concat). ``WinMLDecoderOnlyModel`` is cache-agnostic — padding, +mask construction, and cache updates are all delegated to the cache class +via ``prepare_prefill_chunk``, ``build_decoder_mask``, and +``update_all_layers``. To switch to ``WinMLStaticCache`` (index_copy_): + +1. **Export wrapper**: change ``QwenDecoderWrapper.forward()`` to use + ``WinMLStaticCache``, take ``cache_position`` as an explicit ONNX + input (instead of computing it internally), and set ``kv_start = 4``. +2. **OnnxConfig inputs**: add ``"cache_position": {}`` to + ``_qwen_io_inputs`` (after ``position_ids``, before ``past_*``). +3. **Inference**: override ``get_cache_class()`` to return + ``WinMLStaticCache``. ``WinMLDecoderOnlyModel`` passes + ``cache_position`` in feeds automatically when the ONNX model + expects it. + +Task name constraints (Optimum compatibility): + +- Task names must exist in ``TasksManager.get_all_tasks()`` to pass + validation in ``register_onnx_overwrite``. Custom names like + ``"causal-lm-prefill"`` require pre-registration in + ``TasksManager._LIBRARY_TO_TASKS_TO_MODEL_LOADER_MAP``. +- ``"causal-lm"`` is a synonym for ``"text-generation"`` in Optimum's + ``_SYNONYM_TASK_MAP`` — registering an OnnxConfig under ``"causal-lm"`` + silently resolves to ``"text-generation"`` at lookup time. +- ``"text-generation-with-past"`` requires the OnnxConfig to implement + ``with_past`` support (raises ``ValueError`` otherwise). +- We use ``"feature-extraction"`` (prefill) and ``"text-generation"`` (gen) + as they are standard tasks with no normalization surprises. + +Model: Qwen/Qwen3-0.6B, Qwen/Qwen3-1.7B, etc. + +Usage:: + + # Generate both configs (pipeline mode) + winml config -m Qwen/Qwen3-0.6B --task text-generation -o qwen.json + + # Build both sub-models + from winml.modelkit.models.winml.decoder_only import WinMLQwen3Model + model = WinMLQwen3Model.from_pretrained("Qwen/Qwen3-0.6B") + + # Or load pre-built ONNX directly (skip_build=True avoids re-optimization) + from winml.modelkit.models.auto import WinMLAutoModel + prefill = WinMLAutoModel.from_pretrained("prefill.onnx", skip_build=True) + gen = WinMLAutoModel.from_pretrained("gen.onnx", skip_build=True) + model = WinMLQwen3Model(sub_models={...}, config=hf_config) +""" + +from __future__ import annotations + +from typing import Any, ClassVar + +import torch +import torch.nn as nn +from optimum.exporters.onnx import OnnxConfig +from optimum.utils import NormalizedConfig +from transformers import AutoModelForCausalLM + +from ...config import WinMLBuildConfig +from ...export import register_onnx_overwrite +from ...export.config import WinMLExportConfig +from ...optim import WinMLOptimizationConfig +from ..winml import register_specialization +from ..winml.composite_model import register_composite_model +from ..winml.decoder_only import ( + DecoderOnlyInputGenerator, + DecoderOnlyPrefillInputGenerator, + WinMLDecoderOnlyModel, +) +from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLSlidingWindowCache + + +# ============================================================================= +# Wrapper nn.Module +# ============================================================================= + + +class QwenDecoderWrapper(nn.Module): + """Wraps Qwen3ForCausalLM with static KV cache I/O. + + Used for both prefill and generation ONNX export — same forward logic, + different OnnxConfig determines the static input shapes. + + Input KV: full static buffer ``[batch, kv_heads, max_cache_len, head_dim]``. + Output KV: new positions only ``[batch, kv_heads, seq_len, head_dim]``. + Logits: all input positions ``[batch, seq_len, vocab_size]`` (both prefill and gen). + The caller selects the relevant position (last for gen, all for perplexity evaluation). + """ + + def __init__(self, model: nn.Module, num_layers: int) -> None: + super().__init__() + self.model = model + self.num_layers = num_layers + self.config = model.config + + @classmethod + def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> QwenDecoderWrapper: + """Load Qwen3ForCausalLM and wrap for export.""" + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **kwargs) + wrapper = cls(model, model.config.num_hidden_layers) + wrapper.eval() + return wrapper + + def get_export_args(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, ...]: + """Convert dict inputs to positional args for torch.onnx.export.""" + return tuple(inputs.values()) + + def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Run decoder with static KV cache. + + Positional args (order matches OnnxConfig.inputs): + input_ids, attention_mask, position_ids, + past_0_key, past_0_value, past_1_key, past_1_value, ... + + Returns: + (logits, present_0_key, present_0_value, ...) where: + - logits is ``[batch, seq_len, vocab_size]`` (all positions) + - present KV is ``[batch, kv_heads, seq_len, head_dim]`` + """ + input_ids = args[0] + attention_mask = args[1] + position_ids = args[2] + kv_start = 3 + + seq_len = input_ids.size(1) + + # Build WinMLSlidingWindowCache from input KV tensors. + cache = WinMLSlidingWindowCache(self.config, max_cache_len=args[kv_start].size(2)) + cache.early_initialization( + batch_size=input_ids.size(0), + num_heads=args[kv_start].size(1), + head_dim=args[kv_start].size(3), + dtype=args[kv_start].dtype, + device=input_ids.device, + ) + max_cache_len = args[kv_start].size(2) + for i in range(self.num_layers): + cache.layers[i].keys = args[kv_start + i * 2] + cache.layers[i].values = args[kv_start + i * 2 + 1] + + # Sliding window: tokens always append at the END of the buffer. + # cache_position = buffer positions (right-aligned) so HF's + # create_causal_mask builds correct kv_idx <= q_idx constraint. + # position_ids (separate) handles RoPE with absolute positions. + cache_position = torch.arange( + max_cache_len - seq_len, + max_cache_len, + dtype=torch.int64, + device=input_ids.device, + ) + + out = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=cache, + use_cache=True, + cache_position=cache_position, + ) + + # All logits + captured KV directly (no gather). + # forward() selects the right position for padded prefill inputs. + result: list[torch.Tensor] = [out.logits] + for i in range(self.num_layers): + k, v = cache.captured[i] + result.extend([k, v]) + return tuple(result) + + +# Sub-models must use GenericTask (raw ONNX outputs) — task-specific +# wrappers like WinMLModelForFeatureExtraction would discard KV outputs. +register_specialization("qwen3", "feature-extraction", "WinMLModelForGenericTask") +register_specialization("qwen3", "text-generation", "WinMLModelForGenericTask") + + +# ============================================================================= +# OnnxConfig Registrations (using standard Optimum task names) +# ============================================================================= + +_QWEN_NORMALIZED = NormalizedConfig.with_args( + hidden_size="hidden_size", + num_layers="num_hidden_layers", + num_attention_heads="num_key_value_heads", # KV cache uses GQA heads + head_dim="head_dim", + max_cache_len="max_position_embeddings", + vocab_size="vocab_size", + allow_new=True, +) + + +def _qwen_io_inputs(num_layers: int) -> dict[str, dict[int, str]]: + result: dict[str, dict[int, str]] = { + "input_ids": {0: "batch_size"}, + "attention_mask": {0: "batch_size"}, + "position_ids": {0: "batch_size"}, + } + for i in range(num_layers): + result[f"past_{i}_key"] = {0: "batch_size"} + result[f"past_{i}_value"] = {0: "batch_size"} + return result + + +def _qwen_io_outputs(num_layers: int) -> dict[str, dict[int, str]]: + result: dict[str, dict[int, str]] = {"logits": {0: "batch_size"}} + for i in range(num_layers): + result[f"present_{i}_key"] = {0: "batch_size"} + result[f"present_{i}_value"] = {0: "batch_size"} + return result + + +@register_onnx_overwrite("qwen3", "feature-extraction", library_name="transformers") +class QwenPrefillIOConfig(OnnxConfig): + """ONNX config for Qwen3 prefill (feature-extraction task). + + Inputs: input_ids [1, 64], attention_mask [1, 256], position_ids [1, 64], + cache_position [64], past_{i}_key/value [1, 8, 256, 128] + Outputs: logits [1, 1, vocab], present_{i}_key/value [1, 8, 64, 128] + """ + + NORMALIZED_CONFIG_CLASS = _QWEN_NORMALIZED + DUMMY_INPUT_GENERATOR_CLASSES = (DecoderOnlyPrefillInputGenerator, PastKeyValueInputGenerator) + + @property + def inputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + return _qwen_io_inputs(self._normalized_config.num_layers) + + @property + def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + return _qwen_io_outputs(self._normalized_config.num_layers) + + +@register_onnx_overwrite("qwen3", "text-generation", library_name="transformers") +class QwenGenIOConfig(OnnxConfig): + """ONNX config for Qwen3 generation (text-generation task). + + Inputs: input_ids [1, 1], attention_mask [1, 256], position_ids [1, 1], + cache_position [1], past_{i}_key/value [1, 8, 256, 128] + Outputs: logits [1, 1, vocab], present_{i}_key/value [1, 8, 1, 128] + """ + + NORMALIZED_CONFIG_CLASS = _QWEN_NORMALIZED + DUMMY_INPUT_GENERATOR_CLASSES = (DecoderOnlyInputGenerator, PastKeyValueInputGenerator) + + @property + def inputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + return _qwen_io_inputs(self._normalized_config.num_layers) + + @property + def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + return _qwen_io_outputs(self._normalized_config.num_layers) + + +# ============================================================================= +# Build Config (dynamo=True required for Qwen3) +# ============================================================================= + +QWEN_CONFIG = WinMLBuildConfig( + export=WinMLExportConfig(dynamo=True, opset_version=18), + optim=WinMLOptimizationConfig( + gelu_fusion=True, + fuse_rmsnorm=True, + matmul_add_fusion=True, + clamp_constant_values=True, + remove_isnan_in_attention_mask=True, + ), +) + + +# ============================================================================= +# Model Class Mapping +# ============================================================================= + +MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = { + ("qwen3", "feature-extraction"): QwenDecoderWrapper, + ("qwen3", "text-generation"): QwenDecoderWrapper, +} + +# ============================================================================= +# WinMLQwen3Model — inference wrapper (registered as composite model) +# ============================================================================= + + +@register_composite_model("qwen3", "text-generation") +class WinMLQwen3Model(WinMLDecoderOnlyModel): + """Qwen3 decoder-only model for text generation. + + Declares Qwen3 sub-component tasks and generation config defaults. + All forward/cache logic lives in ``WinMLDecoderOnlyModel``. + """ + + _SUB_MODEL_CONFIG: ClassVar[dict[str, str]] = { + "decoder_prefill": "feature-extraction", + "decoder_gen": "text-generation", + } + + @classmethod + def get_cache_class(cls) -> type: # noqa: D102 + return WinMLSlidingWindowCache + + @property + def generation_config(self): # noqa: D102 + if not hasattr(self, "_generation_config"): + from transformers import GenerationConfig + + gc_kw: dict[str, Any] = {} + for attr in ("bos_token_id", "eos_token_id", "pad_token_id"): + val = getattr(self.config, attr, None) + if val is not None: + gc_kw[attr] = val + gc_kw.setdefault("max_new_tokens", self._max_cache_len - self._prefill_seq_len) + gc_kw.setdefault("num_beams", 1) + gc_kw.setdefault("do_sample", False) + self._generation_config = GenerationConfig(**gc_kw) + return self._generation_config + + @generation_config.setter + def generation_config(self, value: Any) -> None: + self._generation_config = value + + +__all__ = [ + "MODEL_CLASS_MAPPING", + "QWEN_CONFIG", + "QwenDecoderWrapper", + "QwenGenIOConfig", + "QwenPrefillIOConfig", + "WinMLQwen3Model", +] diff --git a/src/winml/modelkit/models/hf/t5.py b/src/winml/modelkit/models/hf/t5.py new file mode 100644 index 000000000..e93cb1534 --- /dev/null +++ b/src/winml/modelkit/models/hf/t5.py @@ -0,0 +1,389 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""T5 HuggingFace Model Configuration. + +Provides encoder/decoder export wrappers and OnnxConfig registrations for +T5 encoder-decoder models with static KV cache. + +Export Strategy (split by task): +- T5EncoderWrapper + T5EncoderIOConfig: ``feature-extraction`` task + → encoder-only ONNX (input_ids, attention_mask → encoder_hidden_states) +- T5DecoderWrapper + T5DecoderIOConfig: ``text2text-generation`` task + → decoder ONNX with static buffer input + single-token KV output. + Uses HF StaticCache (index_copy_ at cache_position) for attention. + Output is only the new token's KV [batch, heads, 1, d_kv]. + +Model: google-t5/t5-small, google-t5/t5-base, etc. + +Usage: + wmk config -m google-t5/t5-small --task feature-extraction → encoder + wmk config -m google-t5/t5-small --task text2text-generation → decoder +""" + +from __future__ import annotations + +from typing import Any, ClassVar + +import torch +import torch.nn as nn +from optimum.exporters.onnx import OnnxConfig +from optimum.utils import NormalizedConfig +from optimum.utils.input_generators import DummyTextInputGenerator +from transformers import T5ForConditionalGeneration +from transformers.cache_utils import DynamicCache, EncoderDecoderCache + +from ...config import WinMLBuildConfig +from ...export import register_onnx_overwrite +from ...optim import WinMLOptimizationConfig +from ..winml.composite_model import register_composite_model +from ..winml.encoder_decoder import EncoderDecoderInputGenerator, WinMLEncoderDecoderModel +from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLSlidingWindowCache + + +# ============================================================================= +# Wrapper nn.Modules (with from_pretrained, like SAM2 wrappers) +# ============================================================================= + + +class T5EncoderWrapper(nn.Module): + """Wraps T5 encoder for standalone ONNX export. + + Loads the full T5ForConditionalGeneration and extracts the encoder. + """ + + def __init__(self, encoder: nn.Module) -> None: + super().__init__() + self.encoder = encoder + + @classmethod + def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> T5EncoderWrapper: + """Load full T5, extract encoder.""" + full_model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, **kwargs) + wrapper = cls(full_model.encoder) + wrapper.eval() + return wrapper + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + """Return encoder last hidden state.""" + return self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + ).last_hidden_state + + +class T5DecoderWrapper(nn.Module): + """Wraps T5ForConditionalGeneration with sliding-window KV cache I/O. + + Input: full buffer ``[batch, heads, max_decode, d_kv]`` per layer. + Output: only the new token's KV ``[batch, heads, 1, d_kv]`` per layer. + + Uses ``WinMLSlidingWindowCache`` (Slice+Concat eviction) wrapped in + ``EncoderDecoderCache`` (cross-attn empty → always recomputed from + ``encoder_hidden_states``). + + ``cache_position`` is intentionally NOT an ONNX input — it is pinned to + ``[max_cache_len - 1]`` (the rightmost buffer slot) inside ``forward`` and + traced as a Constant. For single-token generation with a sliding window, + the new token is always written to the rightmost slot, so this value is + invariant. Baking it in lets ONNX constant-fold the entire + ``compute_bias`` subgraph (``memory_position - context_position`` is + constant → learned-bias Gather becomes a fixed tensor) and collapses the + causal mask ``kv_idx <= q_idx`` (all-True since ``q_idx == W-1``). + + This couples the exported graph to sliding-window semantics at build + time. ``WinMLStaticCache`` cannot be used as the *inference* cache for + this ONNX — its buffer layout (left-aligned, index_copy_) does not match + the graph's internal Slice+Concat. Callers who want static-cache + semantics must subclass ``T5DecoderWrapper``, take ``cache_position`` as + an input again, and re-export. ``WinMLStaticCache`` itself remains + fully functional for that path. + """ + + def __init__(self, model: nn.Module, num_layers: int) -> None: + super().__init__() + self.model = model + self.num_layers = num_layers + # Expose config for OnnxConfig / NormalizedConfig access + self.config = model.config + + @classmethod + def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> T5DecoderWrapper: + """Load full T5, wrap with sliding-window cache.""" + full_model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, **kwargs) + num_layers = full_model.config.num_layers + wrapper = cls(full_model, num_layers) + wrapper.eval() + return wrapper + + def get_export_args(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, ...]: + """Convert dict inputs to positional args for torch.onnx.export.""" + return tuple(inputs.values()) + + def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Run decoder with sliding-window KV cache. + + Positional args (order matches OnnxConfig.inputs): + decoder_input_ids, encoder_hidden_states, attention_mask, + decoder_attention_mask, + past_0_key, past_0_value, past_1_key, past_1_value, ... + + Returns: + (logits, present_0_key, present_0_value, ...) where each + present KV is [batch, heads, 1, d_kv] — the new token only. + """ + decoder_input_ids = args[0] + encoder_hidden_states = args[1] + attention_mask = args[2] + decoder_attention_mask = args[3] + kv_start = 4 + + # Build WinMLSlidingWindowCache from input KV tensors. + # update() does Slice+Concat (not index_copy_/ScatterElements) — evicting + # the N oldest entries and appending the N new ones at the right. The + # incoming key/value states are captured for direct ONNX output + # (avoiding a scatter→gather round-trip in the graph). + max_cache_len = args[kv_start].size(2) + self_attn_cache = WinMLSlidingWindowCache(self.config, max_cache_len=max_cache_len) + self_attn_cache.early_initialization( + batch_size=decoder_input_ids.size(0), + num_heads=args[kv_start].size(1), + head_dim=args[kv_start].size(3), + dtype=args[kv_start].dtype, + device=decoder_input_ids.device, + ) + for i in range(self.num_layers): + self_attn_cache.layers[i].keys = args[kv_start + i * 2] + self_attn_cache.layers[i].values = args[kv_start + i * 2 + 1] + + # Sliding window + single-token gen: the query is always at the + # rightmost slot. Constructing this constant inside forward traces it + # as a Constant node — downstream compute_bias and causal-mask subgraphs + # then constant-fold through ONNX optimization. + cache_position = torch.tensor( + [max_cache_len - 1], dtype=torch.int64, device=decoder_input_ids.device + ) + + # EncoderDecoderCache is structurally required: T5Attention routes + # self-attention → self_attention_cache, cross-attention → cross_attention_cache. + # Without the wrapper, both would share the same cache + layer indices. + # DynamicCache for cross-attn is a no-op during export (each layer + # computes fresh from encoder_hidden_states, never reuses). + cross_attn_cache = DynamicCache() + cache = EncoderDecoderCache(self_attn_cache, cross_attn_cache) + + out = self.model( + decoder_input_ids=decoder_input_ids, + encoder_outputs=(encoder_hidden_states,), + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + past_key_values=cache, + use_cache=True, + cache_position=cache_position, + ) + + # Return new-token KV directly from the capturing cache. + # The old approach did gather(ScatterElements output) — a round-trip. + # The cache already saved the incoming key/value states. + result: list[torch.Tensor] = [out.logits] + for i in range(self.num_layers): + k, v = self_attn_cache.captured[i] + result.extend([k, v]) + return tuple(result) + + +# ============================================================================= +# OnnxConfig Registrations +# ============================================================================= + + +@register_onnx_overwrite("t5", "feature-extraction", library_name="transformers") +class T5EncoderIOConfig(OnnxConfig): + """ONNX config for T5 encoder (feature-extraction task). + + Inputs: input_ids, attention_mask + Outputs: encoder_hidden_states + """ + + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( + vocab_size="vocab_size", + allow_new=True, + ) + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator,) + + @property + def inputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + return { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + } + + @property + def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + return { + "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, + } + + +@register_onnx_overwrite("t5", "text2text-generation", library_name="transformers") +class T5DecoderIOConfig(OnnxConfig): + """ONNX config for T5 decoder with sliding-window KV cache. + + Inputs: decoder_input_ids, encoder_hidden_states, attention_mask, + decoder_attention_mask, past_{i}_key/value + Outputs: logits, present_{i}_key/value + + ``cache_position`` is *not* an input: ``T5DecoderWrapper.forward`` pins it + to ``[max_cache_len - 1]`` (rightmost buffer slot) as a Constant in the + graph. This couples the exported model to sliding-window semantics at + build time; see ``T5DecoderWrapper`` docstring for the static-cache + re-export path if needed. + + Input past KV: full buffer [batch, heads, max_decode, d_kv]. + Output present KV: new token only [batch, heads, 1, d_kv]. + """ + + # T5Config: d_model, num_layers, num_heads, d_kv, vocab_size, n_positions. + # sequence_length uses Optimum default (16) — NOT n_positions (512, too large). + # head_dim maps to d_kv for PastKeyValueInputGenerator. + # max_cache_len maps to n_positions (decoder static buffer size). + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( + hidden_size="d_model", + num_layers="num_layers", + num_attention_heads="num_heads", + head_dim="d_kv", + max_cache_len="n_positions", + vocab_size="vocab_size", + allow_new=True, + ) + DUMMY_INPUT_GENERATOR_CLASSES = ( + EncoderDecoderInputGenerator, + PastKeyValueInputGenerator, + ) + + @property + def inputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + result: dict[str, dict[int, str]] = { + "decoder_input_ids": {0: "batch_size"}, + "encoder_hidden_states": {0: "batch_size"}, + "attention_mask": {0: "batch_size"}, + "decoder_attention_mask": {0: "batch_size"}, + } + num_layers = self._normalized_config.num_layers + for i in range(num_layers): + result[f"past_{i}_key"] = {0: "batch_size"} + result[f"past_{i}_value"] = {0: "batch_size"} + return result + + @property + def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 + result: dict[str, dict[int, str]] = { + "logits": {0: "batch_size"}, + } + num_layers = self._normalized_config.num_layers + for i in range(num_layers): + result[f"present_{i}_key"] = {0: "batch_size"} + result[f"present_{i}_value"] = {0: "batch_size"} + return result + + +# ============================================================================= +# Model Class Mapping (same pattern as SAM2 and CLIP) +# ============================================================================= + +MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = { + ("t5", "feature-extraction"): T5EncoderWrapper, + ("t5", "text2text-generation"): T5DecoderWrapper, +} + +T5_CONFIG = WinMLBuildConfig( + optim=WinMLOptimizationConfig( + gelu_fusion=True, + fuse_rmsnorm=True, + matmul_add_fusion=True, + clamp_constant_values=True, + remove_isnan_in_attention_mask=True, + ), +) + + +# ============================================================================= +# WinMLT5Model — inference wrapper (registered as composite model) +# ============================================================================= + + +@register_composite_model("t5", "translation") +@register_composite_model("t5", "summarization") +class WinMLT5Model(WinMLEncoderDecoderModel): + """T5 encoder-decoder model for seq2seq tasks (translation, summarization). + + Declares T5 sub-component tasks and generation config defaults. + All encoder-decoder forward/cache logic lives in ``WinMLEncoderDecoderModel``. + """ + + _SUB_MODEL_CONFIG: ClassVar[dict[str, str]] = { + "encoder": "feature-extraction", + "decoder": "text2text-generation", + } + + @classmethod + def get_cache_class(cls) -> type: + """T5 defaults to ``WinMLSlidingWindowCache`` (Slice+Concat; no ScatterElements). + + Correctness with T5's learned relative position bias hinges on a single + invariant: ``cache_position`` is always the query's *buffer index*, not + its absolute sequence position. ``get_query_cache_position`` on each + cache class supplies the right value — ``[step]`` for static, + ``[max_cache_len-1]`` for sliding. Under that convention, + ``T5Attention.compute_bias`` computes ``memory_position - context_position + = j - (W-1)`` which gives correct relative distances regardless of + overflow, and HF's ``create_causal_mask`` (``kv_idx <= q_idx``) allows + every buffer slot while the 2D decoder mask selects the filled region. + + ``WinMLStaticCache`` remains fully supported — subclass ``WinMLT5Model`` + and override this method to get index_copy_ semantics instead. + """ + return WinMLSlidingWindowCache + + @property + def generation_config(self): # noqa: D102 + if not hasattr(self, "_generation_config"): + from transformers import GenerationConfig + + gc_kw: dict[str, Any] = {} + if self.config is not None: + for attr in ( + "decoder_start_token_id", + "bos_token_id", + "eos_token_id", + "pad_token_id", + ): + val = getattr(self.config, attr, None) + if val is not None: + gc_kw[attr] = val + gc_kw.setdefault("max_new_tokens", self._max_dec - 1) + # Static batch=1 ONNX models don't support beam search + gc_kw.setdefault("num_beams", 1) + gc_kw.setdefault("do_sample", False) + self._generation_config = GenerationConfig(**gc_kw) + return self._generation_config + + @generation_config.setter + def generation_config(self, value: Any) -> None: + self._generation_config = value + + +__all__ = [ + "MODEL_CLASS_MAPPING", + "T5_CONFIG", + "T5DecoderIOConfig", + "T5DecoderWrapper", + "T5EncoderIOConfig", + "T5EncoderWrapper", + "WinMLT5Model", +] diff --git a/src/winml/modelkit/models/winml/__init__.py b/src/winml/modelkit/models/winml/__init__.py index ea6069a17..5f06b7c86 100644 --- a/src/winml/modelkit/models/winml/__init__.py +++ b/src/winml/modelkit/models/winml/__init__.py @@ -176,6 +176,13 @@ def register_specialization(model_type: str, task: str, class_name: str) -> None # ============================================================================= from .base import WinMLModelForGenericTask, WinMLPreTrainedModel +from .composite_model import ( + COMPOSITE_MODEL_REGISTRY, + WinMLCompositeModel, + register_composite_model, +) +from .decoder_only import WinMLDecoderOnlyModel +from .encoder_decoder import WinMLEncoderDecoderModel from .feature_extraction import WinMLModelForFeatureExtraction from .image_classification import WinMLModelForImageClassification from .image_segmentation import ( @@ -183,14 +190,24 @@ def register_specialization(model_type: str, task: str, class_name: str) -> None WinMLModelForImageSegmentation, WinMLModelForSemanticSegmentation, ) +from .kv_cache import ( + WinMLCache, + WinMLSlidingWindowCache, + WinMLStaticCache, +) from .object_detection import WinMLModelForObjectDetection from .sequence_classification import WinMLModelForSequenceClassification __all__ = [ + "COMPOSITE_MODEL_REGISTRY", "TASK_TO_WINML_CLASS", "WINML_MODEL_CLASS_MAPPING", "ImageSegmentationOutput", + "WinMLCache", + "WinMLCompositeModel", + "WinMLDecoderOnlyModel", + "WinMLEncoderDecoderModel", "WinMLModelForFeatureExtraction", "WinMLModelForGenericTask", "WinMLModelForImageClassification", @@ -199,7 +216,10 @@ def register_specialization(model_type: str, task: str, class_name: str) -> None "WinMLModelForSemanticSegmentation", "WinMLModelForSequenceClassification", "WinMLPreTrainedModel", + "WinMLSlidingWindowCache", + "WinMLStaticCache", "get_supported_tasks", "get_winml_class", + "register_composite_model", "register_specialization", ] diff --git a/src/winml/modelkit/models/winml/base.py b/src/winml/modelkit/models/winml/base.py index 7f5e55c70..be6e2e075 100644 --- a/src/winml/modelkit/models/winml/base.py +++ b/src/winml/modelkit/models/winml/base.py @@ -64,6 +64,7 @@ def __init__( onnx_path: str | Path, config: PretrainedConfig | None = None, device: str = "auto", + session_options: Any | None = None, ) -> None: """Initialize inference model. @@ -71,6 +72,7 @@ def __init__( onnx_path: Path to ONNX model file config: HuggingFace PretrainedConfig (num_labels, id2label, etc.) device: Target device ("auto", "npu", "gpu", "cpu") + session_options: ORT SessionOptions (e.g., for graph_optimization_level) """ self._onnx_path = Path(onnx_path) self.config = config @@ -83,6 +85,7 @@ def __init__( self._session = WinMLSession( onnx_path=self._onnx_path, device=device, + session_options=session_options, ) @property diff --git a/src/winml/modelkit/models/winml/composite_model.py b/src/winml/modelkit/models/winml/composite_model.py new file mode 100644 index 000000000..8a887cb2f --- /dev/null +++ b/src/winml/modelkit/models/winml/composite_model.py @@ -0,0 +1,287 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""WinML composite model base and registry. + +Provides ``WinMLCompositeModel`` — a base class for models composed of +multiple ``WinMLAutoModel`` sub-components (e.g., encoder + decoder, +prefill + gen). Each subclass declares ``_SUB_MODEL_CONFIG`` mapping +component names to HF tasks; ``from_pretrained()`` builds them all. + +Registry +-------- +``@register_composite_model(model_type, task)`` registers a pipeline class. +``wmk config`` checks the registry to generate one config file per component:: + + wmk config -m google-t5/t5-small --task translation -o t5.json + # → t5_encoder.json (feature-extraction) + t5_decoder.json (text2text-generation) + + wmk build -c t5_encoder.json -m google-t5/t5-small -o output/encoder + wmk build -c t5_decoder.json -m google-t5/t5-small -o output/decoder + +Per-component kwargs +-------------------- +``sub_model_kwargs`` in ``from_pretrained`` allows different ``shape_config`` +per sub-model (e.g., different ``max_cache_len`` for prefill vs gen):: + + WinMLCompositeModel.from_pretrained(model_id, task="text-generation", + sub_model_kwargs={ + "decoder_prefill": {"shape_config": {"max_cache_len": 256, "seq_len": 64}}, + "decoder_gen": {"shape_config": {"max_cache_len": 256, "seq_len": 1}}, + }) + +Concrete composite models live alongside their export configs: + +- ``models.hf.t5.WinMLT5Model`` (encoder-decoder, T5) +- ``models.hf.mu2.WinMLMu2Model`` (encoder-decoder, Mu2) +- ``models.hf.qwen.WinMLQwen3Model`` (decoder-only, Qwen3) +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, ClassVar + +import torch + +from .base import PreTrainedModel + + +if TYPE_CHECKING: + from pathlib import Path + + from transformers import PretrainedConfig + +logger = logging.getLogger(__name__) + + +# ========================================================================= +# composite model Registry +# ========================================================================= + +# Maps (model_type, task) → pipeline class with _SUB_MODEL_CONFIG. +# Used by `wmk config` to generate one config file per sub-component. +COMPOSITE_MODEL_REGISTRY: dict[tuple[str, str], type] = {} + + +def register_composite_model(model_type: str, task: str): + """Class decorator that registers a composite model for `wmk config`.""" + + def decorator(cls: type) -> type: + key = (model_type, task) + if key in COMPOSITE_MODEL_REGISTRY: + raise ValueError( + f"Composite model already registered for {key!r}: " + f"{COMPOSITE_MODEL_REGISTRY[key].__name__}. " + f"Cannot register {cls.__name__}." + ) + COMPOSITE_MODEL_REGISTRY[key] = cls + return cls + + return decorator + + +# ========================================================================= +# WinMLCompositeModel — multi-component base +# ========================================================================= + + +class WinMLCompositeModel(PreTrainedModel): + """Base class for models composed of multiple WinMLAutoModel sub-components. + + Subclasses declare ``_SUB_MODEL_CONFIG``: a mapping of component name to + the HF task used to build it via ``WinMLAutoModel.from_pretrained``. + + After construction, sub-components are available in ``self.sub_models``. + """ + + _SUB_MODEL_CONFIG: ClassVar[dict[str, str]] = {} + + def __init__( + self, + sub_models: dict[str, Any], + config: PretrainedConfig, + device: str = "cpu", + ) -> None: + self.sub_models = sub_models + self.config = config + self._device = device + + @classmethod + def from_pretrained( + cls, + model_id: str, + task: str, + *, + device: str = "cpu", + use_cache: bool = True, + force_rebuild: bool = False, + sub_model_kwargs: dict[str, dict[str, Any]] | None = None, + trust_remote_code: bool = False, + **kwargs: Any, + ) -> WinMLCompositeModel: + """Build all sub-components and return ready-to-use model. + + When called on ``WinMLCompositeModel`` directly (not a subclass), + ``task`` is required to resolve the concrete class from + ``COMPOSITE_MODEL_REGISTRY``. When called on a registered subclass + (e.g., ``WinMLT5Model``), ``task`` is optional. + + Args: + model_id: HuggingFace model ID or local path. + task: Pipeline task name (e.g., ``"translation"``, + ``"text-generation"``). Required when calling on the base + class; ignored when calling on a registered subclass. + device: Target device. + use_cache: Use persistent cache. + force_rebuild: Force rebuild even if cached. + sub_model_kwargs: Per-component kwargs forwarded to + ``WinMLAutoModel.from_pretrained()``. Keys are component + names from ``_SUB_MODEL_CONFIG`` (e.g., ``"decoder_prefill"``, + ``"decoder_gen"``). Values are dicts merged on top of the + shared ``**kwargs``. Use this to pass different + ``shape_config`` per sub-model. + trust_remote_code: Forward to ``AutoConfig.from_pretrained`` + and each sub-model's ``WinMLAutoModel.from_pretrained``. + Required for custom-code HF models (e.g., Mu2). + **kwargs: Forwarded to ``WinMLAutoModel.from_pretrained()`` + for every sub-component (overridden by ``sub_model_kwargs``). + """ + from transformers import AutoConfig + + hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) + model_type = hf_config.model_type + + if not cls._SUB_MODEL_CONFIG: + # Resolve concrete class from registry when called on the base class + resolved_cls = COMPOSITE_MODEL_REGISTRY.get((model_type, task)) + if resolved_cls is None: + raise ValueError( + f"No composite model registered for ({model_type!r}, {task!r}). " + f"Registered: {list(COMPOSITE_MODEL_REGISTRY.keys())}" + ) + return resolved_cls.from_pretrained( + model_id, + task, + device=device, + use_cache=use_cache, + force_rebuild=force_rebuild, + sub_model_kwargs=sub_model_kwargs, + trust_remote_code=trust_remote_code, + **kwargs, + ) + from ..auto import WinMLAutoModel + + per_component = sub_model_kwargs or {} + sub_models: dict[str, Any] = {} + for name, component_task in cls._SUB_MODEL_CONFIG.items(): + logger.info("Building %s for %s...", name, model_id) + merged = {**kwargs, **per_component.get(name, {})} + sub_models[name] = WinMLAutoModel.from_pretrained( + model_id, + task=component_task, + device=device, + use_cache=use_cache, + force_rebuild=force_rebuild, + trust_remote_code=trust_remote_code, + **merged, + ) + + return cls(sub_models=sub_models, config=hf_config) + + @classmethod + def from_onnx( + cls, + onnx_path: dict[str, str | Path], + *, + task: str | None = None, + hf_config: PretrainedConfig | None = None, + sub_model_kwargs: dict[str, dict[str, Any]] | None = None, + **kwargs: Any, + ) -> WinMLCompositeModel: + """Load composite model from pre-built ONNX files. + + Resolves the concrete model class from the registry using *task* + and ``hf_config.model_type``, then builds each sub-component via + ``WinMLAutoModel.from_onnx``. + + Args: + onnx_path: Maps component name (e.g., ``"encoder"``, + ``"decoder_prefill"``) to its ONNX file path. Values may + be ``str`` or ``pathlib.Path``; coerced via ``Path(path)`` + inside the dispatch loop. + task: Pipeline task (e.g., ``"translation"``, + ``"text-generation"``). + hf_config: HF ``PretrainedConfig`` for the model. Used to + resolve the concrete class from the registry via + ``hf_config.model_type``. + sub_model_kwargs: Per-component kwargs merged on top of + ``**kwargs`` for each sub-model's ``from_onnx`` call. + **kwargs: Forwarded to ``WinMLAutoModel.from_onnx`` for every + component (overridden by ``sub_model_kwargs``). + """ + from pathlib import Path + + per_component = sub_model_kwargs or {} + + # Resolve concrete class from registry + model_type = getattr(hf_config, "model_type", None) if hf_config else None + if not cls._SUB_MODEL_CONFIG: + resolved_cls = COMPOSITE_MODEL_REGISTRY.get((model_type, task)) + if resolved_cls is None: + raise ValueError( + f"No composite model for ({model_type!r}, {task!r}). " + f"Registered: {list(COMPOSITE_MODEL_REGISTRY.keys())}" + ) + else: + resolved_cls = cls + + from ..auto import WinMLAutoModel + + sub_models: dict[str, Any] = {} + for name, path in onnx_path.items(): + component_task = resolved_cls._SUB_MODEL_CONFIG.get(name) + if component_task is None: + valid = list(resolved_cls._SUB_MODEL_CONFIG.keys()) + raise ValueError( + f"Unknown component {name!r}. Valid names for {resolved_cls.__name__}: {valid}" + ) + merged = {**kwargs, "task": component_task, **per_component.get(name, {})} + sub_models[name] = WinMLAutoModel.from_onnx(Path(path), **merged) + + return resolved_cls(sub_models=sub_models, config=hf_config) + + @property + def device(self) -> torch.device: + """Device (CPU — ORT handles actual placement).""" + return torch.device("cpu") + + @property + def ort_device(self) -> str: + """ORT execution provider target (e.g. "npu", "gpu", "cpu", "auto").""" + return self._device + + @property + def dtype(self) -> torch.dtype: + """Model dtype for HF compatibility.""" + return torch.float32 + + def to(self, *args: Any, **kwargs: Any) -> WinMLCompositeModel: + """No-op for HF pipeline compatibility; sub-models remain on their original device.""" + if args or kwargs: + # debug (not warning) — HF pipelines routinely call `.to("cpu")` as a + # setup step; surfacing that as a warning would spam normal usage. + logger.debug( + "WinMLCompositeModel.to(...) is a no-op; sub-models remain on their original " + "device. Use WinMLSession options to control device placement." + ) + return self + + def __call__(self, **kwargs: Any) -> Any: + """Inference entry point.""" + return self.forward(**kwargs) + + def forward(self, **kwargs: Any) -> Any: + """Subclasses implement task-specific logic.""" + raise NotImplementedError diff --git a/src/winml/modelkit/models/winml/decoder_only.py b/src/winml/modelkit/models/winml/decoder_only.py new file mode 100644 index 000000000..3bfa77700 --- /dev/null +++ b/src/winml/modelkit/models/winml/decoder_only.py @@ -0,0 +1,398 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""WinML Decoder-Only composite model. + +Class hierarchy:: + + WinMLCompositeModel(PreTrainedModel) — multi-component base + └─ WinMLDecoderOnlyModel(GenerationMixin) — prefill + gen with WinMLCache + └─ WinMLQwen3Model — Qwen3 tasks + generation config + +How it works: + +1. ``@register_composite_model("qwen3", "text-generation")`` hooks into + ``winml config`` so that ``winml config -m Qwen/Qwen3-0.6B --task text-generation`` + generates ``qwen_decoder_prefill.json`` + ``qwen_decoder_gen.json``. + +2. ``from_pretrained()`` builds each component via ``WinMLAutoModel`` + independently. Sub-models are registered as ``WinMLModelForGenericTask`` + (via ``register_specialization``) so their raw ONNX outputs (logits + KV) + are returned as-is — task-specific wrappers like + ``WinMLModelForFeatureExtraction`` would discard the KV outputs. + +3. ``forward()`` is called by ``GenerationMixin.generate()`` on each step: + + - **Prefill** (``input_ids`` has multiple tokens): chunks into + ``prefill_seq_len`` pieces and runs the prefill ONNX model in a loop. + Right-pads the last chunk; only writes real tokens' KV into the cache + (padding positions are discarded). Returns logits for ALL real + positions ``[1, seq_len, vocab]`` — matches HF convention, enabling + both generation (last-token selection) and perplexity evaluation + (shifted cross-entropy over all positions). + + - **Generation** (``input_ids`` has 1 token): runs the gen ONNX model + with the single token + full KV cache buffer as input. + +4. KV cache is cache-agnostic — ``WinMLDecoderOnlyModel`` delegates mask + construction, position encoding, and cache writes to the ``WinMLCache`` + subclass. Two implementations ship: + ``WinMLStaticCache`` (ScatterElements/``index_copy_``) and + ``WinMLSlidingWindowCache`` (Slice+Concat FIFO). ``WinMLQwen3Model`` + selects the sliding-window variant. The cache persists across + ``generate()`` steps via ``CausalLMOutputWithPast.past_key_values``. + +5. ``prepare_inputs_for_generation()`` handles a subtle interaction with + ``GenerationMixin``: on the FIRST call, GenerationMixin may pass an + auto-created ``DynamicCache`` (empty). We detect this (not a + ``WinMLCache`` or empty) and pass the full prompt through for prefill + rather than trimming to the last token. On subsequent calls with a + populated ``WinMLCache``, we trim to the last token as usual. + +Design principles (same as composite_model.py): + +- ONNX I/O names and shapes are read from ``io_config``, never hardcoded. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +import torch +from optimum.utils.input_generators import DummyInputGenerator +from transformers.generation.utils import GenerationMixin +from transformers.modeling_outputs import CausalLMOutputWithPast + +from .composite_model import WinMLCompositeModel + + +if TYPE_CHECKING: + from transformers import Cache, PretrainedConfig + +logger = logging.getLogger(__name__) + + +# ========================================================================= +# DecoderOnlyInputGenerator — shared dummy input generator +# ========================================================================= + + +class DecoderOnlyInputGenerator(DummyInputGenerator): + """Generates base inputs for decoder-only models with static KV cache. + + Produces ``input_ids``, ``attention_mask``, ``position_ids``, and + ``cache_position``. Reads ``vocab_size``, ``max_cache_len``, and + ``seq_len`` from the ``NormalizedConfig``. + + ``seq_len`` controls the input token count and is read from + ``normalized_config.seq_len`` (falls back to ``_default_seq_len``). + Subclasses override the default for prefill vs generation: + + - ``DecoderOnlyPrefillInputGenerator``: ``_default_seq_len = 64`` + - ``DecoderOnlyInputGenerator`` (base / gen): ``_default_seq_len = 1`` + + To override at config time, set ``config.seq_len = N`` on the HF config. + """ + + SUPPORTED_INPUT_NAMES = ( + "input_ids", + "attention_mask", + "position_ids", + "cache_position", + "position_id", + ) + + _default_seq_len: int = 1 + + def __init__( + self, + task: str, + normalized_config: Any, + batch_size: int = 1, + seq_len: int | None = None, + max_cache_len: int | None = None, + **kwargs: Any, + ) -> None: + self.batch_size = batch_size + self.vocab_size = normalized_config.vocab_size + self.max_cache_len = max_cache_len or normalized_config.max_cache_len + self.seq_len: int = seq_len or getattr(normalized_config, "seq_len", self._default_seq_len) + + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ) -> torch.Tensor: + """Generate a dummy tensor for the given input name.""" + if input_name == "input_ids": + return self.random_int_tensor( + (self.batch_size, self.seq_len), + max_value=self.vocab_size, + framework=framework, + dtype=int_dtype, + ) + if input_name == "attention_mask": + mask = torch.zeros(self.batch_size, self.max_cache_len, dtype=torch.int64) + mask[:, : self.seq_len] = 1 + return mask + if input_name == "position_ids": + return torch.arange(self.seq_len, dtype=torch.int64).unsqueeze(0) + if input_name == "cache_position": + return torch.arange(self.seq_len, dtype=torch.int64) + if input_name == "position_id": + return torch.arange(self.seq_len, dtype=torch.int64) + raise ValueError(f"Unknown input: {input_name}") + + +class DecoderOnlyPrefillInputGenerator(DecoderOnlyInputGenerator): + """Prefill variant with ``_default_seq_len = 64``.""" + + _default_seq_len: int = 64 + + +# ========================================================================= +# WinMLDecoderOnlyModel — prefill + gen with WinMLCache +# ========================================================================= + + +class WinMLDecoderOnlyModel(WinMLCompositeModel, GenerationMixin): + """Decoder-only composite model with HF GenerationMixin support. + + Expects sub-components ``"decoder_prefill"`` and ``"decoder_gen"`` in + ``_SUB_MODEL_CONFIG``. Provides the full interface required by + ``GenerationMixin.generate()`` for decoder-only models with static KV cache. + + Input/output names and shapes are read from ONNX I/O metadata. + """ + + main_input_name = "input_ids" + base_model_prefix = "" + _is_stateful = False + _supports_cache_class = False + + def __init__( + self, + sub_models: dict[str, Any], + config: PretrainedConfig, + ) -> None: + super().__init__(sub_models, config) + self._prefill_model = sub_models["decoder_prefill"] + self._gen_model = sub_models["decoder_gen"] + + # Build {name: shape} lookups from ONNX I/O metadata + prefill_io = self._prefill_model.io_config + self._prefill_expected = dict( + zip( + prefill_io.get("input_names", []), + prefill_io.get("input_shapes", []), + strict=False, + ) + ) + gen_io = self._gen_model.io_config + self._gen_expected = dict( + zip(gen_io.get("input_names", []), gen_io.get("input_shapes", []), strict=False) + ) + + # Cache geometry from gen model's KV input shape + self._max_cache_len = self._gen_expected["past_0_key"][2] + self._num_kv_heads = self._gen_expected["past_0_key"][1] + self._head_dim = self._gen_expected["past_0_key"][3] + self._num_kv_layers = sum( + 1 for n in self._gen_expected if n.startswith("past_") and n.endswith("_key") + ) + # Resolve KV cache dtype from ONNX input types (fp32 or fp16) + gen_type_map = dict( + zip(gen_io.get("input_names", []), gen_io.get("input_types", []), strict=False) + ) + import numpy as np + + if "past_0_key" not in gen_type_map: + raise KeyError( + "'past_0_key' is missing from the decoder ONNX input type map; " + "cannot derive KV cache dtype. Verify the decoder ONNX was built with " + "PastKeyValueInputGenerator." + ) + _np_dtype = gen_type_map["past_0_key"] + self._kv_dtype = torch.from_numpy(np.zeros(1, dtype=_np_dtype)).dtype + + # Prefill chunk size + self._prefill_seq_len = self._prefill_expected["input_ids"][1] + + # ----- Cache + GenerationMixin interface ----- + + @classmethod + def get_cache_class(cls) -> type: + """Return the WinMLCache subclass. Subclasses must override.""" + raise NotImplementedError + + def _resolve_cache(self, past_key_values: Any) -> Any: + """Unwrap or create WinMLCache for this generation step. + + 1. Unwrap EncoderDecoderCache wrapper (GenerationMixin may add it even for + decoder-only models in rare paths; handled here for symmetry with + encoder_decoder.py). + 2. If already a WinMLCache, return directly. + 3. Otherwise create a fresh one and reset it. + """ + from .kv_cache import WinMLCache + + # (1) Unwrap EncoderDecoderCache — never received by decoder-only models + # under the current GenerationMixin flow, but mirroring encoder_decoder.py's + # defensive unwrap keeps the two _resolve_cache paths symmetric. + if hasattr(past_key_values, "self_attention_cache"): + past_key_values = past_key_values.self_attention_cache + + # (2) Already our cache — return as-is + if isinstance(past_key_values, WinMLCache): + return past_key_values + + kv_shape = [1, self._num_kv_heads, self._max_cache_len, self._head_dim] + cache = self.get_cache_class().create(self.config, kv_shape, self._kv_dtype) + cache.reset() + return cache + + def can_generate(self) -> bool: # noqa: D102 + return True + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Cache | None = None, + attention_mask: torch.Tensor | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Build inputs for each generate() step.""" + from .kv_cache import WinMLCache + + if isinstance(past_key_values, WinMLCache) and past_key_values.get_seq_length() > 0: + input_ids = input_ids[:, -1:] + else: + past_key_values = None + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "attention_mask": attention_mask, + } + + # ----- Forward ----- + + def forward( + self, + *, + input_ids: torch.Tensor, + past_key_values: Cache | None = None, + attention_mask: torch.Tensor | None = None, + **kwargs: Any, + ) -> CausalLMOutputWithPast: + """Run prefill or gen with static KV cache. + + Called by ``GenerationMixin.generate()`` on each step: + - First call: ``input_ids`` is the full prompt → prefill (chunked). + - Subsequent calls: ``input_ids`` is 1 token → gen. + + Args: + input_ids: Token IDs ``[batch, seq_len]``. + past_key_values: ``WinMLCache`` from previous step (None on first call). + attention_mask: Not used directly — rebuilt from cache occupancy. + **kwargs: Absorbed for GenerationMixin compatibility. + + Returns: + CausalLMOutputWithPast with logits and updated ``WinMLCache``. + """ + cache = self._resolve_cache(past_key_values) + + seq_len = input_ids.shape[1] + if seq_len > 1: + logits = self._run_prefill(input_ids, cache) + else: + logits = self._run_gen(input_ids, cache) + + return CausalLMOutputWithPast( + logits=logits, + past_key_values=cache, + ) + + # ----- Prefill (chunked) ----- + + def _run_prefill(self, input_ids: torch.Tensor, cache: Any) -> torch.Tensor: + """Run prefill model in a loop over chunks of ``prefill_seq_len``. + + Returns logits for ALL real input positions ``[1, seq_len, vocab_size]``. + """ + seq_len = input_ids.shape[1] + all_logits: list[torch.Tensor] = [] + + for start in range(0, seq_len, self._prefill_seq_len): + end = min(start + self._prefill_seq_len, seq_len) + chunk_len = end - start + + padded_ids, position_ids, pad_len = cache.prepare_prefill_chunk( + input_ids[:, start:end], + start, + self._prefill_seq_len, + ) + attn_mask = cache.build_decoder_mask(self._max_cache_len, chunk_len) + + feeds: dict[str, Any] = { + "input_ids": padded_ids, + "attention_mask": attn_mask, + "position_ids": position_ids, + } + # NOTE: currently dead for Qwen3 (cache_position is not in the Qwen + # prefill ONNX inputs). Kept defensively for future decoder-only + # models whose OnnxConfig declares cache_position; see the + # StaticCache switching instructions at the top of hf/qwen.py for + # the position-alignment caveat before activating this branch. + if "cache_position" in self._prefill_expected: + feeds["cache_position"] = position_ids.squeeze(0) + for i in range(self._num_kv_layers): + feeds[f"past_{i}_key"] = cache.layers[i].keys.detach() + feeds[f"past_{i}_value"] = cache.layers[i].values.detach() + + outputs = self._prefill_model(**feeds) + + # Slice out padding — real tokens are at [pad_len : pad_len+chunk_len] + real = slice(pad_len, pad_len + chunk_len) + all_logits.append(outputs["logits"][:, real, :]) + + # Strip padding KV before updating cache so step advances by + # chunk_len (not prefill_seq_len). + real_outputs = {k: v for k, v in outputs.items() if not k.startswith("present_")} + for k, v in outputs.items(): + if k.startswith("present_"): + t = v if isinstance(v, torch.Tensor) else torch.tensor(v) + real_outputs[k] = t[:, :, real, :] + cache.update_all_layers(real_outputs) + + return torch.cat(all_logits, dim=1) + + # ----- Generation (single token) ----- + + def _run_gen(self, input_ids: torch.Tensor, cache: Any) -> torch.Tensor: + """Run gen model for a single token. Returns logits ``[1, 1, vocab_size]``.""" + fc = cache.step + attn_mask = cache.build_decoder_mask(self._max_cache_len) + + feeds: dict[str, Any] = { + "input_ids": input_ids, + "attention_mask": attn_mask, + "position_ids": torch.tensor([[fc]], dtype=torch.int64), + } + # NOTE: see the matching note in `_run_prefill` above. Currently dead + # for Qwen3 (cache_position is not in the gen ONNX inputs). Kept for + # future decoder-only models that declare cache_position in their + # OnnxConfig; activate with care re: the position-alignment caveat. + if "cache_position" in self._gen_expected: + feeds["cache_position"] = feeds["position_ids"].squeeze(0) + for i in range(self._num_kv_layers): + feeds[f"past_{i}_key"] = cache.layers[i].keys.detach() + feeds[f"past_{i}_value"] = cache.layers[i].values.detach() + + outputs = self._gen_model(**feeds) + cache.update_all_layers(outputs) + + return outputs["logits"] diff --git a/src/winml/modelkit/models/winml/encoder_decoder.py b/src/winml/modelkit/models/winml/encoder_decoder.py new file mode 100644 index 000000000..ae1b669fe --- /dev/null +++ b/src/winml/modelkit/models/winml/encoder_decoder.py @@ -0,0 +1,356 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""WinML Encoder-Decoder inference model and shared input generator. + +Class hierarchy:: + + WinMLCompositeModel — multi-component base + └─ WinMLEncoderDecoderModel(GenerationMixin) — encoder-decoder inference + ├─ WinMLT5Model (t5.py) — WinMLStaticCache + └─ WinMLMu2Model (mu2.py) — WinMLSlidingWindowCache + +How ``forward()`` works: + +1. Encoder runs once (via ``get_encoder()``), hidden states cached by + GenerationMixin across decode steps. + +2. Each decode step: ``_resolve_cache`` unwraps GenerationMixin's + ``EncoderDecoderCache`` wrapper (or creates a fresh ``WinMLCache`` + on first call). Cache type is determined by ``get_cache_class()``. + +3. Feeds are built from ``model_kwargs`` (decoder_input_ids, attention_mask) + plus generated inputs (encoder_hidden_states, decoder_attention_mask, + position input, KV buffers). ``pad_inputs`` filters to ONNX input + names and pads undersized tensors. + +4. After ONNX inference, ``cache.update_all_layers(outputs)`` writes + present KV back and advances step — fully polymorphic, no isinstance. + +Cache-type gotchas (lessons learned): + +- **GenerationMixin wraps cache**: On the first decode call, GenerationMixin + may pass an ``EncoderDecoderCache`` (not None). ``_resolve_cache`` must + unwrap it, and cache reset must check ``not isinstance(WinMLCache)``. + +- **Causal mask with seq_len=1**: ``torch.tril(ones(1, N))`` only keeps + column 0. For single-token KV-cached decoding, the decoder_attention_mask + alone is sufficient — no tril needed. + +- **Position inputs, two roles**: ``forward`` seeds ``cache_position`` from + ``cache.get_query_cache_position(...)`` (the query's *buffer index* — used by + HF's causal mask ``kv_idx <= q_idx`` and by T5's ``compute_bias``) and + ``position_id`` from the absolute sequence step (used by RoPE models). + ``pad_inputs`` then filters to whatever the decoder ONNX actually declares, + so T5 (consumes ``cache_position``) and Mu2 (consumes ``position_id``) share + the same wrapper code. + +- **T5 on sliding window**: Works without any ``compute_bias`` patch because + ``WinMLSlidingWindowCache.get_query_cache_position`` returns + ``[max_cache_len - 1]`` (the rightmost buffer slot). With that value, + ``memory_position - context_position = j - (W-1)`` yields the correct + negative distances for all buffer slots, and the 2D right-aligned mask + selects the filled region. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +import torch +from optimum.utils.input_generators import DummyInputGenerator +from transformers.generation.utils import GenerationMixin +from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput + +from ...utils.data_utils import pad_inputs +from .composite_model import WinMLCompositeModel + + +if TYPE_CHECKING: + from optimum.utils import NormalizedConfig + from transformers import Cache, PretrainedConfig + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# EncoderDecoderInputGenerator — shared dummy input generator +# ============================================================================= + + +class EncoderDecoderInputGenerator(DummyInputGenerator): + """Generates decoder base inputs for encoder-decoder models. + + Produces ``decoder_input_ids``, ``encoder_hidden_states``, + ``attention_mask`` (encoder), ``decoder_attention_mask``, and + ``cache_position``. Reads dimensions from ``NormalizedConfig``. + """ + + SUPPORTED_INPUT_NAMES = ( + "decoder_input_ids", + "encoder_hidden_states", + "attention_mask", + "decoder_attention_mask", + "cache_position", + "position_id", + ) + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + batch_size: int = 1, + max_cache_len: int | None = None, + sequence_length: int | None = None, + **kwargs: Any, + ) -> None: + self.batch_size = batch_size + self.d_model = normalized_config.hidden_size + self.enc_seq = sequence_length or getattr(normalized_config, "sequence_length", 16) + self.max_cache_len = max_cache_len or normalized_config.max_cache_len + self.vocab_size = normalized_config.vocab_size + + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ) -> torch.Tensor: + """Generate a dummy tensor for the given input name.""" + if input_name == "decoder_input_ids": + return self.random_int_tensor( + (self.batch_size, 1), + max_value=self.vocab_size, + framework=framework, + dtype=int_dtype, + ) + if input_name == "encoder_hidden_states": + return self.random_float_tensor( + (self.batch_size, self.enc_seq, self.d_model), + framework=framework, + dtype=float_dtype, + ) + if input_name == "attention_mask": + return torch.ones(self.batch_size, self.enc_seq, dtype=torch.int64) + if input_name == "decoder_attention_mask": + return torch.ones(self.batch_size, self.max_cache_len, dtype=torch.int64) + if input_name == "cache_position": + return torch.tensor([5], dtype=torch.int64) # arbitrary position for tracing + if input_name == "position_id": + return torch.tensor([5], dtype=torch.int64) # absolute seq position for RoPE + raise ValueError(f"Unknown input: {input_name}") + + +# ============================================================================= +# WinMLEncoderDecoderModel — encoder-decoder with StaticCache +# ============================================================================= + + +class WinMLEncoderDecoderModel(WinMLCompositeModel, GenerationMixin): + """composite model with HF GenerationMixin support. + + Expects sub-components ``"encoder"`` and ``"decoder"`` in + ``_SUB_MODEL_CONFIG``. Provides the full interface required by + ``GenerationMixin.generate()`` for encoder-decoder models with + static KV cache. + + Input/output names and shapes are read from ONNX I/O metadata — no + model-specific names are assumed. + """ + + main_input_name = "input_ids" + base_model_prefix = "" + _is_stateful = False + _supports_cache_class = False + + def __init__( + self, + sub_models: dict[str, Any], + config: PretrainedConfig, + ) -> None: + super().__init__(sub_models, config) + raw_encoder = sub_models["encoder"] + self._decoder = sub_models["decoder"] + + # Build {name: shape} lookups from ONNX I/O metadata + enc_io = raw_encoder.io_config + enc_expected = dict( + zip(enc_io.get("input_names", []), enc_io.get("input_shapes", []), strict=False) + ) + # Wrap encoder with auto-padding so all callsites just use self._encoder(...) + self._encoder = self._EncoderWithInputPadding(raw_encoder, enc_expected) + + dec_io = self._decoder.io_config + self._dec_expected = dict( + zip(dec_io.get("input_names", []), dec_io.get("input_shapes", []), strict=False) + ) + + # Max decode length and KV dtype from decoder ONNX metadata + self._max_dec = self._dec_expected["past_0_key"][2] + self._num_kv_layers = sum( + 1 for n in self._dec_expected if n.startswith("past_") and n.endswith("_key") + ) + # Resolve KV cache dtype from ONNX input types (fp32 or fp16) + dec_type_map = dict( + zip(dec_io.get("input_names", []), dec_io.get("input_types", []), strict=False) + ) + import numpy as np + + if "past_0_key" not in dec_type_map: + raise KeyError( + "'past_0_key' is missing from the decoder ONNX input type map; " + "cannot derive KV cache dtype. Verify the decoder ONNX was built with " + "PastKeyValueInputGenerator." + ) + _np_dtype = dec_type_map["past_0_key"] + self._kv_dtype = torch.from_numpy(np.zeros(1, dtype=_np_dtype)).dtype + + # ----- Encoder ----- + + class _EncoderWithInputPadding(torch.nn.Module): + """Wraps an encoder sub-model with auto-padding to ONNX expected shapes. + + Matches kwargs against ONNX input names, pads undersized tensors, + and forwards to the underlying WinMLAutoModel. Used as both + ``self._encoder`` (direct calls) and the return value of + ``get_encoder()`` (GenerationMixin contract). + """ + + def __init__(self, encoder: Any, expected: dict[str, list[int]]) -> None: + super().__init__() + self._encoder = encoder + self._expected = expected + + def forward(self, **kwargs: Any) -> BaseModelOutput: + feeds = pad_inputs(kwargs, self._expected) + return self._encoder(**feeds) + + def get_encoder(self) -> torch.nn.Module: + """Return encoder for GenerationMixin (already wrapped with padding).""" + return self._encoder + + def can_generate(self) -> bool: # noqa: D102 + return True + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Cache | None = None, + attention_mask: torch.Tensor | None = None, + encoder_outputs: BaseModelOutput | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Build decoder inputs for each generate() step.""" + from .kv_cache import WinMLCache + + if isinstance(past_key_values, WinMLCache) and past_key_values.get_seq_length() > 0: + decoder_input_ids = input_ids[:, -1:] + else: + decoder_input_ids = input_ids + return { + "decoder_input_ids": decoder_input_ids, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + } + + # ----- Cache management ----- + + @classmethod + def get_cache_class(cls) -> type: + """Return the WinMLCache subclass. Subclasses must override.""" + raise NotImplementedError + + def _resolve_cache(self, past_key_values: Any) -> Any: + """Unwrap or create the WinMLCache for this generation step. + + 1. Unwrap EncoderDecoderCache wrapper (GenerationMixin may add it). + 2. If already a WinMLCache, return directly. + 3. Otherwise create a fresh one and reset it. + """ + from .kv_cache import WinMLCache + + # (1) Unwrap EncoderDecoderCache + if hasattr(past_key_values, "self_attention_cache"): + past_key_values = past_key_values.self_attention_cache + + # (2) Already our cache — return as-is + if isinstance(past_key_values, WinMLCache): + return past_key_values + + # (3) Create fresh cache and reset + kv_shape = self._dec_expected["past_0_key"] + cache = self.get_cache_class().create(self.config, kv_shape, self._kv_dtype) + cache.reset() + return cache + + # ----- Forward (decoder via WinMLAutoModel + KV cache) ----- + + def forward( + self, + *, + encoder_outputs: BaseModelOutput | tuple | None = None, + past_key_values: Cache | None = None, + input_ids: torch.Tensor | None = None, + **model_kwargs: Any, + ) -> Seq2SeqLMOutput: + """Run decoder with a WinML KV cache. + + Uses ``WinMLStaticCache`` or ``WinMLSlidingWindowCache``, selected by + the subclass via ``get_cache_class()``. + + Args: + encoder_outputs: Pre-computed encoder hidden states. + past_key_values: ``WinMLCache`` (or ``EncoderDecoderCache`` + wrapper) from previous step. + input_ids: Fallback — run encoder if encoder_outputs is None. + **model_kwargs: Remaining kwargs forwarded to the decoder ONNX + (e.g., decoder_input_ids, attention_mask). Each tensor is + auto-padded to match the ONNX model's expected input shape. + """ + # Encoder hidden states + if encoder_outputs is None and input_ids is not None: + encoder_outputs = self._encoder(input_ids=input_ids, **model_kwargs) + if encoder_outputs is None: + raise ValueError("Either encoder_outputs or input_ids required") + enc_h = encoder_outputs["last_hidden_state"] + + # Resolve or create cache (subclasses override get_cache_class). + cache = self._resolve_cache(past_key_values) + + fc = cache.step + dec_mask = cache.build_decoder_mask(self._max_dec) + + feeds: dict[str, Any] = dict(model_kwargs) + feeds.setdefault("encoder_hidden_states", enc_h.detach()) + feeds.setdefault("decoder_attention_mask", dec_mask) + # Feed all position-like names; pad_inputs filters to self._dec_expected. + # Decouples the cache class from the decoder ONNX's chosen input name. + # + # "cache_position": buffer index of the query token — used by HF's + # create_causal_mask (``kv_idx <= q_idx``) and by T5.compute_bias. + # For WinMLStaticCache this equals ``step`` (buffer == seq position); + # for WinMLSlidingWindowCache it is the rightmost buffer slot(s). + # "position_id": absolute sequence position — used by RoPE-based models + # (Mu2) that compute positional encoding from the actual seq position. + cache_pos = cache.get_query_cache_position(self._max_dec).to(torch.int64) + seq_pos = torch.tensor([fc], dtype=torch.int64) + feeds.setdefault("cache_position", cache_pos) + feeds.setdefault("position_id", seq_pos) + for i in range(self._num_kv_layers): + feeds[f"past_{i}_key"] = cache.layers[i].keys.detach() + feeds[f"past_{i}_value"] = cache.layers[i].values.detach() + + # Run decoder ONNX (pad_inputs filters to expected names + pads) + outputs = self._decoder(**pad_inputs(feeds, self._dec_expected)) + + # Write present KV back and advance step + cache.update_all_layers(outputs) + + return Seq2SeqLMOutput( + logits=outputs["logits"], + past_key_values=cache, + ) diff --git a/src/winml/modelkit/models/winml/kv_cache.py b/src/winml/modelkit/models/winml/kv_cache.py new file mode 100644 index 000000000..870659b03 --- /dev/null +++ b/src/winml/modelkit/models/winml/kv_cache.py @@ -0,0 +1,377 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""WinML KV cache classes for ONNX export and inference. + +Hierarchy:: + + StaticCache (HF transformers) + └─ WinMLCache — common interface + ├─ WinMLStaticCache — ScatterElements (index_copy_), T5/Qwen + └─ WinMLSlidingWindowCache — Slice+Concat (FIFO), Mu2 + +Cache type compatibility: + +- **WinMLStaticCache**: ``index_copy_`` at ``cache_position`` keeps + ``buffer_position == sequence_position``. Cannot evict — ``max_cache_len`` + must be ≥ total generated tokens. + +- **WinMLSlidingWindowCache**: Slice+Concat eviction; works for RoPE models + (Mu2, Qwen, Llama) where position is baked into K when K is computed, and + for learned relative position bias (T5) as long as the wrapper feeds + ``cache_position`` as the query's *buffer index* (see + ``get_query_cache_position``). The invariant ``cache_position = buffer_idx + of query`` makes ``j - cache_position`` the correct relative distance for + both cache types, so no per-model compute_bias patch is required. + +Common interface (called by ``WinMLEncoderDecoderModel.forward``): + +- ``position_input_name``: ONNX input name (``"cache_position"`` or ``"position_id"``) +- ``build_decoder_mask(max_len)``: 2D attention mask for current step +- ``get_query_cache_position(max_len)``: buffer indices of query tokens + (used by HF's ``create_causal_mask`` and by T5's ``compute_bias``) +- ``update_all_layers(outputs)``: write present KV from ONNX output, advance step +- ``reset()``: zero out for new generation +- ``create(config, kv_shape, dtype)``: factory from ONNX metadata + +Also provides ``PastKeyValueInputGenerator`` — a reusable ``DummyInputGenerator`` +for static KV cache inputs (``past_{i}_key``, ``past_{i}_value``). +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, ClassVar + +from optimum.utils.input_generators import DummyInputGenerator +from transformers import StaticCache + + +if TYPE_CHECKING: + import torch + from optimum.utils import NormalizedConfig + from transformers import PretrainedConfig + + +# ============================================================================= +# WinMLCache — common interface +# ============================================================================= + + +class WinMLCache(StaticCache, ABC): + """Abstract base for WinML KV caches (export + inference). + + Subclasses set ``position_input_name``, implement ``build_decoder_mask``, + and override ``update()`` for cache-specific write logic. + + ``step`` tracks the absolute generation position + (used for RoPE and mask construction). + ``num_layers`` is set from ``config.num_hidden_layers``. + """ + + #: ONNX input name for the position tensor (subclasses override). + #: Empty string is a sentinel — concrete subclasses must set a real value. + position_input_name: ClassVar[str] = "" + + def __init__(self, config: PretrainedConfig, *args: Any, **kwargs: Any) -> None: + super().__init__(config, *args, **kwargs) + self.step: int = 0 + self.num_layers: int = config.num_hidden_layers + #: New-token KV captured during ``update()``, keyed by layer index. + #: Export wrappers read ``captured[i]`` to build ONNX present outputs. + self.captured: dict[int, tuple[torch.Tensor, torch.Tensor]] = {} + + # ----- Interface for WinMLEncoderDecoderModel.forward ----- + + @abstractmethod + def build_decoder_mask(self, max_len: int, num_new_tokens: int = 1) -> torch.Tensor: + """Build the decoder attention mask for the current step. + + Args: + max_len: Total cache buffer length. + num_new_tokens: Number of new tokens being added (1 for gen, + chunk_len for prefill). + """ + + @abstractmethod + def get_query_cache_position(self, max_len: int, num_new_tokens: int = 1) -> torch.Tensor: + """Buffer indices of the query tokens for HF's ``cache_position`` input. + + HF's ``create_causal_mask`` uses ``cache_position`` as the query's + *buffer index* (``kv_idx <= q_idx``). For static cache the buffer index + equals the sequence position (``step``); for sliding window it is the + rightmost slot(s) because new tokens are written at the right end. + + Returns: + ``[num_new_tokens]`` int64 tensor of buffer positions for the new + tokens being processed this step. + """ + + @abstractmethod + def prepare_prefill_chunk( + self, + chunk_ids: torch.Tensor, + start: int, + prefill_seq_len: int, + ) -> tuple[torch.Tensor, torch.Tensor, int]: + """Pad tokens and build position IDs for one prefill chunk. + + Args: + chunk_ids: ``[1, chunk_len]`` — real tokens for this chunk. + start: Absolute position of the first real token. + prefill_seq_len: ONNX model's fixed prefill input length. + + Returns: + padded_ids: ``[1, prefill_seq_len]`` — padded input token IDs. + position_ids: ``[1, prefill_seq_len]`` — position encoding input. + pad_len: Number of leading padding positions (0 for right-pad). + """ + + def update_all_layers(self, outputs: dict[str, Any]) -> None: + """Write present KV for all layers via ``update()`` and advance step. + + Step advances by N where N is the seq_len of the present KV tensors + (1 for gen, chunk_len for prefill). + """ + import torch + + n = 0 + for i in range(self.num_layers): + k = outputs[f"present_{i}_key"] + v = outputs[f"present_{i}_value"] + k = k if isinstance(k, torch.Tensor) else torch.tensor(k) + v = v if isinstance(v, torch.Tensor) else torch.tensor(v) + n = k.size(2) + ck = {"cache_position": torch.arange(self.step, self.step + n, dtype=torch.int64)} + self.update(k, v, i, cache_kwargs=ck) + self.step += n + + def reset(self) -> None: + """Zero out all layers and reset step (start of new generation).""" + self.step = 0 + self.captured.clear() + for i in range(self.num_layers): + self.layers[i].keys.zero_() + self.layers[i].values.zero_() + + @classmethod + def create( + cls, config: PretrainedConfig, kv_shape: list[int], dtype: torch.dtype + ) -> WinMLCache: + """Create and initialize a cache from ONNX KV shape metadata. + + Args: + config: HF model config (must have ``num_hidden_layers``). + kv_shape: ``[batch, heads, max_cache_len, head_dim]`` from ONNX. + dtype: KV dtype (fp32 or fp16). + """ + import torch + + cache = cls(config, max_cache_len=kv_shape[2]) + cache.early_initialization( + batch_size=1, + num_heads=kv_shape[1], + head_dim=kv_shape[3], + dtype=dtype, + device=torch.device("cpu"), + ) + return cache + + +# ============================================================================= +# WinMLStaticCache — ScatterElements (index_copy_) +# ============================================================================= + + +class WinMLStaticCache(WinMLCache): + """Cache using ``index_copy_`` at ``cache_position`` (ScatterElements). + + **Export**: intercepts ``update()`` to capture incoming KV for ONNX output. + **Inference**: ``update_all_layers`` writes new-token KV at the current step. + Mask is left-aligned: ``[1, 1, ..., 1, 0, 0, ..., 0]``. + """ + + position_input_name: str = "cache_position" + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Capture new-token KV, then delegate to parent ``index_copy_``.""" + self.captured[layer_idx] = (key_states, value_states) + return super().update(key_states, value_states, layer_idx, cache_kwargs) + + def build_decoder_mask(self, max_len: int, num_new_tokens: int = 1) -> torch.Tensor: + """Left-aligned: first ``step + num_new_tokens`` positions are 1.""" + import torch + + mask = torch.zeros(1, max_len, dtype=torch.int64) + mask[0, : self.step + num_new_tokens] = 1 + return mask + + def get_query_cache_position(self, max_len: int, num_new_tokens: int = 1) -> torch.Tensor: + """Buffer index == sequence position for static cache: ``[step..step+N)``.""" + import torch + + return torch.arange(self.step, self.step + num_new_tokens, dtype=torch.int64) + + def prepare_prefill_chunk( + self, + chunk_ids: torch.Tensor, + start: int, + prefill_seq_len: int, + ) -> tuple[torch.Tensor, torch.Tensor, int]: + """Right-pad: real tokens at START, padding at end.""" + import torch + + chunk_len = chunk_ids.shape[1] + padded_ids = torch.zeros(1, prefill_seq_len, dtype=chunk_ids.dtype) + padded_ids[0, :chunk_len] = chunk_ids[0] + + position_ids = torch.arange(start, start + prefill_seq_len, dtype=torch.int64).unsqueeze(0) + + return padded_ids, position_ids, 0 + + +# ============================================================================= +# WinMLSlidingWindowCache — Slice + Concat (FIFO) +# ============================================================================= + + +class WinMLSlidingWindowCache(WinMLCache): + """FIFO cache: evict oldest, append new at end (Slice+Concat). + + **Export**: ``update()`` does Slice+Concat on the buffer and captures + the new-token KV (same as ``WinMLStaticCache.captured``). Present KV + output is the new token only ``[batch, heads, 1, head_dim]``. + **Inference**: ``update_all_layers`` does Slice+Concat from present KV. + Mask is right-aligned: ``[0, 0, ..., 0, 1, 1, ..., 1]``. + """ + + position_input_name: str = "position_id" + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Drop N oldest, append N new KV at end (N = key_states.size(2)). + + Works for both single-token gen (N=1) and multi-token prefill (N>1). + """ + import torch + + self.captured[layer_idx] = (key_states, value_states) + + n = key_states.size(2) + old_k = self.layers[layer_idx].keys[:, :, n:, :] + new_k = torch.cat([old_k, key_states], dim=2) + self.layers[layer_idx].keys = new_k + + old_v = self.layers[layer_idx].values[:, :, n:, :] + new_v = torch.cat([old_v, value_states], dim=2) + self.layers[layer_idx].values = new_v + + return new_k, new_v + + def build_decoder_mask(self, max_len: int, num_new_tokens: int = 1) -> torch.Tensor: + """Right-aligned: rightmost ``step + num_new_tokens`` positions are 1.""" + import torch + + filled = min(self.step + num_new_tokens, max_len) + mask = torch.zeros(1, max_len, dtype=torch.int64) + mask[0, max(0, max_len - filled) :] = 1 + return mask + + def get_query_cache_position(self, max_len: int, num_new_tokens: int = 1) -> torch.Tensor: + """Query tokens sit at the rightmost ``num_new_tokens`` buffer slots. + + Because new tokens are always written at the right end of the buffer + (Slice+Concat), the query's buffer index is ``[max_len-N..max_len)`` — + independent of the absolute sequence position. HF's causal mask + then allows attention to every prior buffer slot (``j <= max_len-1``), + and the 2D ``build_decoder_mask`` selects the filled region within that. + """ + import torch + + return torch.arange(max_len - num_new_tokens, max_len, dtype=torch.int64) + + def prepare_prefill_chunk( + self, + chunk_ids: torch.Tensor, + start: int, + prefill_seq_len: int, + ) -> tuple[torch.Tensor, torch.Tensor, int]: + """Left-pad: padding at start, real tokens at END.""" + import torch + + chunk_len = chunk_ids.shape[1] + pad_len = prefill_seq_len - chunk_len + + padded_ids = torch.zeros(1, prefill_seq_len, dtype=chunk_ids.dtype) + padded_ids[0, pad_len:] = chunk_ids[0] + + # Padding positions get 0 — RoPE computes embeddings for position 0 on these, + # but the attention mask at build_decoder_mask masks them out before softmax, + # so the RoPE artifacts don't influence outputs. + position_ids = torch.zeros(1, prefill_seq_len, dtype=torch.int64) + position_ids[0, pad_len:] = torch.arange(start, start + chunk_len, dtype=torch.int64) + + return padded_ids, position_ids, pad_len + + def get_seq_length(self, layer_idx: int = 0) -> int: + """Filled positions: ``min(step, max_cache_len)``.""" + max_len = self.layers[layer_idx].keys.shape[2] + return min(self.step, max_len) + + +# ============================================================================= +# PastKeyValueInputGenerator +# ============================================================================= + + +class PastKeyValueInputGenerator(DummyInputGenerator): + """Generates ``past_{i}_key`` / ``past_{i}_value`` tensors for static KV cache. + + Reads ``num_layers``, ``num_attention_heads``, ``head_dim``, and + ``max_cache_len`` from the ``NormalizedConfig``. + """ + + SUPPORTED_INPUT_NAMES = () # dynamic — built in __init__ + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + batch_size: int = 1, + max_cache_len: int | None = None, + **kwargs: Any, + ) -> None: + self.batch_size = batch_size + self.num_layers: int = normalized_config.num_layers + self.num_heads: int = normalized_config.num_attention_heads + self.head_dim: int = normalized_config.head_dim + self.max_cache_len: int = max_cache_len or normalized_config.max_cache_len + self.SUPPORTED_INPUT_NAMES = tuple( + name for i in range(self.num_layers) for name in (f"past_{i}_key", f"past_{i}_value") + ) + + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ) -> torch.Tensor: + """Return a random float tensor of shape ``[batch, heads, max_cache_len, head_dim]``.""" + return self.random_float_tensor( + (self.batch_size, self.num_heads, self.max_cache_len, self.head_dim), + framework=framework, + dtype=float_dtype, + ) diff --git a/src/winml/modelkit/optim/capabilities/surgery.py b/src/winml/modelkit/optim/capabilities/surgery.py index 7d851c6fa..8b2048f00 100644 --- a/src/winml/modelkit/optim/capabilities/surgery.py +++ b/src/winml/modelkit/optim/capabilities/surgery.py @@ -26,3 +26,14 @@ category=CapabilityCategory.SURGERY, default=False, ) + +# Remove Softmax -> IsNaN -> Where NaN guard patterns in attention. +# These guards are dead code when clamp_constant_values replaces -inf +# with a finite value (Softmax never produces NaN). +REMOVE_ISNAN_IN_ATTENTION_MASK = BoolCapability( + name="remove-isnan-in-attention-mask", + ort_name=None, # Custom implementation, not ORT optimizer + description="Remove Softmax->IsNaN->Where NaN guard patterns in attention", + category=CapabilityCategory.SURGERY, + default=False, +) diff --git a/src/winml/modelkit/optim/pipes/surgery.py b/src/winml/modelkit/optim/pipes/surgery.py index 0a90227b2..fa4fa6bcf 100644 --- a/src/winml/modelkit/optim/pipes/surgery.py +++ b/src/winml/modelkit/optim/pipes/surgery.py @@ -37,6 +37,7 @@ SURGERY_CAPABILITIES: dict[str, Any] = caps_dict( surgery.CLAMP_CONSTANT_VALUES, + surgery.REMOVE_ISNAN_IN_ATTENTION_MASK, ) @@ -53,12 +54,16 @@ class SurgeryPipeConfig(PipeConfig): clamp_constant_values: Whether to clamp extreme float constants clamp_min: Minimum value for constant clamping (default: -1e3) clamp_max: Maximum value for constant clamping (default: 1e3) + fix_nan_attention_mask: Replace -inf attention mask with finite value + and remove Softmax->IsNaN->Where NaN guard patterns + mask_value: Replacement value for -inf (default: -1e3) verbose: Enable verbose logging """ clamp_constant_values: bool = False clamp_min: float = -1e3 clamp_max: float = 1e3 + remove_isnan_in_attention_mask: bool = False verbose: bool = False @@ -90,6 +95,7 @@ def build_config(cls, **kwargs: Any) -> SurgeryPipeConfig: - clamp_constant_values: Enable/disable constant clamping - clamp_min: Minimum value for clamping (default: -1e3) - clamp_max: Maximum value for clamping (default: 1e3) + - remove_isnan_in_attention_mask: Remove IsNaN guard patterns - verbose: Enable verbose logging Returns: @@ -99,6 +105,7 @@ def build_config(cls, **kwargs: Any) -> SurgeryPipeConfig: clamp_constant_values=kwargs.get("clamp_constant_values", False), clamp_min=kwargs.get("clamp_min", -1e3), clamp_max=kwargs.get("clamp_max", 1e3), + remove_isnan_in_attention_mask=kwargs.get("remove_isnan_in_attention_mask", False), verbose=kwargs.get("verbose", False), ) @@ -112,7 +119,7 @@ def should_process(cls, config: SurgeryPipeConfig) -> bool: Returns: True if any surgery operation is enabled """ - return config.clamp_constant_values + return config.clamp_constant_values or config.remove_isnan_in_attention_mask def process(self, model: onnx.ModelProto, config: SurgeryPipeConfig) -> onnx.ModelProto: """Apply surgery operations to the model. @@ -139,6 +146,9 @@ def process(self, model: onnx.ModelProto, config: SurgeryPipeConfig) -> onnx.Mod model_copy, config.clamp_min, config.clamp_max, config.verbose ) + if config.remove_isnan_in_attention_mask: + model_copy = self._remove_isnan_in_attention_mask(model_copy, config.verbose) + return model_copy def _clamp_constant_values( @@ -219,3 +229,93 @@ def _clamp_constant_values( logger.debug("Clamped tensors: %s", clamped_tensors) return model + + # ----------------------------------------------------------------- + # remove-isnan-in-attention-mask + # ----------------------------------------------------------------- + + def _remove_isnan_in_attention_mask( + self, + model: onnx.ModelProto, + verbose: bool = False, + ) -> onnx.ModelProto: + """Remove Softmax → IsNaN → Where NaN guard patterns in attention. + + Pattern: Softmax → IsNaN → Where(isnan, 0, softmax_out) + Remove IsNaN + guard Where, use Softmax output directly. + + These guards are dead code when clamp_constant_values has already + replaced -inf with a finite value (Softmax never produces NaN). + + Args: + model: ONNX model (modified in place). + verbose: Log details about each removal. + + Returns: + Model with IsNaN guard patterns removed. + """ + guard_count = 0 + + # Build output→node map + output_to_node: dict[str, onnx.NodeProto] = {} + for node in model.graph.node: + for out in node.output: + output_to_node[out] = node + + nodes_to_remove: list[onnx.NodeProto] = [] + rewire_map: dict[str, str] = {} + + for node in list(model.graph.node): + if node.op_type != "IsNaN": + continue + producer = output_to_node.get(node.input[0]) + if producer is None or producer.op_type != "Softmax": + continue + softmax_out = producer.output[0] + isnan_out = node.output[0] + + # Find guard Where consuming IsNaN output + guard_wheres = [ + n for n in model.graph.node if n.op_type == "Where" and isnan_out in n.input + ] + if len(guard_wheres) != 1: + continue + guard_where = guard_wheres[0] + if softmax_out not in guard_where.input: + continue + + guard_out = guard_where.output[0] + nodes_to_remove.extend([node, guard_where]) + rewire_map[guard_out] = softmax_out + guard_count += 1 + if verbose: + logger.info( + " remove-isnan: remove %s + %s, rewire %s -> %s", + node.name, + guard_where.name, + guard_out, + softmax_out, + ) + + # Apply rewiring + for node in model.graph.node: + for i, inp in enumerate(node.input): + if inp in rewire_map: + node.input[i] = rewire_map[inp] + for out in model.graph.output: + if out.name in rewire_map: + out.name = rewire_map[out.name] + + # Remove dead nodes + remove_ids = {id(n) for n in nodes_to_remove} + remaining = [n for n in model.graph.node if id(n) not in remove_ids] + del model.graph.node[:] + model.graph.node.extend(remaining) + + if guard_count: + logger.info( + "SurgeryPipe: remove-isnan-in-attention-mask: %d IsNaN+Where guards removed", + guard_count, + ) + + return model diff --git a/src/winml/modelkit/utils/data_utils.py b/src/winml/modelkit/utils/data_utils.py new file mode 100644 index 000000000..bccf7acd1 --- /dev/null +++ b/src/winml/modelkit/utils/data_utils.py @@ -0,0 +1,64 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Data utilities for input preparation and padding.""" + +from __future__ import annotations + +from typing import Any, Literal + +import torch + + +def pad_inputs( + source: dict[str, Any], + expected: dict[str, list[int]], + mode: Literal["left", "right"] = "right", +) -> dict[str, Any]: + """Filter *source* to keys in *expected* and pad undersized tensors. + + For each name in *expected*, if *source* has a tensor for it, pad any + dimension smaller than the ONNX expected shape (skips batch dim). + Non-tensor values are passed through. Missing names are skipped. + + Args: + source: Input tensors keyed by name. + expected: ONNX expected shapes keyed by input name. + mode: Padding side — ``"right"`` (default, pad at end) or + ``"left"`` (pad at start). + + Returns: + Filtered and padded tensors matching *expected* keys. + """ + if mode not in ("right", "left"): + raise ValueError(f"mode must be 'right' or 'left', got {mode!r}") + + result: dict[str, Any] = {} + for name, expected_shape in expected.items(): + val = source.get(name) + if val is None: + continue + if isinstance(val, torch.Tensor): + # TODO: support dynamic shape ONNX models (None in expected_shape) + ndim = min(len(val.shape), len(expected_shape)) + # torch.nn.functional.pad takes pairs (low, high) from the LAST + # dim backwards. Skip batch dim (dim 0). + pad: list[int] = [] + for dim in reversed(range(1, ndim)): + exp = expected_shape[dim] + # Dynamic ONNX dims may be None or a string symbol; emit a + # (0, 0) pair so later pairs stay aligned with their dim index. + if not isinstance(exp, int): + pad.extend([0, 0]) + continue + deficit = max(exp - val.shape[dim], 0) + if mode == "right": + pad.extend([0, deficit]) + else: # left + pad.extend([deficit, 0]) + if any(p > 0 for p in pad): + val = torch.nn.functional.pad(val, pad) + result[name] = val + return result diff --git a/tests/unit/export/test_io.py b/tests/unit/export/test_io.py index 66fe5d3e1..20dbdcdc4 100644 --- a/tests/unit/export/test_io.py +++ b/tests/unit/export/test_io.py @@ -13,11 +13,13 @@ from __future__ import annotations +from types import SimpleNamespace from unittest.mock import patch import pytest import torch from transformers import ( + AutoConfig, CLIPTextConfig, CLIPTextModelWithProjection, CLIPVisionConfig, @@ -36,6 +38,7 @@ _get_onnx_config, _populate_image_size_from_preprocessor, ) +from winml.modelkit.models.winml.kv_cache import PastKeyValueInputGenerator # ============================================================================= @@ -672,3 +675,289 @@ def test_no_size_key_in_config(self) -> None: assert "height" not in shape_kwargs assert "width" not in shape_kwargs + + +# ============================================================================= +# PastKeyValueInputGenerator — shared KV cache dummy input generation +# ============================================================================= + + +def _make_normalized_config( + num_layers: int = 4, + num_attention_heads: int = 2, + head_dim: int = 32, + max_cache_len: int = 16, +) -> SimpleNamespace: + """Create a lightweight object that quacks like NormalizedConfig.""" + return SimpleNamespace( + num_layers=num_layers, + num_attention_heads=num_attention_heads, + head_dim=head_dim, + max_cache_len=max_cache_len, + ) + + +@pytest.fixture(scope="module") +def t5_config(): + """T5-small config with n_positions overridden to 32 for fast tests.""" + cfg = AutoConfig.from_pretrained("google-t5/t5-small") + cfg.n_positions = 32 + return cfg + + +@pytest.fixture(scope="module") +def qwen_config(): + """Qwen3-0.6B config with max_position_embeddings overridden to 256.""" + cfg = AutoConfig.from_pretrained("Qwen/Qwen3-0.6B") + cfg.max_position_embeddings = 256 + return cfg + + +class TestPastKeyValueInputGenerator: + """Direct tests for PastKeyValueInputGenerator.""" + + def test_supported_input_names(self) -> None: + nc = _make_normalized_config(num_layers=3) + gen = PastKeyValueInputGenerator("text-generation", nc) + expected = ( + "past_0_key", + "past_0_value", + "past_1_key", + "past_1_value", + "past_2_key", + "past_2_value", + ) + assert expected == gen.SUPPORTED_INPUT_NAMES + + def test_generate_key_shape(self) -> None: + nc = _make_normalized_config( + num_layers=2, + num_attention_heads=4, + head_dim=16, + max_cache_len=64, + ) + gen = PastKeyValueInputGenerator("text-generation", nc, batch_size=2) + tensor = gen.generate("past_0_key") + assert tensor.shape == (2, 4, 64, 16) + + def test_generate_value_shape(self) -> None: + nc = _make_normalized_config( + num_layers=2, + num_attention_heads=4, + head_dim=16, + max_cache_len=64, + ) + gen = PastKeyValueInputGenerator("text-generation", nc, batch_size=1) + tensor = gen.generate("past_1_value") + assert tensor.shape == (1, 4, 64, 16) + + def test_generate_returns_float_tensor(self) -> None: + nc = _make_normalized_config() + gen = PastKeyValueInputGenerator("text-generation", nc) + tensor = gen.generate("past_0_key") + assert isinstance(tensor, torch.Tensor) + assert tensor.dtype == torch.float32 + + def test_single_layer(self) -> None: + nc = _make_normalized_config(num_layers=1) + gen = PastKeyValueInputGenerator("text-generation", nc) + assert gen.SUPPORTED_INPUT_NAMES == ("past_0_key", "past_0_value") + + def test_batch_size_propagated(self) -> None: + nc = _make_normalized_config() + gen = PastKeyValueInputGenerator("text-generation", nc, batch_size=8) + assert gen.batch_size == 8 + tensor = gen.generate("past_0_key") + assert tensor.shape[0] == 8 + + +class TestT5DecoderKVInputs: + """T5 decoder dummy inputs use PastKeyValueInputGenerator.""" + + def test_kv_input_names(self, t5_config) -> None: + inputs = generate_dummy_inputs("t5", "text2text-generation", t5_config) + num_layers = t5_config.num_layers # 6 + for i in range(num_layers): + assert f"past_{i}_key" in inputs + assert f"past_{i}_value" in inputs + + def test_kv_shape(self, t5_config) -> None: + inputs = generate_dummy_inputs("t5", "text2text-generation", t5_config) + kv = inputs["past_0_key"] + # [batch=1, heads=8, max_cache_len=32, d_kv=64] + assert kv.shape == (1, t5_config.num_heads, 32, t5_config.d_kv) + + def test_decoder_attention_mask_matches_cache_len(self, t5_config) -> None: + inputs = generate_dummy_inputs("t5", "text2text-generation", t5_config) + assert inputs["decoder_attention_mask"].shape[1] == 32 + + def test_all_kv_layers_present(self, t5_config) -> None: + inputs = generate_dummy_inputs("t5", "text2text-generation", t5_config) + kv_names = [n for n in inputs if n.startswith("past_")] + assert len(kv_names) == t5_config.num_layers * 2 + + +class TestQwenPrefillKVInputs: + """Qwen3 prefill dummy inputs use PastKeyValueInputGenerator.""" + + def test_kv_input_names(self, qwen_config) -> None: + inputs = generate_dummy_inputs("qwen3", "feature-extraction", qwen_config) + num_layers = qwen_config.num_hidden_layers # 28 + for i in range(num_layers): + assert f"past_{i}_key" in inputs + assert f"past_{i}_value" in inputs + + def test_kv_shape(self, qwen_config) -> None: + inputs = generate_dummy_inputs("qwen3", "feature-extraction", qwen_config) + kv = inputs["past_0_key"] + # [batch=1, kv_heads=8, max_cache_len=256, head_dim=128] + assert kv.shape == (1, qwen_config.num_key_value_heads, 256, qwen_config.head_dim) + + def test_attention_mask_matches_cache_len(self, qwen_config) -> None: + inputs = generate_dummy_inputs("qwen3", "feature-extraction", qwen_config) + assert inputs["attention_mask"].shape[1] == 256 + + +class TestQwenGenKVInputs: + """Qwen3 generation dummy inputs use PastKeyValueInputGenerator.""" + + def test_kv_shape_matches_prefill(self, qwen_config) -> None: + inputs = generate_dummy_inputs("qwen3", "text-generation", qwen_config) + kv = inputs["past_0_key"] + assert kv.shape == (1, qwen_config.num_key_value_heads, 256, qwen_config.head_dim) + + def test_input_ids_single_token(self, qwen_config) -> None: + inputs = generate_dummy_inputs("qwen3", "text-generation", qwen_config) + assert inputs["input_ids"].shape == (1, 1) + + +# ============================================================================= +# WinMLCache — build_decoder_mask and prepare_prefill_chunk +# ============================================================================= + + +def _make_cache(cls, num_layers=2, num_heads=2, max_cache_len=16, head_dim=8): + """Create a WinMLCache instance with minimal config. + + Uses a real PretrainedConfig subclass because HF StaticCache.__init__ + calls config.get_text_config(). + """ + from transformers import PretrainedConfig + + config = PretrainedConfig(num_hidden_layers=num_layers) + cache = cls.create(config, [1, num_heads, max_cache_len, head_dim], torch.float32) + cache.reset() + return cache + + +class TestStaticCacheBuildDecoderMask: + """WinMLStaticCache.build_decoder_mask — left-aligned mask.""" + + def test_default_single_token(self) -> None: + from winml.modelkit.models.winml.kv_cache import WinMLStaticCache + + cache = _make_cache(WinMLStaticCache) + cache.step = 3 + mask = cache.build_decoder_mask(16) + assert mask.shape == (1, 16) + assert mask[0, :4].tolist() == [1, 1, 1, 1] + assert mask[0, 4:].sum().item() == 0 + + def test_num_new_tokens(self) -> None: + from winml.modelkit.models.winml.kv_cache import WinMLStaticCache + + cache = _make_cache(WinMLStaticCache) + cache.step = 2 + mask = cache.build_decoder_mask(16, num_new_tokens=4) + assert mask[0, :6].tolist() == [1, 1, 1, 1, 1, 1] + assert mask[0, 6:].sum().item() == 0 + + +class TestSlidingWindowCacheBuildDecoderMask: + """WinMLSlidingWindowCache.build_decoder_mask — right-aligned mask.""" + + def test_default_single_token(self) -> None: + from winml.modelkit.models.winml.kv_cache import WinMLSlidingWindowCache + + cache = _make_cache(WinMLSlidingWindowCache) + cache.step = 3 + mask = cache.build_decoder_mask(16) + # rightmost 4 positions should be 1 + assert mask[0, -4:].tolist() == [1, 1, 1, 1] + assert mask[0, :-4].sum().item() == 0 + + def test_num_new_tokens(self) -> None: + from winml.modelkit.models.winml.kv_cache import WinMLSlidingWindowCache + + cache = _make_cache(WinMLSlidingWindowCache) + cache.step = 2 + mask = cache.build_decoder_mask(16, num_new_tokens=4) + # rightmost 6 positions + assert mask[0, -6:].tolist() == [1, 1, 1, 1, 1, 1] + assert mask[0, :-6].sum().item() == 0 + + def test_saturates_at_max_len(self) -> None: + from winml.modelkit.models.winml.kv_cache import WinMLSlidingWindowCache + + cache = _make_cache(WinMLSlidingWindowCache, max_cache_len=8) + cache.step = 10 + mask = cache.build_decoder_mask(8, num_new_tokens=4) + # min(10+4, 8)=8 → all 1s + assert mask[0].sum().item() == 8 + + +class TestStaticCachePreparePrefillChunk: + """WinMLStaticCache.prepare_prefill_chunk — right-pad.""" + + def test_full_chunk_no_padding(self) -> None: + from winml.modelkit.models.winml.kv_cache import WinMLStaticCache + + cache = _make_cache(WinMLStaticCache) + chunk = torch.tensor([[10, 20, 30, 40]]) + padded_ids, pos_ids, pad_len = cache.prepare_prefill_chunk( + chunk, start=0, prefill_seq_len=4 + ) + assert pad_len == 0 + assert padded_ids[0].tolist() == [10, 20, 30, 40] + assert pos_ids[0].tolist() == [0, 1, 2, 3] + + def test_partial_chunk_right_padded(self) -> None: + from winml.modelkit.models.winml.kv_cache import WinMLStaticCache + + cache = _make_cache(WinMLStaticCache) + chunk = torch.tensor([[10, 20]]) + padded_ids, pos_ids, pad_len = cache.prepare_prefill_chunk( + chunk, start=4, prefill_seq_len=4 + ) + assert pad_len == 0 + assert padded_ids[0, :2].tolist() == [10, 20] + assert padded_ids[0, 2:].tolist() == [0, 0] + assert pos_ids[0].tolist() == [4, 5, 6, 7] + + +class TestSlidingWindowCachePreparePrefillChunk: + """WinMLSlidingWindowCache.prepare_prefill_chunk — left-pad.""" + + def test_full_chunk_no_padding(self) -> None: + from winml.modelkit.models.winml.kv_cache import WinMLSlidingWindowCache + + cache = _make_cache(WinMLSlidingWindowCache) + chunk = torch.tensor([[10, 20, 30, 40]]) + padded_ids, pos_ids, pad_len = cache.prepare_prefill_chunk( + chunk, start=0, prefill_seq_len=4 + ) + assert pad_len == 0 + assert padded_ids[0].tolist() == [10, 20, 30, 40] + assert pos_ids[0].tolist() == [0, 1, 2, 3] + + def test_partial_chunk_left_padded(self) -> None: + from winml.modelkit.models.winml.kv_cache import WinMLSlidingWindowCache + + cache = _make_cache(WinMLSlidingWindowCache) + chunk = torch.tensor([[10, 20]]) + padded_ids, pos_ids, pad_len = cache.prepare_prefill_chunk( + chunk, start=4, prefill_seq_len=4 + ) + assert pad_len == 2 + assert padded_ids[0].tolist() == [0, 0, 10, 20] + assert pos_ids[0].tolist() == [0, 0, 4, 5] diff --git a/tests/unit/models/auto/test_auto_onnx.py b/tests/unit/models/auto/test_auto_onnx.py index 5c9b66f9c..23aca3350 100644 --- a/tests/unit/models/auto/test_auto_onnx.py +++ b/tests/unit/models/auto/test_auto_onnx.py @@ -13,7 +13,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from unittest.mock import MagicMock, patch import pytest @@ -195,3 +195,112 @@ def test_passes_ep_from_kwargs(self, fake_onnx: Path, tmp_path: Path): call_kwargs = mock_from_onnx.call_args.kwargs assert call_kwargs["ep"] == "qnn" + + +# ============================================================================= +# from_onnx dict dispatch → WinMLCompositeModel.from_onnx +# ============================================================================= + + +class TestFromOnnxDictDispatch: + """from_onnx with dict onnx_path delegates to WinMLCompositeModel.from_onnx.""" + + def test_dict_dispatches_to_composite(self, tmp_path: Path): + """Dict onnx_path calls WinMLCompositeModel.from_onnx.""" + with patch( + "winml.modelkit.models.winml.composite_model.WinMLCompositeModel.from_onnx" + ) as mock_from_onnx: + mock_from_onnx.return_value = MagicMock() + + WinMLAutoModel.from_onnx( + {"encoder": str(tmp_path / "enc.onnx"), "decoder": str(tmp_path / "dec.onnx")}, + task="translation", + skip_build=True, + ) + + mock_from_onnx.assert_called_once() + call_kwargs = mock_from_onnx.call_args.kwargs + assert call_kwargs["task"] == "translation" + assert call_kwargs["skip_build"] is True + + def test_hf_config_dispatches_composite_via_registry(self, tmp_path: Path): + """hf_config kwarg threads through so model_type registry lookup works. + + Exercises the real WinMLCompositeModel.from_onnx body via a fake + subclass in a temporary registry slot. hf_config must be a dedicated + parameter on WinMLAutoModel.from_onnx (distinct from ``config``, which + is a WinMLBuildConfig and has no ``model_type`` attribute). + """ + from winml.modelkit.models.winml.composite_model import ( + COMPOSITE_MODEL_REGISTRY, + WinMLCompositeModel, + ) + + # Minimal HF-config stand-in: only attribute access (.model_type) is + # required; no isinstance check happens on hf_config in the dispatch. + class _FakeHFConfig: + model_type = "_test_dispatch_model_" + + enc_path = tmp_path / "enc.onnx" + dec_path = tmp_path / "dec.onnx" + enc_path.write_bytes(b"fake") + dec_path.write_bytes(b"fake") + + test_key = ("_test_dispatch_model_", "_test_task_") + + class _FakeComposite(WinMLCompositeModel): + _SUB_MODEL_CONFIG: ClassVar[dict[str, str]] = { + "encoder": "feature-extraction", + "decoder": "translation", + } + + def forward(self, **kwargs): # type: ignore[override] + pass + + assert test_key not in COMPOSITE_MODEL_REGISTRY + COMPOSITE_MODEL_REGISTRY[test_key] = _FakeComposite + try: + # Patch WinMLAutoModel.from_onnx: outer dict call falls through to + # the real implementation, inner per-component Path calls mocked. + _real_from_onnx = WinMLAutoModel.from_onnx + sub_mock = MagicMock() + sub_calls: list = [] + + def _side_effect(onnx_path, **kw): # type: ignore[no-untyped-def] + if isinstance(onnx_path, dict): + return _real_from_onnx(onnx_path, **kw) + sub_calls.append((onnx_path, kw)) + return sub_mock + + with patch.object(WinMLAutoModel, "from_onnx", side_effect=_side_effect): + result = WinMLAutoModel.from_onnx( + {"encoder": str(enc_path), "decoder": str(dec_path)}, + task="_test_task_", + hf_config=_FakeHFConfig(), + skip_build=True, + ) + + assert isinstance(result, _FakeComposite) + assert len(sub_calls) == 2 + tasks_called = {kw["task"] for _, kw in sub_calls} + assert tasks_called == {"feature-extraction", "translation"} + finally: + COMPOSITE_MODEL_REGISTRY.pop(test_key, None) + + def test_from_onnx_dict_without_hf_config_raises(self, tmp_path: Path): + """Dict dispatch without hf_config surfaces a clear registry-miss error. + + Guards against silent fallback: unregistered ``(model_type, task)`` must + raise ValueError immediately, not accept a wrong-typed kwarg and mis-dispatch. + """ + enc_path = tmp_path / "enc.onnx" + dec_path = tmp_path / "dec.onnx" + enc_path.write_bytes(b"fake") + dec_path.write_bytes(b"fake") + + with pytest.raises(ValueError, match="No composite model"): + WinMLAutoModel.from_onnx( + {"encoder": str(enc_path), "decoder": str(dec_path)}, + task="_unregistered_task_", + skip_build=True, + )