Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX transformation to cast int64 constants to int32 when possible #655

Merged
merged 8 commits into from Jan 16, 2023

Conversation

fxmarty
Copy link
Collaborator

@fxmarty fxmarty commented Dec 30, 2022

As per title.

Partially fixes #627 , we need to integrate this in this CLI and document + test.

Try with:

import onnx
from pathlib import Path
from optimum.onnx import model_to_int32

path = "/path/to/decoder_model.onnx"
model = onnx.load(path)

model = model_to_int32(model)

onnx.save(
    model,
    path,
    save_as_external_data=True,
    all_tensors_to_one_file=True,
    location=Path(path).name + "_data",
)

onnx.checker.check_model(path)

Inspect the original and transformed models "Slice" nodes.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 30, 2022

The documentation is not available anymore as the PR was closed or merged.

@fxmarty
Copy link
Collaborator Author

fxmarty commented Dec 30, 2022

Should test with ONNX Runtime, it's possible there's an issue (could avoid converting int64 to int32, and just clamp to avoid it):

(hf-inf) fxmarty@huggingface:~/hf_internship/optimum/gptj_tiny_bool_int32$ onnxsim decoder_model_int32.onnx decoder_model_int32_sim.onnx --input-shape input_ids:1,12 attention_mask:1,12
Your model contains "Tile" ops or/and "ConstantOfShape" ops. Folding these ops can make the simplified model much larger. If it is not expected, please 
specify "--no-large-tensor" (which will lose some optimization chances)
Simplifying...
Traceback (most recent call last):
  File "/home/fxmarty/anaconda3/envs/hf-inf/bin/onnxsim", line 8, in <module>
    sys.exit(main())
  File "/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/onnxsim/onnx_simplifier.py", line 434, in main
    model_opt, check_ok = simplify(
  File "/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/onnxsim/onnx_simplifier.py", line 186, in simplify
    model_opt_bytes = C.simplify(
  File "/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/onnxsim/onnx_simplifier.py", line 239, in Run
    sess = rt.InferenceSession(
  File "/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 347, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 386, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Type Error: Type parameter (Tind) of Optype (Slice) bound to different types (tensor(int32) and tensor(int64) in node (/transformer/h.0/attn/Slice_4).

@fxmarty
Copy link
Collaborator Author

fxmarty commented Jan 3, 2023

Should be all good now, feel free to review.

@fxmarty fxmarty added gpu-test trigger GPU tests and removed gpu-test trigger GPU tests labels Jan 5, 2023
Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -92,7 +59,7 @@ def remove_duplicate_weights(model: ModelProto, inplace: bool = False) -> ModelP
return model


def replace_atenops_to_gather(model: ModelProto):
def replace_atenops_to_gather(model: ModelProto) -> ModelProto:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this replace aten ops related to Gather to a working version?
If so I would make it more explicit in the name.

Copy link
Collaborator Author

@fxmarty fxmarty Jan 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea, I did not code it. I'm not sure in which case it is an useful transform.

optimum/onnx/graph_transformations.py Outdated Show resolved Hide resolved
@fxmarty fxmarty merged commit 4016c17 into huggingface:main Jan 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
gpu-test trigger GPU tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

INT64 clamping to INT32 creates overhead while using TensorRT
3 participants