-
Notifications
You must be signed in to change notification settings - Fork 284
Add multi context binary package support #2369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
xiaoyu-work
wants to merge
5
commits into
main
Choose a base branch
from
xiaoyu/context_binary
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| # ------------------------------------------------------------------------- | ||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # Licensed under the MIT License. | ||
| # -------------------------------------------------------------------------- | ||
| import json | ||
| import logging | ||
| from argparse import ArgumentParser | ||
| from pathlib import Path | ||
|
|
||
| from olive.cli.base import BaseOliveCLICommand, add_logging_options, add_telemetry_options | ||
| from olive.common.utils import hardlink_copy_dir | ||
| from olive.telemetry import action | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @action | ||
| class ModelPackageCommand(BaseOliveCLICommand): | ||
| """Merge multiple single-target context binary outputs into a multi-target package with manifest.json.""" | ||
|
|
||
| @staticmethod | ||
| def register_subcommand(parser: ArgumentParser): | ||
| sub_parser = parser.add_parser( | ||
| "model-package", | ||
| help="Merge multiple context binary outputs into a multi-target package with manifest.json", | ||
| ) | ||
|
|
||
| sub_parser.add_argument( | ||
| "-s", | ||
| "--source", | ||
| type=str, | ||
| action="append", | ||
| required=True, | ||
| help=("Source context binary output directory. Can be specified multiple times. "), | ||
| ) | ||
|
|
||
| sub_parser.add_argument( | ||
| "-o", | ||
| "--output_path", | ||
| type=str, | ||
| required=True, | ||
| help="Output directory for the merged multi-target package.", | ||
| ) | ||
|
|
||
| sub_parser.add_argument( | ||
| "--model_name", | ||
| type=str, | ||
| default=None, | ||
| help="Model name for the manifest. If not set, derived from the output directory name.", | ||
| ) | ||
|
|
||
| add_logging_options(sub_parser) | ||
| add_telemetry_options(sub_parser) | ||
| sub_parser.set_defaults(func=ModelPackageCommand) | ||
|
|
||
| def run(self): | ||
| sources = self._parse_sources() | ||
| output_dir = Path(self.args.output_path) | ||
| output_dir.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| model_name = self.args.model_name or output_dir.name | ||
|
|
||
| # Create component model directory | ||
| component_dir = output_dir / model_name | ||
| component_dir.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| model_variants = {} | ||
| for target_name, source_path in sources: | ||
| model_config = self._read_model_config(source_path) | ||
| model_attrs = model_config.get("config", {}).get("model_attributes") or {} | ||
|
|
||
| # Copy source directory into component_dir/{target_name}/ | ||
| target_dir = component_dir / target_name | ||
| hardlink_copy_dir(source_path, target_dir) | ||
|
|
||
| constraints = {} | ||
| for key in ("ep", "device", "architecture", "ep_compatibility_info"): | ||
| if model_attrs.get(key) is not None: | ||
| constraints[key] = model_attrs[key] | ||
|
|
||
| model_variants[target_name] = { | ||
| "file": model_config.get("config", {}).get("model_path", f"{target_name}/"), | ||
| "constraints": constraints, | ||
| } | ||
|
|
||
| # Write metadata.json in component directory | ||
| metadata = {"name": model_name, "model_variants": model_variants} | ||
| with open(component_dir / "metadata.json", "w") as f: | ||
| json.dump(metadata, f, indent=2) | ||
|
|
||
| # Write manifest.json at package root | ||
| manifest = { | ||
| "name": model_name, | ||
| "component_models": { | ||
| model_name: {"model_variants": model_variants}, | ||
| }, | ||
| } | ||
| manifest_path = output_dir / "manifest.json" | ||
| with open(manifest_path, "w") as f: | ||
| json.dump(manifest, f, indent=2) | ||
|
|
||
| print(f"Merged {len(sources)} targets into {output_dir}") | ||
| print(f"Manifest written to {manifest_path}") | ||
|
|
||
| def _parse_sources(self) -> list[tuple[str, Path]]: | ||
| sources = [] | ||
| for source in self.args.source: | ||
| path = Path(source) | ||
| if not path.is_dir(): | ||
| raise ValueError(f"Source path does not exist or is not a directory: {path}") | ||
|
|
||
| if not (path / "model_config.json").exists(): | ||
| raise ValueError( | ||
| f"No model_config.json found in {path}. " | ||
| "Source must be an Olive output directory with model_config.json." | ||
| ) | ||
|
|
||
| sources.append((path.name, path)) | ||
|
|
||
| if len(sources) < 2: | ||
| raise ValueError("At least two --source directories are required to merge.") | ||
|
|
||
| return sources | ||
|
|
||
| @staticmethod | ||
| def _read_model_config(source_path: Path) -> dict: | ||
| """Read and return model_config.json from a source directory.""" | ||
| config_path = source_path / "model_config.json" | ||
| with open(config_path) as f: | ||
| return json.load(f) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| # ------------------------------------------------------------------------- | ||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # Licensed under the MIT License. | ||
| # -------------------------------------------------------------------------- | ||
| import logging | ||
| from collections.abc import Iterator | ||
| from typing import Any, Optional, Union | ||
|
|
||
| from olive.common.config_utils import serialize_to_json, validate_config | ||
| from olive.common.utils import dict_diff | ||
| from olive.constants import Framework, ModelFileFormat | ||
| from olive.hardware.accelerator import Device | ||
| from olive.model.config.model_config import ModelConfig | ||
| from olive.model.config.registry import model_handler_registry | ||
| from olive.model.handler.base import OliveModelHandler | ||
| from olive.resource_path import OLIVE_RESOURCE_ANNOTATIONS | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @model_handler_registry("MultiTargetModel") | ||
| class MultiTargetModelHandler(OliveModelHandler): | ||
| """MultiTargetModel represents the same model compiled for multiple hardware targets. | ||
|
|
||
| Unlike CompositeModelHandler which holds different component models (e.g., split parts of a pipeline), | ||
| MultiTargetModelHandler holds the same logical model compiled for different hardware targets | ||
| (e.g., different SoC models for QNN). | ||
|
|
||
| When a pass encounters a MultiTargetModelHandler, it runs independently on each target model, | ||
| preserving the multi-target structure through the pipeline. | ||
| """ | ||
|
|
||
| resource_keys: tuple[str, ...] = ("model_path",) | ||
| json_config_keys: tuple[str, ...] = ("target_names",) | ||
|
|
||
| def __init__( | ||
| self, | ||
| target_models: list[Union[OliveModelHandler, dict[str, Any]]], | ||
| target_names: list[str], | ||
| model_path: OLIVE_RESOURCE_ANNOTATIONS = None, | ||
| model_attributes: Optional[dict[str, Any]] = None, | ||
| ): | ||
| super().__init__( | ||
| model_path=model_path, | ||
| framework=Framework.ONNX, | ||
| model_file_format=ModelFileFormat.COMPOSITE_MODEL, | ||
| model_attributes=model_attributes, | ||
| ) | ||
| self._target_models = [ | ||
| validate_config(m, ModelConfig).create_model() if isinstance(m, dict) else m for m in target_models | ||
| ] | ||
| assert all(isinstance(m, OliveModelHandler) for m in self._target_models), ( | ||
| "All target models must be OliveModelHandler or dict" | ||
| ) | ||
| assert len(self._target_models) == len(target_names), "Number of target models and names must match" | ||
| self.target_names = target_names | ||
|
|
||
| @property | ||
| def target_models(self): | ||
| for m in self._target_models: | ||
| m.model_attributes = {**(self.model_attributes or {}), **(m.model_attributes or {})} | ||
| yield m | ||
|
|
||
| def to_json(self, check_object: bool = False): | ||
| json_dict = super().to_json(check_object) | ||
| json_dict["config"]["target_models"] = [] | ||
| for m in self._target_models: | ||
| target_json = m.to_json(check_object) | ||
| target_json["config"]["model_attributes"] = dict_diff( | ||
| target_json["config"]["model_attributes"], self.model_attributes | ||
| ) | ||
| json_dict["config"]["target_models"].append(target_json) | ||
| return serialize_to_json(json_dict, check_object) | ||
|
|
||
| def get_target_models(self) -> Iterator[tuple[str, OliveModelHandler]]: | ||
| """Iterate over (target_name, target_model) pairs.""" | ||
| return zip(self.target_names, self.target_models) | ||
|
|
||
| def load_model(self, rank: int = None, cache_model: bool = True): | ||
| raise NotImplementedError | ||
|
|
||
| @property | ||
| def size_on_disk(self) -> int: | ||
| """Compute size of the model on disk.""" | ||
| raise NotImplementedError | ||
|
|
||
| def prepare_session( | ||
| self, | ||
| inference_settings: Optional[dict[str, Any]] = None, | ||
| device: Device = Device.CPU, | ||
| execution_providers: Union[str, list[str]] = None, | ||
| rank: Optional[int] = None, | ||
| ): | ||
| raise RuntimeError("MultiTargetModelHandler doesn't have a session of its own") | ||
|
|
||
| def run_session( | ||
| self, | ||
| session: Any = None, | ||
| inputs: Union[dict[str, Any], list[Any], tuple[Any, ...]] = None, | ||
| **kwargs: dict[str, Any], | ||
| ) -> Any: | ||
| raise RuntimeError("MultiTargetModelHandler doesn't have a session of its own") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.