Skip to content

Commit

Permalink
Fix mix precision (#628)
Browse files Browse the repository at this point in the history
Signed-off-by: Mengni Wang <mengni.wang@intel.com>
  • Loading branch information
mengniwang95 committed Mar 13, 2023
1 parent 387bd10 commit 4b71a82
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 15 deletions.
28 changes: 20 additions & 8 deletions docs/source/mixed_precision.md
Expand Up @@ -42,7 +42,7 @@ Supported precisions for mix precision include bf16 and fp16. If users want to g
from neural_compressor import mix_precision
from neural_compressor.config import MixedPrecisionConfig

conf = MixedPrecisionConfig(excluded_precisions=['fp16'])
conf = MixedPrecisionConfig(precision='bf16')
converted_model = mix_precision.fit(model, config=conf)
converted_model.save('./path/to/save/')
```
Expand All @@ -56,7 +56,7 @@ from neural_compressor.config import MixedPrecisionConfig
conf = MixedPrecisionConfig(
backend='onnxrt_cuda_ep',
device='gpu',
excluded_precisions=['bf16'])
precision='fp16')
converted_model = mix_precision.fit(model, config=conf)
converted_model.save('./path/to/save/')
```
Expand All @@ -66,17 +66,29 @@ converted_model.save('./path/to/save/')

## Examples

There are some pre-requirements to run mixed precision examples for each framework. If the hardware requirements cannot be met, the program would exit consequently.

- BF16:

There are 2 pre-requirements to run BF16 mixed precision examples:

### TensorFlow

1. Hardware: CPU supports `avx512_bf16` instruction set.
2. Software: intel-tensorflow >= [2.3.0](https://pypi.org/project/intel-tensorflow/2.3.0/) or torch >= [1.11.0](https://download.pytorch.org/whl/torch_stable.html).
2. Software: intel-tensorflow >= [2.3.0](https://pypi.org/project/intel-tensorflow/2.3.0/).

If either pre-requirement can't be met, the program would exit consequently.
### PyTorch

1. Hardware: CPU supports `avx512_bf16` instruction set.
2. Software: torch >= [1.11.0](https://download.pytorch.org/whl/torch_stable.html).

### ONNX Runtime

1. Hardware: GPU, set 'device' of config to 'gpu' and 'backend' to 'onnxrt_cuda_ep'.
2. Software: onnxruntime-gpu.

- FP16

Currently Intel® Neural Compressor only support FP16 mixed precision for ONNX models.

To run FP16 mixed precision examples, users need to set 'device' of config to 'gpu' and 'backend' to 'onnxrt_cuda_ep'.
### ONNX Runtime

1. Hardware: GPU, set 'device' of config to 'gpu' and 'backend' to 'onnxrt_cuda_ep'.
2. Software: onnxruntime-gpu.
4 changes: 3 additions & 1 deletion neural_compressor/adaptor/onnxrt.py
Expand Up @@ -942,7 +942,9 @@ def query_fw_capability(self, model):
precisions = query.get_precisions()

for precision in precisions:
if precision == 'fp16' and self.device == 'cpu':
if precision in ['fp16', 'bf16'] and (self.device == 'cpu' or self.backend != 'CUDAExecutionProvider'):
continue
elif precision == 'bf16' and 'CUDAExecutionProvider' not in ort.get_available_providers():
continue
# get supported optype for target precision
optypes = query.get_op_types_by_precision(precision) if \
Expand Down
17 changes: 17 additions & 0 deletions neural_compressor/config.py
Expand Up @@ -1200,6 +1200,7 @@ class MixedPrecisionConfig(PostTrainingQuantConfig):
def __init__(self,
device="cpu",
backend="default",
precision="bf16",
inputs=[],
outputs=[],
tuning_criterion=tuning_criterion,
Expand All @@ -1214,7 +1215,23 @@ def __init__(self,
accuracy_criterion=accuracy_criterion,
excluded_precisions=excluded_precisions,
)
self.precision = precision

@property
def precision(self):
"""Get precision."""
return self._precision

@precision.setter
def precision(self, precision):
"""Set precision."""
if isinstance(precision, str):
assert precision in ["fp16", "bf16"], "Only support 'fp16' and 'bf16' for mix precision."
self._precision = [precision]
elif isinstance(precision, list):
assert all([i in ["fp16", "bf16"] for i in precision]), "Only " \
"support 'fp16' and 'bf16' for mix precision."
self._precision = precision

class ExportConfig:
"""Config Class for Export."""
Expand Down
22 changes: 18 additions & 4 deletions neural_compressor/mix_precision.py
Expand Up @@ -370,18 +370,32 @@ def fit(model,
converted_model = mix_precision.fit(model, config=conf)
"""
converter = MixedPrecision(config)
precisions = ["bf16", "fp16", "fp32"]
precisions = list(set(precisions) - set(config.excluded_precisions))
if config.precision in config.excluded_precisions:
logger.warning("Target precision is in excluded_precisions, "\
"please modify precision or excluded_precisions to make it understandable.")
sys.exit(0)
precisions = list(set(config.precision) - set(config.excluded_precisions))
converter.precisions = precisions
if 'bf16' in precisions and not CpuInfo().bf16:
converter.model = model

if ('bf16' in precisions or 'fp16' in precisions) and converter.model.framework() == "onnxruntime":
if config.device == "cpu":
logger.warning("Mix precision exits due to device isn't gpu for onnx models.")
sys.exit(0)
elif config.backend != "onnxrt_cuda_ep":
logger.warning("Mix precision exits due to backend isn't onnxrt_cuda_ep for onnx models.")
sys.exit(0)
elif 'bf16' in precisions and not CpuInfo().bf16 and converter.model.framework() != "onnxruntime":
if os.getenv('FORCE_BF16') == '1':
logger.warning("Mix precision will generate bf16 graph although " \
"the hardware doesn't support bf16 instruction.")
else:
logger.warning("Mix precision exits due to the hardware " \
"doesn't support bf16 instruction.")
sys.exit(0)
converter.model = model
elif 'fp16' in precisions and converter.model.framework() != "onnxruntime":
logger.warning("Currently mix precision only supports fp16 for onnx models.")
sys.exit(0)
if eval_func is not None:
converter.eval_func = eval_func
if eval_dataloader is not None:
Expand Down
4 changes: 2 additions & 2 deletions test/mixed_precision/test_mixed_precision.py
Expand Up @@ -274,7 +274,7 @@ def test_on_non_enabled_dtype(self):
output_model = mix_precision.fit(self.onnx_model, conf)
self.assertEqual(cm.exception.code, 0)

conf = MixedPrecisionConfig(excluded_precisions=["fp16"])
conf = MixedPrecisionConfig(precision="fp16")
with self.assertRaises(SystemExit) as cm:
output_model = mix_precision.fit(self.tf_model, conf)
self.assertEqual(cm.exception.code, 0)
Expand Down Expand Up @@ -309,7 +309,7 @@ def test_mixed_precision_with_evaluation(self):
#self.assertTrue(any([i.op_type == 'Cast' for i in output_model.nodes()]))

tuning_criterion = TuningCriterion(max_trials=3, timeout=1000000)
conf = MixedPrecisionConfig(device='gpu', tuning_criterion=tuning_criterion, backend='onnxrt_cuda_ep', excluded_precisions=['bf16'])
conf = MixedPrecisionConfig(device='gpu', tuning_criterion=tuning_criterion, backend='onnxrt_cuda_ep', precision="fp16")
output_model = mix_precision.fit(self.onnx_model,
conf,
eval_dataloader=self.matmul_dataloader,
Expand Down

0 comments on commit 4b71a82

Please sign in to comment.