Skip to content

Commit

Permalink
Support FP16/BF16 for onnxrt adaptor (#273)
Browse files Browse the repository at this point in the history
Signed-off-by: Mengni Wang <mengni.wang@intel.com>
  • Loading branch information
mengniwang95 committed Feb 26, 2023
1 parent ba42d00 commit 108c245
Show file tree
Hide file tree
Showing 30 changed files with 283 additions and 125 deletions.
1 change: 1 addition & 0 deletions .azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt
Expand Up @@ -554,6 +554,7 @@ entrypoint
enum
env
environ
ep
eq
erf
Erf
Expand Down
53 changes: 39 additions & 14 deletions docs/source/mixed_precision.md
Expand Up @@ -20,38 +20,63 @@ The recently launched 3rd Gen Intel® Xeon® Scalable processor (codenamed Coope

## Mixed Precision Support Matrix

|Framework |BF16 |
|--------------|:-----------:|
|TensorFlow |&#10004; |
|PyTorch |&#10004; |
|ONNX |plan to support in the future |
|MXNet |&#10004; |
|Framework |BF16 |FP16 |
|--------------|:-----------:|:-----------:|
|TensorFlow |&#10004; |:x: |
|PyTorch |&#10004; |:x: |
|ONNX Runtime |&#10004; |&#10004; |
|MXNet |&#10004; |:x: |

> **During quantization, BF16 conversion is default enabled. Please refer to this [document](./quantization_mixed_precision.md) for its workflow.**
> **During quantization, BF16 conversion is default enabled, FP16 can be executed if 'device' of config is 'gpu'. Please refer to this [document](./quantization_mixed_precision.md) for its workflow.**
## Get Started with Mixed Precision API

To get a bf16 model, users can use the Mixed Precision API as follows.
To get a bf16/fp16 model, users can use the Mixed Precision API as follows.


Supported precisions for mix precision include bf16 and fp16. If users want to get a pure fp16 or bf16 model, they should add another precision into excluded_precisions.

- BF16:

```python
from neural_compressor import mix_precision
from neural_compressor.config import MixedPrecisionConfig

conf = MixedPrecisionConfig()
conf = MixedPrecisionConfig(excluded_precisions=['fp16'])
converted_model = mix_precision.fit(model, config=conf)
converted_model.save('./path/to/save/')
```

- FP16:

```python
from neural_compressor import mix_precision
from neural_compressor.config import MixedPrecisionConfig

conf = MixedPrecisionConfig(
backend='onnxrt_cuda_ep',
device='gpu',
excluded_precisions=['bf16'])
converted_model = mix_precision.fit(model, config=conf)
converted_model.save('./path/to/save/')
```

> **BF16 conversion may lead to accuracy drop. Intel® Neural Compressor provides an accuracy-aware tuning function to reduce accuracy loss, which will fallback converted ops to FP32 automatically to get better accuracy. To enable this function, users only need to provide an evaluation function (or dataloader + metric).**
> **BF16/FP16 conversion may lead to accuracy drop. Intel® Neural Compressor provides an accuracy-aware tuning function to reduce accuracy loss, which will fallback converted ops to FP32 automatically to get better accuracy. To enable this function, users only need to provide an evaluation function (or dataloader + metric).**

## Examples

There are 2 pre-requirements to run BF16 mixed precision examples:
- BF16:

There are 2 pre-requirements to run BF16 mixed precision examples:

1. Hardware: CPU supports `avx512_bf16` instruction set.
2. Software: intel-tensorflow >= [2.3.0](https://pypi.org/project/intel-tensorflow/2.3.0/) or torch >= [1.11.0](https://download.pytorch.org/whl/torch_stable.html).

If either pre-requirement can't be met, the program would exit consequently.

- Hardware: CPU supports `avx512_bf16` instruction set.
- Software: intel-tensorflow >= [2.3.0](https://pypi.org/project/intel-tensorflow/2.3.0/) or torch >= [1.11.0](https://download.pytorch.org/whl/torch_stable.html).
- FP16

If either pre-requirement can't be met, the program would exit consequently.
Currently Intel® Neural Compressor only support FP16 mixed precision for ONNX models.

To run FP16 mixed precision examples, users need to set 'device' of config to 'gpu' and 'backend' to 'onnxrt_cuda_ep'.
18 changes: 14 additions & 4 deletions neural_compressor/adaptor/onnxrt.py
Expand Up @@ -86,12 +86,13 @@ def __init__(self, framework_specific_info):
logger.warning("Dynamic approach doesn't support QDQ format.")

# get quantization config file according to backend
config_file = None
if self.backend == 'CPUExecutionProvider':
config_file = 'onnxrt.yaml'
elif self.backend == 'TensorrtExecutionProvider':
config_file = 'onnxrt_trt.yaml'
elif self.backend == 'CUDAExecutionProvider':
config_file == 'onnxrt_cuda.yaml'
config_file = 'onnxrt_cuda.yaml'
else: # pragma: no cover
assert False, "{} provider is not supported in current environment, " \
"supported providers: {}".format(self.backend,
Expand Down Expand Up @@ -128,6 +129,8 @@ def __init__(self, framework_specific_info):

for precision in self.query_handler.get_precisions():
if precision != 'fp32':
if self.device == 'cpu' and precision == 'fp16':
continue
self.quantizable_op_types += \
self.query_handler.get_op_types_by_precision(precision=precision)

Expand Down Expand Up @@ -930,6 +933,8 @@ def query_fw_capability(self, model):
precisions = query.get_precisions()

for precision in precisions:
if precision == 'fp16' and self.device == 'cpu':
continue
# get supported optype for target precision
optypes = query.get_op_types_by_precision(precision) if \
query.get_op_types_by_precision(precision) != ['*'] else \
Expand Down Expand Up @@ -1046,7 +1051,7 @@ def query_fw_capability(self, model):
else: # pragma: no cover
op_wise.update(
{(node.name, node.op_type): copy.deepcopy(optype_wise[node.op_type])})

return {'optypewise': optype_wise, 'opwise': op_wise}

def _optypewise_filter_for_qdq(self, optype_wise):
Expand Down Expand Up @@ -1411,12 +1416,17 @@ def _compare(version1, version2):
config['capabilities'] = {}

# generate other config content including precisions and ops
precisions = [key for key in config['capabilities'].keys()]
precisions = list(version_config.keys() - {'version', 'recipes'})
if 'fp32' not in precisions:
precisions.append('fp32')
config['precisions'] = {'names': ','.join(precisions)}

op_types = {}
for precision in precisions:
if precision in config['capabilities']:
op_types[precision] = [op_type for op_type in config['capabilities'][precision].keys()]
elif precision in version_config:
op_types[precision] = version_config[precision]
for precision, precision_config in config['capabilities'].items():
op_types[precision] = [op_type for op_type in precision_config.keys()]
if 'fp32' not in op_types:
Expand Down Expand Up @@ -1485,4 +1495,4 @@ def get_fallback_list(self):

def get_specific_cfg_version(self):
"""Get version of the specific config."""
return self.config_version
return self.config_version
21 changes: 20 additions & 1 deletion neural_compressor/adaptor/onnxrt_cuda.yaml
Expand Up @@ -97,6 +97,11 @@
'LSTM': *default_dynamic,
}
}
fp16: &common_fp16 ['Concat', 'Gather', 'Reshape', 'Squeeze', 'Transpose', 'Unsqueeze',
'EmbedLayerNormalization', 'Attention', 'Split', 'Sigmoid', 'Relu', 'Mul', 'Pad', 'MaxPool',
'MatMul', 'LeakyRelu', 'GlobalAveragePool', 'Gemm', 'Conv', 'AveragePool', 'Add', 'Clip']
bf16: &common_bf16 ['Concat', 'Gather', 'Reshape', 'Squeeze', 'Transpose', 'Unsqueeze',
'Split', 'Sigmoid', 'Relu', 'Mul', 'MatMul', 'Gemm', 'Add']
recipes: &default_optimization
graph_optimization: # from onnxruntime graph_optimization_level
level: ['DISABLE_ALL', 'ENABLE_BASIC', 'ENABLE_EXTENDED', 'ENABLE_ALL']
Expand Down Expand Up @@ -137,6 +142,8 @@
},
'dynamic': *ref_1_6_dynamic
}
fp16: *common_fp16
bf16: *common_bf16
recipes:
<<: *default_optimization

Expand Down Expand Up @@ -204,6 +211,8 @@
'LSTM': *default_dynamic,
}
}
fp16: *common_fp16
bf16: *common_bf16
recipes:
<<: *default_optimization

