Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 50 additions & 29 deletions onnxruntime/python/tools/transformers/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

import numpy as np
import onnx
from onnx import helper, numpy_helper
from onnx import onnx_pb as onnx_proto
from onnx import AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, helper, numpy_helper
from onnx.shape_inference import infer_shapes, infer_shapes_path
from packaging import version

Expand Down Expand Up @@ -87,11 +86,11 @@ def convert_tensor_float_to_float16(tensor, min_positive_val=5.96e-08, max_finit
TensorProto: the converted tensor.
"""

if not isinstance(tensor, onnx_proto.TensorProto):
if not isinstance(tensor, TensorProto):
raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}")

if tensor.data_type == onnx_proto.TensorProto.FLOAT:
tensor.data_type = onnx_proto.TensorProto.FLOAT16
if tensor.data_type == TensorProto.FLOAT:
tensor.data_type = TensorProto.FLOAT16
# convert float_data (float type) to float16 and write to int32_data
if tensor.float_data:
float16_data = convert_np_to_float16(np.array(tensor.float_data), min_positive_val, max_finite_val)
Expand Down Expand Up @@ -152,12 +151,12 @@ def make_value_info_from_tensor(tensor):
class InitializerTracker:
"""Class for keeping track of initializer."""

def __init__(self, initializer: onnx_proto.TensorProto):
def __init__(self, initializer: TensorProto):
self.initializer = initializer
self.fp32_nodes = []
self.fp16_nodes = []

def add_node(self, node: onnx_proto.NodeProto, is_node_blocked):
def add_node(self, node: NodeProto, is_node_blocked):
if is_node_blocked:
self.fp32_nodes.append(node)
else:
Expand Down Expand Up @@ -219,7 +218,7 @@ def convert_float_to_float16(
else:
model = onnx.load(model_path)

if not isinstance(model, onnx_proto.ModelProto):
if not isinstance(model, ModelProto):
raise ValueError(f"Expected an ONNX ModelProto but got {type(model)}")

func_infer_shape = None
Expand Down Expand Up @@ -259,8 +258,8 @@ def convert_float_to_float16(
graph_io_to_skip = set()
io_casts = set()

fp32_inputs = [n.name for n in model.graph.input if n.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT]
fp32_outputs = [n.name for n in model.graph.output if n.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT]
fp32_inputs = [n.name for n in model.graph.input if n.type.tensor_type.elem_type == TensorProto.FLOAT]
fp32_outputs = [n.name for n in model.graph.output if n.type.tensor_type.elem_type == TensorProto.FLOAT]
if isinstance(keep_io_types, list):
fp32_inputs = [n for n in fp32_inputs if n in keep_io_types]
fp32_outputs = [n for n in fp32_outputs if n in keep_io_types]
Expand All @@ -278,9 +277,9 @@ def convert_float_to_float16(
new_value_info = model.graph.value_info.add()
new_value_info.CopyFrom(n)
new_value_info.name = output_name
new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT16
# add Cast node (from tensor(float) to tensor(float16) after graph input
new_node = [helper.make_node("Cast", [n.name], [output_name], to=10, name=node_name)]
new_node = [helper.make_node("Cast", [n.name], [output_name], to=TensorProto.FLOAT16, name=node_name)]
model.graph.node.extend(new_node)
value_info_list.append(new_value_info)
io_casts.add(node_name)
Expand All @@ -296,7 +295,7 @@ def convert_float_to_float16(
new_value_info = model.graph.value_info.add()
new_value_info.CopyFrom(n)
new_value_info.name = input_name
new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT16
new_node = [helper.make_node("Cast", [input_name], [n.name], to=1, name=node_name)]
model.graph.node.extend(new_node)
value_info_list.append(new_value_info)
Expand All @@ -307,12 +306,12 @@ def convert_float_to_float16(
next_level = []
for q in queue:
# if q is model, push q.graph (GraphProto)
if isinstance(q, onnx_proto.ModelProto):
if isinstance(q, ModelProto):
next_level.append(q.graph)
# if q is model.graph, push q.node.attribute (AttributeProto)
if isinstance(q, onnx_proto.GraphProto):
if isinstance(q, GraphProto):
for n in q.initializer: # TensorProto type
if n.data_type == onnx_proto.TensorProto.FLOAT:
if n.data_type == TensorProto.FLOAT:
assert n.name not in fp32_initializers
fp32_initializers[n.name] = InitializerTracker(n)

Expand Down Expand Up @@ -343,10 +342,32 @@ def convert_float_to_float16(
else:
if n.op_type == "Cast":
for attr in n.attribute:
if attr.name == "to" and attr.i == 1:
attr.i = 10
if attr.name == "to" and attr.i == TensorProto.FLOAT:
attr.i = TensorProto.FLOAT16
break

if n.op_type in [
"EyeLike",
"Multinomial",
"RandomNormal",
"RandomNormalLike",
"RandomUniform",
"RandomUniformLike",
"SequenceEmpty",
"Bernoulli",
]:
has_dtype = False
for attr in n.attribute:
if attr.name == "dtype":
has_dtype = True
if attr.i == TensorProto.FLOAT:
attr.i = TensorProto.FLOAT16

# The dtype attribute is optional and default is FLOAT in the following operators
# so we need add dtype attribute to specify the data type float16
if (n.op_type in ["RandomNormal", "RandomUniform", "SequenceEmpty"]) and not has_dtype:
n.attribute.extend([helper.make_attribute("dtype", TensorProto.FLOAT16)])

# For Resize/GroupNorm, attribute data type cannot be changed
if n.op_type not in ALWAYS_FLOAT_INPUTS or n.op_type in force_fp16_inputs_dict:
for attr in n.attribute:
Expand All @@ -356,27 +377,27 @@ def convert_float_to_float16(

# if q is model.graph.node.attribute, push q.g and q.graphs (GraphProto)
# and process node.attribute.t and node.attribute.tensors (TensorProto)
if isinstance(q, onnx_proto.AttributeProto):
if isinstance(q, AttributeProto):
next_level.append(q.g)
for n in q.graphs:
next_level.append(n) # noqa: PERF402
q.t.CopyFrom(convert_tensor_float_to_float16(q.t, min_positive_val, max_finite_val))
for n in q.tensors:
n = convert_tensor_float_to_float16(n, min_positive_val, max_finite_val) # noqa: PLW2901
# if q is graph, process input, output and value_info (ValueInfoProto)
if isinstance(q, onnx_proto.GraphProto):
if isinstance(q, GraphProto):
# Note that float initializers tracked by fp32_initializers will be processed later.
# for all ValueInfoProto with tensor(float) type in input, output and value_info, convert them to
# tensor(float16) except map and seq(map). And save them in value_info_list for further processing
for n in itertools.chain(q.input, q.output, q.value_info):
if n.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT:
if n.type.tensor_type.elem_type == TensorProto.FLOAT:
if n.name not in graph_io_to_skip:
n.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16
n.type.tensor_type.elem_type = TensorProto.FLOAT16
value_info_list.append(n)
if n.type.HasField("sequence_type"):
if n.type.sequence_type.elem_type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT:
if n.type.sequence_type.elem_type.tensor_type.elem_type == TensorProto.FLOAT:
if n.name not in graph_io_to_skip:
n.type.sequence_type.elem_type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16
n.type.sequence_type.elem_type.tensor_type.elem_type = TensorProto.FLOAT16
value_info_list.append(n)

queue = next_level
Expand Down Expand Up @@ -405,7 +426,7 @@ def convert_float_to_float16(
new_value_info.CopyFrom(value_info)
output_name = node.name + "_input_cast_" + str(i)
new_value_info.name = output_name
new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
# add Cast node (from tensor(float16) to tensor(float) before current node
node_name = node.name + "_input_cast" + str(i)
new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)]
Expand All @@ -428,7 +449,7 @@ def convert_float_to_float16(
new_value_info.CopyFrom(value_info)
output_name = node.name + "_input_cast_" + str(i)
new_value_info.name = output_name
new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
# add Cast node (from tensor(float16) to tensor(float) before current node
node_name = node.name + "_input_cast" + str(i)
new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)]
Expand All @@ -447,7 +468,7 @@ def convert_float_to_float16(
new_value_info.CopyFrom(value_info)
input_name = node.name + "_output_cast_" + str(i)
new_value_info.name = input_name
new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
# add Cast node (from tensor(float) to tensor(float16) after current node
node_name = node.name + "_output_cast" + str(i)
new_node = [helper.make_node("Cast", [input_name], [output], to=10, name=node_name)]
Expand All @@ -460,9 +481,9 @@ def convert_float_to_float16(

def float_to_float16_max_diff(tensor, min_positive_val=5.96e-08, max_finite_val=65504.0):
"""Measure the maximum absolute difference after converting a float tensor to float16."""
if not isinstance(tensor, onnx_proto.TensorProto):
if not isinstance(tensor, TensorProto):
raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}")
if tensor.data_type != onnx_proto.TensorProto.FLOAT:
if tensor.data_type != TensorProto.FLOAT:
raise ValueError("Expected tensor data type is float.")

float32_data = None
Expand Down