Skip to content

Commit

Permalink
Enhance the ORT node name checking (#1512)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <yi4.liu@intel.com>
  • Loading branch information
yiliu30 committed Jan 5, 2024
1 parent d69c552 commit f1597aa
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 8 deletions.
12 changes: 6 additions & 6 deletions neural_compressor/adaptor/ox_utils/quantizer.py
Expand Up @@ -46,6 +46,7 @@
dtype_mapping,
dtype_to_name,
find_by_name,
get_node_original_name,
make_dquant_node,
make_quant_node,
quantize_data,
Expand Down Expand Up @@ -168,8 +169,8 @@ def should_quantize(self, node):
if node.name in self.config and self.config[node.name] not in self.fallback_list:
return True
elif (
node.name.split("_quant")[0] in self.config
and self.config[node.name.split("_quant")[0]] not in self.fallback_list
get_node_original_name(node) in self.config
and self.config[get_node_original_name(node)] not in self.fallback_list
):
return True
else:
Expand Down Expand Up @@ -309,7 +310,7 @@ def insert_qdq(self):

def should_convert(self, node):
"""Check if node should be converted."""
name = node.name.split("_quant")[0]
name = get_node_original_name(node)
if (
name in self.config
and self.config[name] not in self.fallback_list
Expand All @@ -327,7 +328,7 @@ def convert_qdq_to_operator_oriented(self):
for node in self.model.nodes():
if node.op_type not in ["QuantizeLinear", "DequantizeLinear"] and self.should_convert(node):
op_converter = OPERATORS[node.op_type](self, node)
mode = self.config[node.name.split("_quant")[0]]["activation"]["quant_mode"]
mode = self.config[get_node_original_name(node)]["activation"]["quant_mode"]
if op_converter.convert_check(mode):
op_converter.convert(mode)
self.model.graph().node.extend(self.new_nodes)
Expand Down Expand Up @@ -622,8 +623,7 @@ def quantize_inputs(self, node, indices=None, initializer_use_weight_qType=True,
data_found = True
else:
data_found, scale_name, zp_name, _, _ = self._get_quantization_params(tensor_name)

if self.config[node.name.split("_quant")[0]]["activation"]["quant_mode"] != "dynamic":
if self.config[get_node_original_name(node)]["activation"]["quant_mode"] != "dynamic":
if data_found is False:
raise ValueError(
"Quantization parameters are not specified for param {}."
Expand Down
14 changes: 14 additions & 0 deletions neural_compressor/adaptor/ox_utils/util.py
Expand Up @@ -95,6 +95,20 @@

MAXIMUM_PROTOBUF = 2147483648

# The quantized node will be renamed to original_name + QUANT_OP_NAME_SUFFIX, for example `conv1` -> `conv1_quant`.
QUANT_OP_NAME_SUFFIX = "_quant"


def get_node_original_name(node) -> str:
"""Get the original name of the given node."""
node_name: str = node.name
# TODO how to handle the unquantized node that has the `_quant` suffix, such as `conv_quant`?
if node_name.endswith(QUANT_OP_NAME_SUFFIX):
return node_name[: -len(QUANT_OP_NAME_SUFFIX)]
else:
# For unquantized nodes
return node_name


def simple_progress_bar(total, i):
"""Progress bar for cases where tqdm can't be used."""
Expand Down
75 changes: 73 additions & 2 deletions test/adaptor/onnxrt_adaptor/test_adaptor_onnxrt.py
Expand Up @@ -452,6 +452,29 @@ def build_matmul_model2():
return model


def build_matmul_model3():
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1, 5, 5])
C = helper.make_tensor_value_info("C", TensorProto.FLOAT, [1, 5, 2])
D = helper.make_tensor_value_info("D", TensorProto.FLOAT, [1, 5, 2])
H = helper.make_tensor_value_info("H", TensorProto.FLOAT, [1, 5, 2])

e_value = np.random.randint(2, size=(10)).astype(np.float32)
B_init = helper.make_tensor("B", TensorProto.FLOAT, [5, 2], e_value.reshape(10).tolist())
E_init = helper.make_tensor("E", TensorProto.FLOAT, [1, 5, 2], e_value.reshape(10).tolist())

matmul_node = onnx.helper.make_node("MatMul", ["A", "B"], ["C"], name="post_quant_Matmul")
add = onnx.helper.make_node("Add", ["C", "E"], ["D"], name="add")

f_value = np.random.randint(2, size=(10)).astype(np.float32)
F_init = helper.make_tensor("F", TensorProto.FLOAT, [1, 5, 2], e_value.reshape(10).tolist())
add2 = onnx.helper.make_node("Add", ["D", "F"], ["H"], name="add2")

graph = helper.make_graph([matmul_node, add, add2], "test_graph_1", [A], [H], [B_init, E_init, F_init])
model = helper.make_model(graph)
model = helper.make_model(graph, **{"opset_imports": [helper.make_opsetid("", 13)]})
return model


def build_matmul_gather_model():
input = helper.make_tensor_value_info("input0", TensorProto.INT64, [1, 1])
output = helper.make_tensor_value_info("output0", TensorProto.FLOAT, [1, 1])
Expand Down Expand Up @@ -592,6 +615,41 @@ def build_conv_model2():
return model


def build_conv_model3():
initializers = []
input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 224, 224])
conv1_weight_initializer = numpy_helper.from_array(
np.random.randint(-1, 2, [3, 3, 3, 3]).astype(np.float32), name="conv1_weight"
)
conv1_node = helper.make_node("Conv", ["input", "conv1_weight"], ["conv1_output"], name="conv1")