Expand Down Expand Up @@ -278,6 +287,8 @@
'LSTM': *default_dynamic,
}
}
fp16: *common_fp16
bf16: *common_bf16
recipes:
<<: *default_optimization

Expand Down Expand Up @@ -332,6 +343,8 @@
},
'dynamic': *ref_1_9_dynamic
}
fp16: *common_fp16
bf16: *common_bf16
recipes:
<<: *default_optimization

Expand Down Expand Up @@ -393,19 +406,25 @@
},
'dynamic': *ref_1_9_dynamic
}
fp16: *common_fp16
bf16: *common_bf16
recipes:
<<: *default_optimization

-
version:
name: '1.12.0'
int8: *ref_1_11
fp16: *common_fp16
bf16: *common_bf16
recipes:
<<: *default_optimization

-
version:
name: 'default'
int8: *ref_1_6
fp16: *common_fp16
bf16: *common_bf16
recipes:
<<: *default_optimization
<<: *default_optimization
3 changes: 2 additions & 1 deletion neural_compressor/adaptor/ox_utils/calibration.py
Expand Up @@ -426,7 +426,8 @@ def calculate_quantization_params(self, q_config, quantization_thresholds):
qType = 2 # uint8
if tensor_name in output_name_to_nodes:
parent = output_name_to_nodes[tensor_name]
if parent and parent.name in q_config and q_config[parent.name] not in ['fp32']:
if parent and parent.name in q_config and \
q_config[parent.name] not in ['fp32', 'fp16']:
scheme = q_config[parent.name]['activation']['scheme']
qType = q_config[parent.name]['activation']['dtype']
elif self.backend in ['TensorrtExecutionProvider']:
Expand Down
17 changes: 1 addition & 16 deletions neural_compressor/adaptor/ox_utils/operators/direct_q8.py
Expand Up @@ -81,25 +81,10 @@ def cast(self): # pragma: no cover
return
self.quantizer.dtype_cast(self.node, self.dtype)

