Skip to content

Cast f32 -> bf16 -> f32 does not work as expected for graph inputs #9915

@max-ku

Description

@max-ku

We have a sequence of back-to-back Cast operators casting from float to bfloat16 and then back to float, we expect values to be truncated or rounded to bfloat16 precision. However ORT does that only for graph initializers, and not for graph inputs, which remain the same (not truncated or rounded after f32 -> bf16 ->f32 cast sequence).

System information

  • OS Platform: Windows 10
  • ONNX version: onnxruntime==1.8.1, onnx==1.9.0
  • Python version: 3.7

Reproduction instructions

import sys
import argparse
import pathlib

import onnxruntime
import numpy as np

sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL

session = onnxruntime.InferenceSession(sys.argv[1], sess_options)

session.get_modelmeta()

model_inputs = {}

model_inputs['cast0_input'] = np.array([0.333333343267440796], dtype=np.float32)

results = session.run([], model_inputs)

for i, output in enumerate(session.get_outputs()):
  pathlib.Path("output.bin." + str(i)).write_bytes(results[i])
  • ONNX model attached

repro.zip

Expected behavior

We expect graph input values to be truncated or rounded to bfloat16 precision, however it does not happen. It only works for graph initializers.

Workaround

If Identity node is inserted in between Cast nodes, Cast Ops work as expected.

Metadata

Metadata

Assignees

No one assigned

    Labels

    core runtimeissues related to core runtimestaleissues that have not been addressed in a while; categorized by a bot

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions