Skip to content

Commit

Permalink
Enable INC Distributed Tuning (#470)
Browse files Browse the repository at this point in the history
* draft multinode code

* fix wrong cur_cfg_id

* add more details for pipeline

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>

* fix end tag and boundary overflow

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>

* remove tune cfg list limitation

* draft distributed basic strategy

Signed-off-by: spycsh <sihan.chen@intel.com>

* fix syntax bugs

* set q model

* add some debug info

* fix bug

* fix best_tune_cfg_id bug

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>

* add multi stage detail

* remove debug info

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>

* add UT flow

Signed-off-by: Spycsh <sihan.chen@intel.com>

* add UT example

Signed-off-by: Spycsh <sihan.chen@intel.com>

* distributed config

* fix use_distributed_tuning config bugs

* updated the accuracy comparison

Signed-off-by: yiliu30 <yi4.liu@intel.com>

* logger info

* add UT stage 2

* add more UT without coordination on fake_eval index

Signed-off-by: spycsh <sihan.chen@intel.com>

* add UTs

Signed-off-by: yiliu30 <yi4.liu@intel.com>

* Add WA for UTs

Signed-off-by: yiliu30 <yi4.liu@intel.com>

* revert wrongly commented next_tune_cfg code

* fix fake eval in UT

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>

* fix bug in tuning config gap on single-node and multi-node

Signed-off-by: spycsh <sihan.chen@intel.com>

* set tune_cfg_lst as a class attribute & delete debug info

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>

* bypass onnx exporting on slave ndoes

* enable example on multi nodes

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>

* remove UT tmp files

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>

* update op_fallback_acc_impact after stage 3

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>

* remove redundant code for strategy

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>

* set use distributed tuning in example

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>

* skip UT if no mpi4py

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>

* add docstring

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>

* use LazyImport and code check

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>

* add a blank line between summary line and description

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>

* code check

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>

* refine document

Signed-off-by: spycsh <sihan.chen@intel.com>

* remove Lazyimport mpi4py line

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>

---------

Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
Signed-off-by: spycsh <sihan.chen@intel.com>
Signed-off-by: Spycsh <sihan.chen@intel.com>
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Co-authored-by: Kaihui-intel <kaihui.tang@intel.com>
Co-authored-by: yiliu30 <yi4.liu@intel.com>
  • Loading branch information
3 people committed Feb 23, 2023
1 parent 08e2551 commit e1fe50e
Show file tree
Hide file tree
Showing 11 changed files with 785 additions and 46 deletions.
Expand Up @@ -52,7 +52,30 @@ python -u ./run_glue.py \
--overwrite_output_dir
```

### 2. Get the benchmark of tuned model, including Batch_size and Throughput:
You can also try to use INC distributed tuning (Take mrpc task as an example) as follows:

In `run_glue.py`, set `config.use_distributed_tuning` to True by the following statement.

```python
conf = PostTrainingQuantConfig(approach="static", tuning_criterion=tuning_criterion, use_distributed_tuning=True)
```

And then, run the following command:

```
mpirun -np <NUM_PROCESS> -mca btl_tcp_if_include <NETWORK_INTERFACE> -x OMP_NUM_THREADS=<MAX_NUM_THREADS> --host <HOSTNAME1>,<HOSTNAME2>,<HOSTNAME3> bash run_distributed_tuning.sh
```

* *`<NUM_PROCESS>`* is the number of processes, which is recommended to set to be equal to the number of hosts.

* *`<MAX_NUM_THREADS>`* is the number of threads, which is recommended to set to be equal to
the number of physical cores on one node.

* *`<HOSTNAME>`* is the host name, and argument `--host <HOSTNAME>,<HOSTNAME>...` can be replaced with `--hostfile <HOSTFILE>`, when each line in *`<HOSTFILE>`* is a host name.

* `-mca btl_tcp_if_include <NETWORK_INTERFACE>` is used to set the network communication interface between hosts. For example, *`<NETWORK_INTERFACE>`* can be set to 192.168.20.0/24 to allow the MPI communication between all hosts under the 192.168.20.* network segment.

### 2. Get the benchmark of the tuned model

```bash
python -u ./run_glue.py \
Expand Down
@@ -0,0 +1,4 @@
source ~/miniconda3/etc/profile.d/conda.sh
conda activate MPIENV
cd /YOURWORKDIR/examples/pytorch/nlp/huggingface_models/text-classification/quantization/ptq_static/fx
python -u ./run_glue.py --model_name_or_path distilbert_mrpc --task_name mrpc --do_eval --max_seq_length 128 --per_device_eval_batch_size 16 --no_cuda --output_dir ./int8_model_dir --tune --overwrite_output_dir
Expand Up @@ -510,47 +510,57 @@ def eval_func(model):
from neural_compressor.quantization import fit
from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion
tuning_criterion = TuningCriterion(max_trials=600)
conf = PostTrainingQuantConfig(approach="static", tuning_criterion=tuning_criterion)
conf = PostTrainingQuantConfig(approach="static", tuning_criterion=tuning_criterion, use_distributed_tuning=False)
q_model = fit(model, conf=conf, calib_dataloader=eval_dataloader, eval_func=eval_func)
from neural_compressor.utils.load_huggingface import save_for_huggingface_upstream
save_for_huggingface_upstream(q_model, tokenizer, training_args.output_dir)

if model_args.onnx:
it = iter(eval_dataloader)
input = next(it)
input.pop('labels')
symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
dynamic_axes = {k: symbolic_names for k in input.keys()}
from neural_compressor.config import Torch2ONNXConfig
fp32_onnx_config = Torch2ONNXConfig(
dtype="fp32",
opset_version=14,
example_inputs=tuple(input.values()),
input_names=list(input.keys()),
output_names=['labels'],
dynamic_axes=dynamic_axes,
)
q_model.export('fp32-model.onnx', fp32_onnx_config)
int8_onnx_config = Torch2ONNXConfig(
dtype="int8",
opset_version=14,
quant_format="QDQ",
example_inputs=tuple(input.values()),
input_names=list(input.keys()),
output_names=['labels'],
dynamic_axes=dynamic_axes,
)
q_model.export('int8-nlp-qdq-model.onnx', int8_onnx_config)
int8_onnx_config = Torch2ONNXConfig(
dtype="int8",
opset_version=14,
quant_format="QLinear",
example_inputs=tuple(input.values()),
input_names=list(input.keys()),
output_names=['labels'],
dynamic_axes=dynamic_axes,
)
q_model.export('int8-nlp-qlinear-model.onnx', int8_onnx_config)
# whether to use distributed tuning
if conf.use_distributed_tuning == True:
from mpi4py import MPI
comm = MPI.COMM_WORLD
size = comm.Get_size()
assert size > 1
rank = comm.Get_rank()
else:
rank = -1
if rank == 0 or conf.use_distributed_tuning == False:
from neural_compressor.utils.load_huggingface import save_for_huggingface_upstream
save_for_huggingface_upstream(q_model, tokenizer, training_args.output_dir)

if model_args.onnx:
it = iter(eval_dataloader)
input = next(it)
input.pop('labels')
symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
dynamic_axes = {k: symbolic_names for k in input.keys()}
from neural_compressor.config import Torch2ONNXConfig
fp32_onnx_config = Torch2ONNXConfig(
dtype="fp32",
opset_version=14,
example_inputs=tuple(input.values()),
input_names=list(input.keys()),
output_names=['labels'],
dynamic_axes=dynamic_axes,
)
q_model.export('fp32-model.onnx', fp32_onnx_config)
int8_onnx_config = Torch2ONNXConfig(
dtype="int8",
opset_version=14,
quant_format="QDQ",
example_inputs=tuple(input.values()),
input_names=list(input.keys()),
output_names=['labels'],
dynamic_axes=dynamic_axes,
)
q_model.export('int8-nlp-qdq-model.onnx', int8_onnx_config)
int8_onnx_config = Torch2ONNXConfig(
dtype="int8",
opset_version=14,
quant_format="QLinear",
example_inputs=tuple(input.values()),
input_names=list(input.keys()),
output_names=['labels'],
dynamic_axes=dynamic_axes,
)
q_model.export('int8-nlp-qlinear-model.onnx', int8_onnx_config)
return

if model_args.benchmark or model_args.accuracy_only:
Expand Down
1 change: 1 addition & 0 deletions neural_compressor/conf/config.py
Expand Up @@ -1417,6 +1417,7 @@ def map_pyconfig_to_cfg(self, pythonic_config):
'tuning.exit_policy.timeout': pythonic_config.quantization.timeout,
'tuning.exit_policy.max_trials': pythonic_config.quantization.max_trials,
'tuning.exit_policy.performance_only': pythonic_config.quantization.performance_only,
'tuning.use_distributed_tuning': pythonic_config.quantization.use_distributed_tuning,
'use_bf16': pythonic_config.quantization.use_bf16,
'quantization.quant_level': pythonic_config.quantization.quant_level,
'reduce_range': pythonic_config.quantization.reduce_range
Expand Down
6 changes: 4 additions & 2 deletions neural_compressor/conf/pythonic_config.py
Expand Up @@ -42,7 +42,8 @@ def __init__(self,
reduce_range=None,
use_bf16=True,
quant_level=1,
accuracy_criterion=accuracy_criterion):
accuracy_criterion=accuracy_criterion,
use_distributed_tuning=False):
excluded_precisions = ["bf16"] if not use_bf16 else []
super().__init__(
inputs=inputs,
Expand All @@ -61,7 +62,8 @@ def __init__(self,
reduce_range=reduce_range,
excluded_precisions=excluded_precisions,
accuracy_criterion=accuracy_criterion,
quant_level=quant_level
quant_level=quant_level,
use_distributed_tuning=use_distributed_tuning
)
self.approach = approach

Expand Down
17 changes: 15 additions & 2 deletions neural_compressor/config.py
Expand Up @@ -364,7 +364,8 @@ def __init__(self,
reduce_range=None,
excluded_precisions=[],
quant_level=1,
accuracy_criterion=accuracy_criterion):
accuracy_criterion=accuracy_criterion,
use_distributed_tuning=False):
"""Initialize _BaseQuantizationConfig class.
Args:
Expand Down Expand Up @@ -426,6 +427,7 @@ def __init__(self,
self.accuracy_criterion = accuracy_criterion
self.calibration_sampling_size = calibration_sampling_size
self.quant_level = quant_level
self.use_distributed_tuning=use_distributed_tuning

@property
def domain(self):
Expand Down Expand Up @@ -536,6 +538,15 @@ def quant_level(self):
def quant_level(self, quant_level):
self._quant_level = quant_level

@property
def use_distributed_tuning(self):
return self._use_distributed_tuning

@use_distributed_tuning.setter
def use_distributed_tuning(self, use_distributed_tuning):
if check_value('use_distributed_tuning', use_distributed_tuning, bool):
self._use_distributed_tuning = use_distributed_tuning

@property
def reduce_range(self):
return self._reduce_range
Expand Down Expand Up @@ -779,6 +790,7 @@ def __init__(self,
quant_level=1,
tuning_criterion=tuning_criterion,
accuracy_criterion=accuracy_criterion,
use_distributed_tuning=False,
):
"""Init a PostTrainingQuantConfig object."""
self.tuning_criterion = tuning_criterion
Expand All @@ -800,7 +812,8 @@ def __init__(self,
reduce_range=reduce_range,
excluded_precisions=excluded_precisions,
quant_level=quant_level,
accuracy_criterion=accuracy_criterion)
accuracy_criterion=accuracy_criterion,
use_distributed_tuning=use_distributed_tuning)
self.approach = approach

@property
Expand Down
33 changes: 33 additions & 0 deletions neural_compressor/experimental/quantization.py
Expand Up @@ -172,6 +172,10 @@ def pre_process(self):

def execute(self):
"""Quantization execute routinue based on strategy design."""
# check here the distributed flag
logger.info("..............use_distributed_tuning: {}".format(self.conf.usr_cfg.tuning.use_distributed_tuning))
if self.conf.usr_cfg.tuning.use_distributed_tuning:
return self.distributed_execute()
try:
with time_limit(self.conf.usr_cfg.tuning.exit_policy.timeout):
logger.debug("Dump user yaml configuration:")
Expand All @@ -195,6 +199,35 @@ def execute(self):
"Not found any quantized model which meet accuracy goal. Exit.")

return self.strategy.best_qmodel

def distributed_execute(self):
"""Quantization distributed execute routinue based on strategy design."""
from ..utils.utility import LazyImport
MPI = LazyImport("mpi4py.MPI")
comm = MPI.COMM_WORLD
try:
with time_limit(self.conf.usr_cfg.tuning.exit_policy.timeout):
self.strategy.traverse()
except KeyboardInterrupt:
pass
except Exception as e:
logger.error("Unexpected exception {} happened during tuning.".format(repr(e)))
import traceback
traceback.print_exc()
finally:
if self.strategy.best_qmodel:
logger.info(
"Specified timeout or max trials is reached! "
"Found a quantized model which meet accuracy goal. Exit.")
self.strategy.deploy_config()
else:
if comm.Get_rank() != 0: # slaves have no q model
return None
logger.error(
"Specified timeout or max trials is reached! "
"Not found any quantized model which meet accuracy goal. Exit.")

return self.strategy.best_qmodel

def __call__(self):
"""Automatic quantization tuning main entry point.
Expand Down
25 changes: 25 additions & 0 deletions neural_compressor/objective.py
Expand Up @@ -24,6 +24,8 @@
import numpy as np
from copy import deepcopy