conv2_weight_initializer = numpy_helper.from_array(
np.random.randint(-1, 2, [5, 3, 3, 3]).astype(np.float32), name="conv2_weight"
)
conv2_node = helper.make_node("Conv", ["conv1_output", "conv2_weight"], ["conv2_output"], name="pre_quant_conv2")

conv3_weight_initializer = numpy_helper.from_array(
np.random.randint(-1, 2, [3, 3, 3, 3]).astype(np.float32), name="conv3_weight"
)
conv3_node = helper.make_node("Conv", ["input", "conv3_weight"], ["conv3_output"], name="conv3")

avg_args = {"kernel_shape": [3, 3]}
avgpool_node = helper.make_node("AveragePool", ["conv3_output"], ["avg_output"], name="AveragePool", **avg_args)

concat_node = helper.make_node("Concat", ["avg_output", "conv2_output"], ["concat_output"], name="Concat", axis=1)
output = helper.make_tensor_value_info("concat_output", TensorProto.FLOAT, [1, 8, 220, 220])
initializers = [conv1_weight_initializer, conv2_weight_initializer, conv3_weight_initializer]
graph = helper.make_graph(
[conv1_node, conv2_node, conv3_node, concat_node, avgpool_node],
"test",
[input],
[output],
initializer=initializers,
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
return model


def build_gemm_model():
initializers = []
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [-1, 2048])
Expand Down Expand Up @@ -800,10 +858,12 @@ def setUpClass(self):
self.gather_model = build_model_with_gather()
self.matmul_model = build_matmul_model()
self.matmul_model2 = build_matmul_model2()
self.matmul_model3 = build_matmul_model3()
self.rename_model = build_rename_model()
self.conv_model = build_conv_model()
self.gemm_model = build_gemm_model()
self.conv_model2 = build_conv_model2()
self.conv_model3 = build_conv_model3()
export_onnx_nlp_model(self.distilbert_model, self.distilbert_export_path, 14)
export_onnx_nlp_model(self.albert_model, self.albert_export_path, 14)
self.distilbert_model = onnx.load(self.distilbert_export_path)
Expand Down Expand Up @@ -975,7 +1035,7 @@ def test_inspect_tensor(self):
self.assertTrue(sorted(fp32_tensor["weight"].keys()) == sorted(int8_tensor["weight"].keys()))

def test_set_tensor(self):
from neural_compressor.adaptor.ox_utils.util import quantize_data_with_scale_zero
from neural_compressor.adaptor.ox_utils.util import get_node_original_name, quantize_data_with_scale_zero

config = PostTrainingQuantConfig(
approach="static", recipes={"gemm_to_matmul": False, "graph_optimization_level": "ENABLE_EXTENDED"}
Expand All @@ -996,7 +1056,7 @@ def test_set_tensor(self):
framework = "onnxrt_qlinearops"
adaptor = FRAMEWORKS[framework](framework_specific_info)
q_config = {
q_model.nodes()[1].name.split("_quant")[0]: {
get_node_original_name(q_model.nodes()[1]): {
"weight": {"granularity": "per_channel", "dtype": onnx_proto.TensorProto.INT8, "scheme": "sym"}
}
}
Expand Down Expand Up @@ -1327,6 +1387,17 @@ def test_qdq_settings(self):
q_model = quantization.fit(self.rn50_model, config, calib_dataloader=self.cv_dataloader)
self.assertEqual(len([i for i in q_model.nodes() if i.op_type == "QuantizeLinear"]), 53)

def test_model_name_checking(self):
# some nodes have names that include `_quant`
# static
config = PostTrainingQuantConfig(approach="static", quant_format="QDQ", recipes={"dedicated_qdq_pair": True})
q_model = quantization.fit(self.conv_model3, config, calib_dataloader=self.cv_dataloader)
self.assertEqual(len([i for i in q_model.nodes() if i.op_type == "QuantizeLinear"]), 6)
# dynamic
config = PostTrainingQuantConfig(approach="dynamic")
q_model = quantization.fit(self.matmul_model3, config, calib_dataloader=self.matmul_dataloader)
self.assertTrue("MatMulInteger" in [i.op_type for i in q_model.nodes()])

def test_lower_is_better_case(self):
import time

Expand Down

0 comments on commit f1597aa

Please sign in to comment.