@op_registry(op_types="Shape, Loop, Slice")
class DirectCastOperator(Operator): # pragma: no cover
"""Direct8bit Operator Cast."""

def __init__(self, onnx_quantizer, onnx_node):
"""Initialization."""
super(DirectCastOperator, self).__init__(onnx_quantizer, onnx_node)

def cast(self):
"""Cast node."""
node = self.node
if node.input[0] not in [i.tensor_name for i in self.quantizer.new_value_info.values()]:
return
self.quantizer.dtype_cast(self.node, self.dtype)

@qop_registry(op_types="Reshape, Transpose, Squeeze, Unsqueeze")
class QDirectOperator(QOperator):
"""QDirect Operator."""

def __init__(self, onnx_node, children, initializers):
"""Initialization."""
super().__init__(onnx_node, children, initializers)
super().__init__(onnx_node, children, initializers)
4 changes: 2 additions & 2 deletions neural_compressor/adaptor/ox_utils/operators/ops.py
Expand Up @@ -70,7 +70,7 @@ def __init__(self, onnx_quantizer, onnx_node):
self.activation_dtype = None
self.activation_scheme = 'asym'
if self.node.name in self.quantizer.config:
if self.quantizer.config[self.node.name] != 'fp32':
if self.quantizer.config[self.node.name] not in self.quantizer.fallback_list:
if 'weight' in self.quantizer.config[self.node.name].keys():
self.per_channel = self.quantizer.config[self.node.name]\
['weight']['granularity'] == 'per_channel'
Expand Down Expand Up @@ -162,4 +162,4 @@ def convert(self):
node.op_type, inputs,
outputs, node.name + '_convert', **kwargs)
add_nodes.append(new_node)
return True, add_nodes, inits
return True, add_nodes, inits
10 changes: 3 additions & 7 deletions neural_compressor/adaptor/ox_utils/quantizer.py
Expand Up @@ -320,7 +320,7 @@ def dfs(match_nodes, node, pattern):
if len(outs) > 0:
output_dtype = str(self.new_value_info[outs[0]].new_dtype)
break
if len(outs) == 0 or all([not self.should_convert(i) for i in children]):
if len(outs) == 0 or all([not self.should_cast(i) for i in children]):
return
if input_dtype == str(match_nodes[1].attribute[0].i) and \
output_dtype == str(match_nodes[0].attribute[0].i) and \
Expand Down Expand Up @@ -355,17 +355,13 @@ def dfs(match_nodes, node, pattern):

