-
Notifications
You must be signed in to change notification settings - Fork 86
[API] Create stable APIs for PyTorch 2.5 #1832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
d5633f7
[API] Create stable APIs for PyTorch 2.5
justinchuby 94335d5
external tensor
justinchuby dd3df10
lint
justinchuby d84a8ca
sort
justinchuby 871621d
rename
justinchuby 2cfc439
order
justinchuby 48b7ec5
import
justinchuby 1effffb
lint
justinchuby c827bf3
typing
justinchuby 1914c99
lint
justinchuby ef254c0
Merge branch 'main' into justinchu/stable-torch
justinchuby 9c5b826
Update torch_2_5.py
justinchuby 5ba6919
fix qualify name
justinchuby b4bbd82
fix
justinchuby 5e5bb6b
convert_version
justinchuby 014c676
lint
justinchuby File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
"""Semi-private stable APIs for framework-specific usage only.""" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
|
||
# Licensed under the MIT License. | ||
"""Stable APIs for PyTorch 2.5.""" | ||
|
||
|
||
from __future__ import annotations | ||
|
||
__all__ = [ | ||
"check_model", | ||
"convert_version", | ||
"get_torchlib_ops", | ||
"optimize", | ||
"save_model_with_external_data", | ||
] | ||
|
||
import dataclasses | ||
|
||
import os | ||
import pathlib | ||
|
||
from typing import Callable | ||
|
||
import onnx | ||
|
||
from onnxscript import ir | ||
from onnxscript.function_libs.torch_lib import registration | ||
from onnxscript.ir import _external_data | ||
|
||
# Internal flag. Will go away. | ||
_TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR = ( | ||
os.getenv("TORCH_ONNX_OFFLOAD_EXTERNAL_DATA_WITH_IR") == "1" | ||
) | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class _OnnxFunctionMeta: | ||
"""A wrapper of onnx-script function with additional metadata. | ||
|
||
qualified_name: The qualified name of the aten operator. | ||
function: The onnx-script function. | ||
domain: The domain of the function. | ||
name: The name of the function. | ||
is_complex: Whether the function is a complex function. | ||
gramalingam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
qualified_name: str | ||
function: Callable | ||
domain: str | ||
name: str | ||
is_complex: bool = False | ||
|
||
|
||
def optimize(model: ir.Model) -> ir.Model: | ||
"""Optimize the model.""" | ||
|
||
# TODO(justinchuby): Use the optimizer | ||
shubhambhokare1 marked this conversation as resolved.
Show resolved
Hide resolved
titaiwangms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return model | ||
|
||
|
||
def convert_version(model: ir.Model, target_version: int) -> ir.Model: | ||
titaiwangms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Convert the model to the specified ONNX opset version.""" | ||
# model_version = model.opset_import.get("") | ||
# if model_version == target_version: | ||
# # No conversion needed | ||
# return model | ||
|
||
# # FIXME(justinchuby): version_converter does not support functions | ||
# proto = ir.serde.serialize_model(model) | ||
# proto = onnx.version_converter.convert_version(proto, target_version) | ||
# return ir.serde.deserialize_model(proto) | ||
# TODO(justinchuby): This function needs to be carefully implemented | ||
# to handle large models. For now, we just return the model. | ||
del target_version # Unused | ||
return model | ||
|
||
|
||
def check_model(model: ir.Model) -> None: | ||
|
||
"""Check the model.""" | ||
|
||
del model # Unused yet | ||
Check warningCode scanning / CodeQL Unnecessary delete statement in function
Unnecessary deletion of local variable [model](1) in function [check_model](2).
|
||
|
||
|
||
def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike) -> None: | ||
"""Save the model with external data. The model is unchanged after saving.""" | ||
justinchuby marked this conversation as resolved.
Show resolved
Hide resolved
justinchuby marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# TODO(#1835): Decide if we want to externalize large attributes as well | ||
if _TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR: | ||
initializer_values = tuple(model.graph.initializers.values()) | ||
tensors = [v.const_value for v in initializer_values] | ||
for tensor in tensors: | ||
if tensor is None: | ||
raise ValueError( | ||
"The model contains uninitialized initializer values. " | ||
"Please make sure all initializer values are initialized." | ||
) | ||
destination_path = pathlib.Path(model_path) | ||
base_dir = destination_path.parent | ||
data_path = f"{destination_path.name}.data" | ||
|
||
external_tensors = _external_data.convert_tensors_to_external( | ||
tensors, # type: ignore[arg-type] | ||
base_dir, | ||
data_path, | ||
) | ||
|
||
# Replace the initializer values with external tensors and save the model | ||
for initializer, external_tensor in zip(initializer_values, external_tensors): | ||
initializer.const_value = external_tensor | ||
ir.save(model, model_path) | ||
|
||
# Restore the original initializer values so the model is unchanged | ||
for initializer, tensor in zip(initializer_values, tensors): | ||
initializer.const_value = tensor | ||
|
||
else: | ||
destination_path = pathlib.Path(model_path) | ||
# Create the directory if it does not exist | ||
data_path = f"{destination_path.name}.data" | ||
proto = ir.serde.serialize_model(model) | ||
onnx.save_model( | ||
proto, | ||
model_path, | ||
save_as_external_data=True, | ||
location=data_path, | ||
) | ||
|
||
|
||
def get_torchlib_ops() -> list[_OnnxFunctionMeta]: | ||
# Trigger op registration | ||
from onnxscript.function_libs.torch_lib import ( # pylint: disable=import-outside-toplevel | ||
ops, | ||
) | ||
|
||
del ops # Unused | ||
|
||
torchlib_registry = registration.default_registry | ||
function_metas = [] | ||
|
||
for qualified_name, aten_overloads_func in torchlib_registry.items(): | ||
if qualified_name.startswith("internal::"): | ||
# Skip the custom defined internal functions | ||
continue | ||
|
||
for overload_func in aten_overloads_func.overloads: | ||
function_meta = _OnnxFunctionMeta( | ||
qualified_name=qualified_name, | ||
function=overload_func, | ||
domain=overload_func.function_ir.domain, | ||
name=overload_func.name, | ||
is_complex=False, | ||
) | ||
function_metas.append(function_meta) | ||
for complex_func in aten_overloads_func.complex: | ||
function_meta = _OnnxFunctionMeta( | ||
qualified_name=qualified_name, | ||
function=complex_func, | ||
domain=complex_func.function_ir.domain, | ||
name=complex_func.name, | ||
is_complex=True, | ||
) | ||
function_metas.append(function_meta) | ||
|
||
return function_metas |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.