Skip to content

Commit

Permalink
Add distributed fallback by blockwise (#1179)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
  • Loading branch information
Kaihui-intel committed Aug 25, 2023
1 parent 04884ed commit ea309f5
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 6 deletions.
35 changes: 30 additions & 5 deletions neural_compressor/strategy/basic.py
Expand Up @@ -43,8 +43,9 @@ def distributed_next_tune_cfg_lst(self, comm):
"""Generate and yield the next tuning config list with below order.
1. OP Type Wise Tuning
2. Fallback OP One by One
3. Fallback Multiple OPs Accumulated
2. Fallback OPs Block by Block
3. Fallback OP One by One
4. Fallback Multiple OPs Accumulated
Yields:
tuning_config_list (list): A list containing dicts of the tuning configuration for quantization.
Expand All @@ -62,6 +63,18 @@ def distributed_next_tune_cfg_lst(self, comm):
quant_ops = quant_mode_wise_items['static'] if 'static' in quant_mode_wise_items else []
quant_ops += quant_mode_wise_items['dynamic'] if 'dynamic' in quant_mode_wise_items else []
stage1_max = 1e9 # TODO set a more appropriate value
if not self.cur_best_tuning_cfg:
self.cur_best_tuning_cfg = deepcopy(initial_op_tuning_cfg)

# try to tune sq alpha
op_tuning_cfg_lst_stage_sq = []
if self._should_tuning_sq_alpha(self.config.recipes):
for tune_cfg in self.tuning_sq_alpha(tuning_space, \
deepcopy(self.cur_best_tuning_cfg), self.config.recipes):
op_tuning_cfg_lst_stage_sq.append(tune_cfg)
yield op_tuning_cfg_lst_stage_sq

# op type-wise tuning
op_type_wise_tuning_sampler = OpTypeWiseTuningSampler(tuning_space, [], [],
op_item_dtype_dict, initial_op_tuning_cfg)
# stage 1: yield op_tune_cfg_lst
Expand All @@ -83,6 +96,7 @@ def distributed_next_tune_cfg_lst(self, comm):
else:
self.cur_best_tuning_cfg = comm.bcast(cur_best_tuning_cfg, root=0)


# stage 2: yield new_op_tuning_cfg_lst (length of stage 1)
# Fallback the ops supported both static and dynamic from static to dynamic
# Tuning items: None
Expand Down Expand Up @@ -113,12 +127,25 @@ def distributed_next_tune_cfg_lst(self, comm):
best_op_tuning_cfg_stage1 = deepcopy(self.cur_best_tuning_cfg)

# Fallback
# Fallback block after stage (1, 2) and before stage (3, 4)
# stage 3, 4: yield op_tuning_cfg_lst
op_tuning_cfg_lst_stage_block = []
op_tuning_cfg_lst_stage_3 = []
op_tuning_cfg_lst_stage_4 = []
for target_dtype in ['bf16', 'fp32']:
for target_dtype in PRECISION_LIST:
target_type_lst = set(tuning_space.query_items_by_quant_mode(target_dtype))
fallback_items_lst = [item for item in quant_ops if item in target_type_lst]

# Fallback block by block
for op_tuning_cfg in self.fallback_by_block(fallback_items_lst, best_op_tuning_cfg_stage1,
target_dtype,
tuning_space,
calib_sampling_size):
op_tuning_cfg_lst_stage_block.append(deepcopy(op_tuning_cfg))
logger.info("yield op_tuning_cfg_lst_stage_block with length {}"\
.format(len(op_tuning_cfg_lst_stage_block)))
yield op_tuning_cfg_lst_stage_block

if fallback_items_lst:
logger.info(f"Start to fallback op to {target_dtype} one by one.")
self._fallback_started()
Expand Down Expand Up @@ -273,8 +300,6 @@ def next_tune_cfg(self):
op_item_dtype_dict, initial_op_tuning_cfg)

for index, op_tuning_cfg in enumerate(op_type_wise_tuning_sampler):
if not self.cur_best_tuning_cfg:
self.cur_best_tuning_cfg = deepcopy(initial_op_tuning_cfg)
op_tuning_cfg['calib_sampling_size'] = calib_sampling_size
# try to quantizing ops into lower bits, such as int4,
# if accuracy meets the requirements after first trial and max_trials > 1
Expand Down
6 changes: 5 additions & 1 deletion neural_compressor/strategy/strategy.py
Expand Up @@ -446,13 +446,17 @@ def traverse(self):
from mpi4py import MPI
if MPI.COMM_WORLD.Get_size() > 2:
logger.info("Use distributed tuning on {} nodes".format(MPI.COMM_WORLD.Get_size()))
return self.distributed_traverse()
elif MPI.COMM_WORLD.Get_size() == 2:
logger.info("Use distributed tuning on {} nodes, will be fallback to normal tuning."\
.format(MPI.COMM_WORLD.Get_size()))
MPI_INSTALLED=True
except (ImportError, AttributeError) as e:
logger.warning("[Strategy] Please install `mpi4py` correctly if using distributed tuning;" + \
" otherwise, ignore this warning.")
MPI_INSTALLED=False
if MPI_INSTALLED:
if MPI.COMM_WORLD.Get_size() > 2:
return self.distributed_traverse()
self._setup_pre_tuning_algo_scheduler()
self._prepare_tuning()
# import pdb;pdb.set_trace()
Expand Down

0 comments on commit ea309f5

Please sign in to comment.