def dtype_cast(self, node, cfg, keep_io_types=True): # pragma: no cover
"""Cast node dtype."""
min_positive_val = 1e-7
max_finite_val = 1e4
for idx, tensor_name in enumerate(node.input):
initializer = find_by_name(tensor_name, self.model.initializer())
if initializer is not None:
if initializer.data_type != onnx_proto.TensorProto.FLOAT:
continue
new_tensor = cast_tensor(initializer, cfg)
if new_tensor:
self.model.remove_initializer(initializer)
self.model.add_initializer(new_tensor)
do_cast = cast_tensor(initializer, cfg)
if do_cast:
self.new_value_info[tensor_name] = ValueInfo(tensor_name,
TensorProto.FLOAT, dtype_mapping[cfg])
else:
Expand Down
48 changes: 40 additions & 8 deletions neural_compressor/adaptor/ox_utils/util.py
Expand Up @@ -33,10 +33,16 @@
ms_domain = "com.microsoft"

support_pair = {
'float32 bfloat16': True,
'1 16': True,
'bfloat16 float32': True,
'16 1': True,
'uint8 uint8': True,
'2 2': True,
'float16 float16': True,
'10 10': True,
'bfloat16 bfloat16': True,
'16 16': True,
'float32 float16': True,
'1 10': True,
'float16 float32': True,
Expand All @@ -59,6 +65,7 @@
'uint64': 13,
'complex64': 14,
'complex128': 15,
'bf16': 16
}

PROVIDERS = {
Expand Down Expand Up @@ -135,6 +142,26 @@ def split_shared_bias(model):
node.input[2] = new_input_name
return model

def float_to_float16(tensor):
"""Convert float to float16."""
min_val = 5.96e-08
max_val = 65504.0
tensor[(tensor > max_val) & (tensor < float('inf'))] = max_val
tensor[(tensor < min_val) & (tensor > 0)] = min_val
tensor[(tensor > -min_val) & (tensor < 0)] = -min_val
tensor[(tensor < -max_val) & (tensor > float('-inf'))] = -max_val
return np.float16(tensor)

def float_to_bfloat16(tensor):
"""Convert float to bfloat16."""
min_val = 9.2e-41
max_val = 3.38953139e38
tensor[(tensor > max_val) & (tensor < float('inf'))] = max_val
tensor[(tensor < min_val) & (tensor > 0)] = min_val
tensor[(tensor > -min_val) & (tensor < 0)] = -min_val
tensor[(tensor < -max_val) & (tensor > float('-inf'))] = -max_val
return tensor

def cast_tensor(tensor, dtype): # pragma: no cover
"""Convert tensor float to target dtype.
Expand All @@ -146,14 +173,19 @@ def cast_tensor(tensor, dtype): # pragma: no cover
raise ValueError('Expected input type is an ONNX TensorProto but got %s' % type(tensor))

if tensor.data_type == onnx_proto.TensorProto.FLOAT:
new_tensor = helper.make_tensor(
name=tensor.name,
data_type=dtype_mapping[dtype],
dims=numpy_helper.to_array(tensor).shape,
vals=numpy_helper.to_array(tensor)
)
return new_tensor
return None
val = numpy_helper.to_array(tensor).copy()
if dtype == 'fp16':
new_val = float_to_float16(val)
elif dtype == 'bf16':
new_val = float_to_bfloat16(val)
else:
raise ValueError('Expect fp16 or bf16 but get {}.'.format(dtype))
tensor.float_data[:] = []
tensor.int32_data[:] = []
tensor.raw_data = new_val.tostring()
tensor.data_type = dtype_mapping[dtype]
return True
return False

def remove_init_from_model_input(model):
"""Remove initializer from model input."""
Expand Down

0 comments on commit 108c245

Please sign in to comment.