-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed as not planned
Closed as not planned
Copy link
Labels
core runtimeissues related to core runtimeissues related to core runtimestaleissues that have not been addressed in a while; categorized by a botissues that have not been addressed in a while; categorized by a bot
Description
We have a sequence of back-to-back Cast operators casting from float to bfloat16 and then back to float, we expect values to be truncated or rounded to bfloat16 precision. However ORT does that only for graph initializers, and not for graph inputs, which remain the same (not truncated or rounded after f32 -> bf16 ->f32 cast sequence).
System information
- OS Platform: Windows 10
- ONNX version: onnxruntime==1.8.1, onnx==1.9.0
- Python version: 3.7
Reproduction instructions
import sys
import argparse
import pathlib
import onnxruntime
import numpy as np
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
session = onnxruntime.InferenceSession(sys.argv[1], sess_options)
session.get_modelmeta()
model_inputs = {}
model_inputs['cast0_input'] = np.array([0.333333343267440796], dtype=np.float32)
results = session.run([], model_inputs)
for i, output in enumerate(session.get_outputs()):
pathlib.Path("output.bin." + str(i)).write_bytes(results[i])
- ONNX model attached
Expected behavior
We expect graph input values to be truncated or rounded to bfloat16 precision, however it does not happen. It only works for graph initializers.
Workaround
If Identity node is inserted in between Cast nodes, Cast Ops work as expected.
Metadata
Metadata
Assignees
Labels
core runtimeissues related to core runtimeissues related to core runtimestaleissues that have not been addressed in a while; categorized by a botissues that have not been addressed in a while; categorized by a bot