from typing import List, Tuple

import tracemalloc
from .utils.utility import get_size

Expand Down Expand Up @@ -325,6 +327,29 @@ def accuracy_meets(self):
all_lower = all([_last < _target for _last, _target in zip(last_acc, self.accuracy_target) ])
got_better_result = (all_higher and self.higher_is_better) or (all_lower and not self.higher_is_better)
return got_better_result

def accuracy_meet_req(self, last_result: Tuple[float, List[float]]) -> bool:
"""Compare the result of last tuning with baseline to check whether the result meet requirements.
Args:
last_result: The evaluation result of the last tuning.
Returns:
check_result: Return True if the accuracy meets requirements else False.
"""
check_result = False
last_acc, _ = last_result
if not isinstance(last_acc, list):
last_acc = [last_acc]

if self.metric_weight is not None and len(last_acc) > 1:
last_acc = [np.mean(np.array(last_acc) * self.metric_weight)]
if not self._accuracy_target:
self.accuracy_target = self._get_accuracy_target()
all_higher = all([_last > _target for _last, _target in zip(last_acc, self.accuracy_target) ])
all_lower = all([_last < _target for _last, _target in zip(last_acc, self.accuracy_target) ])
check_result = (all_higher and self.higher_is_better) or (all_lower and not self.higher_is_better)
return check_result

def evaluate(self, eval_func, model):
"""The interface of calculating the objective.
Expand Down

0 comments on commit e1fe50e

Please sign in to comment.