Skip to content

Commit

Permalink
Add more BF16 ops support on stock tensorflow (#792)
Browse files Browse the repository at this point in the history
Signed-off-by: Lv, Liang1 <liang1.lv@intel.com>
  • Loading branch information
lvliang-intel committed Apr 28, 2023
1 parent 35a2cd2 commit 369b9d0
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 7 deletions.
6 changes: 4 additions & 2 deletions neural_compressor/adaptor/tensorflow.py
Expand Up @@ -722,6 +722,7 @@ def _query_quantizable_ops(self, matched_nodes):
fp32_common_config = {'weight': {'dtype': 'fp32'}, 'activation': {'dtype': 'fp32'}}
uint8_type = self.query_handler.get_op_types_by_precision(precision='uint8')
int8_type = self.query_handler.get_op_types_by_precision(precision='int8')
bf16_type = self.query_handler.get_op_types_by_precision(precision='bf16')
tf_quantizable_op_type = list(set(uint8_type).union(set(int8_type)))

valid_precision = self.query_handler.get_mixed_precision_combination()
Expand Down Expand Up @@ -792,7 +793,8 @@ def _query_quantizable_ops(self, matched_nodes):
self.quantizable_op_details[(
node_name, self.unify_op_type_mapping[node_op]
)] = [copy.deepcopy(other_config), fp32_common_config]
if ('bf16' in valid_precision and CpuInfo().bf16) or os.getenv('FORCE_BF16') == '1':
if node_op in bf16_type and (('bf16' in valid_precision and CpuInfo().bf16) \
or os.getenv('FORCE_BF16') == '1'):
self.quantizable_op_details[(
node_name, self.unify_op_type_mapping[node_op]
)].insert(1, bf16_common_config)
Expand Down Expand Up @@ -2228,7 +2230,7 @@ def get_op_types_by_precision(self, precision):
return self.cur_config[precision]
if version1_gte_version2(tf.version.VERSION, '2.1.0') or \
version1_eq_version2(tf.version.VERSION, '1.15.0-up3'):
return ['Conv2D']
return self.cur_config[precision]
return []

def get_mixed_precision_combination(self):
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/tensorflow.yaml
Expand Up @@ -153,7 +153,7 @@
version:
name: ['2.1.0', '2.2.0', '2.3.0', '2.4.0', '2.5.0', '2.6.0', '2.6.1', '2.6.2', '2.7.0', '2.8.0', '2.9.0', '2.9.1', '2.10.0', '2.11.0', '1.15.0-up1', '1.15.0-up2', 1.15.0-up3]

bf16: ['Conv2D']
bf16: ['Conv2D', 'Conv3D', 'MatMul', 'BatchMatMul', 'MaxPool', 'MaxPool3D', 'AvgPool', 'AvgPool3D', 'DepthwiseConv2dNative']
fp32: ['*'] # '*' means all op types

int8: {
Expand Down
Expand Up @@ -34,6 +34,8 @@
from neural_compressor.adaptor.tf_utils.graph_util import GraphRewriterHelper as Helper
from ..generic.graph_cse_optimizer import GraphCseOptimizer
from ..generic.dequantize_cast_optimizer import DequantizeCastOptimizer
import tensorflow as tf
from neural_compressor.adaptor.tf_utils.util import TF_SPR_BASE_VERSIONS

DT_FLOAT32 = attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum)
DT_BFLOAT16 = attr_value_pb2.AttrValue(type=dtypes.bfloat16.as_datatype_enum)
Expand Down Expand Up @@ -179,7 +181,8 @@ def _bf16_convert(self, bf16_node_name):
tensor=tensor_util.make_tensor_proto(
fp32_value, dtypes.bfloat16, fp32_value.shape)))
elif 'Dequantize' == input_node.op and len(input_node_outputs) == 1 \
and input_node.attr['mode'].s != b'MIN_FIRST':
and input_node.attr['mode'].s != b'MIN_FIRST' \
and tf.version.VERSION in TF_SPR_BASE_VERSIONS:
# Dequantize with mode MIN_FIRST does not support bf16 in both eigen and mkl
_, outputs_dt_input_node = self._dtype(input_node)
allowed_input_node_dt_val = self._allowed_dtype_val(input_node)
Expand Down
Expand Up @@ -21,8 +21,10 @@

from ..graph_base import GraphRewriterBase
from neural_compressor.adaptor.tf_utils.graph_util import GraphAnalyzer
from neural_compressor.adaptor.tf_utils.graph_util import GraphRewriterHelper as Helper
from neural_compressor.utils.utility import dump_elapsed_time
import tensorflow as tf
from neural_compressor.adaptor.tf_utils.util import TF_SPR_BASE_VERSIONS

class DequantizeCastOptimizer(GraphRewriterBase):
"""Remove the Cast OP and set Dequantize output to B16 if the Cast OP output is BF16."""

Expand All @@ -36,6 +38,11 @@ def do_transformation(self):
Returns:
[graphdef]: optimized graph
"""
# stock TF _MklDequantize doesn't support BF16 currently.
# TODO remove this when spr-base upstream to stock TF.
if not tf.version.VERSION in TF_SPR_BASE_VERSIONS:
return self.model

DT_BFLOAT16 = attr_value_pb2.AttrValue(type=dtypes.bfloat16.as_datatype_enum)
cur_graph = GraphAnalyzer()
cur_graph.graph = self.model
Expand Down
Expand Up @@ -67,7 +67,10 @@ def test_dequantize_cast_normal(self):
graph_def = build_fake_graphdef()
converted_graph_def = DequantizeCastOptimizer(graph_def).do_transformation()
for i in converted_graph_def.node:
self.assertNotEqual(i.op, 'Cast')
if i.op == 'Cast':
hasCast = True
break
self.assertEqual(hasCast, True)

@disable_random()
def test_dequantize_cast_min_first(self):
Expand Down
Expand Up @@ -52,7 +52,9 @@ def build_fake_framework_yaml():
---
-
version:
name: ['2.1.0', '2.2.0', '2.3.0', '2.4.0', '2.5.0', '2.6.0', '2.7.0']
name: ['2.1.0', '2.2.0', '2.3.0', '2.4.0', '2.5.0', '2.6.0', '2.7.0']
bf16: ['Conv2D', 'MatMul', 'ConcatV2', 'MaxPool', 'AvgPool', 'DepthwiseConv2dNative']
int8: {
'static': {
Expand Down Expand Up @@ -93,6 +95,8 @@ def build_fake_framework_yaml():
version:
name: ['default']
bf16: ['Conv2D', 'MatMul', 'ConcatV2', 'MaxPool', 'AvgPool', 'DepthwiseConv2dNative']
int8: {
'static': {
'Conv2D': {
Expand Down
@@ -0,0 +1,96 @@
import unittest
import os
import yaml
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import dtypes
from neural_compressor.adaptor.tf_utils.util import disable_random
from neural_compressor.adaptor.tf_utils.graph_util import GraphRewriterHelper as Helper
from neural_compressor.adaptor.tf_utils.graph_rewriter.generic.dequantize_cast_optimizer import DequantizeCastOptimizer

def build_fake_graphdef(set_min_first=False, dq_multi_outputs=False):
tf.compat.v1.disable_eager_execution()

input = tf.compat.v1.placeholder(tf.float32, shape=(32, 224, 224, 3), name='input')
graph_def = tf.compat.v1.get_default_graph().as_graph_def(add_shapes=True)

min_input = Helper.create_constant_node(
'test_min',
value=0.,
dtype=dtypes.float32)

max_input = Helper.create_constant_node(
'test_max',
value=[1],
dtype=dtypes.float32)

quant_v2_node = Helper.create_node("QuantizeV2", 'test_quantize',
[input.name, min_input.name, max_input.name])

dequantize_node = Helper.create_node(
"Dequantize", 'test_dequantize',
[quant_v2_node.name, quant_v2_node.name + ':1', quant_v2_node.name + ':2'])
if set_min_first:
Helper.set_attr_string(dequantize_node, "mode", b'MIN_FIRST')

cast_node = Helper.create_node(
"Cast", 'test_cast', [dequantize_node.name])
Helper.set_attr_dtype(cast_node, "DstT", dtypes.bfloat16)
Helper.set_attr_dtype(cast_node, "SrcT", dtypes.float32)
Helper.set_attr_bool(cast_node, "Truncate", False)

dentity_node = Helper.create_node(
"Identity", 'output', [cast_node.name])
Helper.set_attr_dtype(dentity_node, "T", dtypes.bfloat16)

graph_def.node.extend([
min_input,
max_input,
quant_v2_node,
dequantize_node,
cast_node,
dentity_node,
])

if dq_multi_outputs:
dentity_node_2 = Helper.create_node(
"Identity", 'id_1', [dequantize_node.name])
Helper.set_attr_dtype(dentity_node_2, "T", dtypes.float32)
graph_def.node.extend([dentity_node_2])

return graph_def

class TestDequantizeCastOptimizer(unittest.TestCase):

@disable_random()
def test_dequantize_cast_normal(self):
graph_def = build_fake_graphdef()
converted_graph_def = DequantizeCastOptimizer(graph_def).do_transformation()
for i in converted_graph_def.node:
self.assertNotEqual(i.op, 'Cast')

@disable_random()
def test_dequantize_cast_min_first(self):
graph_def = build_fake_graphdef(set_min_first=True)
converted_graph_def = DequantizeCastOptimizer(graph_def).do_transformation()
hasCast = False
for i in converted_graph_def.node:
if i.op == 'Cast':
hasCast = True
break
self.assertEqual(hasCast, True)

@disable_random()
def test_dequantize_cast_multiple_outputs(self):
graph_def = build_fake_graphdef(dq_multi_outputs=True)
converted_graph_def = DequantizeCastOptimizer(graph_def).do_transformation()
hasCast = False
for i in converted_graph_def.node:
if i.op == 'Cast':
hasCast = True
break
self.assertEqual(hasCast, True)


if __name__ == "__main__":
unittest.main()

0 comments on commit 369b9d0

Please sign in to comment.