Skip to content

Commit

Permalink
Support Quantization of Big Saved Model for TF Backend (#1396)
Browse files Browse the repository at this point in the history
Signed-off-by: zehao-intel <zehao.huang@intel.com>
  • Loading branch information
zehao-intel committed Nov 22, 2023
1 parent 173c188 commit 3b29252
Show file tree
Hide file tree
Showing 10 changed files with 1,095 additions and 24 deletions.
98 changes: 91 additions & 7 deletions neural_compressor/adaptor/tensorflow.py
Expand Up @@ -648,6 +648,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
fp32_ops=self.fp32_ops,
bf16_ops=self.bf16_ops,
data_loader=data_loader,
calib_func=q_func,
qdq_enabled=self.qdq_enabled,
new_api=self.new_api,
performance_only=self.performance_only,
Expand All @@ -670,6 +671,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
fp32_ops=self.fp32_ops,
bf16_ops=self.bf16_ops,
data_loader=data_loader,
calib_func=q_func,
qdq_enabled=self.qdq_enabled,
new_api=self.new_api,
performance_only=self.performance_only,
Expand All @@ -693,6 +695,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
fp32_ops=self.fp32_ops,
bf16_ops=self.bf16_ops,
data_loader=data_loader,
calib_func=q_func,
qdq_enabled=self.qdq_enabled,
new_api=self.new_api,
performance_only=self.performance_only,
Expand Down Expand Up @@ -761,15 +764,15 @@ def _dump_model_op_stats(self, model_graphdef):
if i.op in fp32_op_list:
if "T" not in i.attr and i.op != "Cast":
continue
if i.attr["T"].type == dtypes.bfloat16:
res[i.op]["BF16"] += 1
elif i.attr["T"].type in (dtypes.quint8, dtypes.qint8):
res[i.op]["INT8"] += 1
elif i.op == "Cast":
if i.op == "Cast":
if i.attr["DstT"].type == dtypes.bfloat16:
res[i.op]["BF16"] += 1
elif i.attr["DstT"].type == dtypes.float32:
res[i.op]["FP32"] += 1
elif i.attr["T"].type == dtypes.bfloat16:
res[i.op]["BF16"] += 1
elif i.attr["T"].type in (dtypes.quint8, dtypes.qint8):
res[i.op]["INT8"] += 1
else:
res[i.op]["FP32"] += 1

Expand Down Expand Up @@ -1815,7 +1818,6 @@ def smooth_quant(
model,
dataloader,
calib_iter=1,
tune_cfg=None,
alpha=0.5,
folding=False,
percentile=99.999,
Expand All @@ -1832,7 +1834,6 @@ def smooth_quant(
model: original model
dataloader: the calibration dataloader
calib_iter: how many steps of iterations on the dataloader to move forward
tune_cfg: quantization config
alpha: smooth alpha in SmoothQuant, 1.0 will fallback to SPIQ
folding: whether insert mul(False) or just allow foldable layers(True) for SmoothQuant
percentile: percentile of calibration to remove outliers
Expand All @@ -1852,6 +1853,11 @@ def smooth_quant(
if self.smooth_quant_model is not None:
return self.smooth_quant_model

if model.model_type == "llm_saved_model":
return self.smooth_quant_LLM(
model, dataloader, calib_iter, alpha, folding, percentile, op_types, scales_per_op
)

# Do a pre-optimization before smooth quant
from .tf_utils.graph_rewriter.generic.pre_optimize import PreOptimization

Expand All @@ -1860,6 +1866,7 @@ def smooth_quant(
model.graph_def = self.pre_optimized_model.graph_def

# Get the nodes list which can't be quantized from tune_cfg
tune_cfg = None
black_nodes = []
if tune_cfg is not None:
self._tuning_cfg_to_fw(tune_cfg)
Expand Down Expand Up @@ -1887,6 +1894,81 @@ def smooth_quant(
self.smooth_quant_model = model
return self.smooth_quant_model

def smooth_quant_LLM(
self,
model,
dataloader,
calib_iter=1,
alpha=0.5,
folding=False,
percentile=99.999,
op_types=["MatMul", "Conv2D"],
scales_per_op=True,
):
"""Convert the model by smooth quant.
Args:
model: original model of TensorflowLLMModel object.
calib_iter: how many steps of iterations on the dataloader to move forward.
tune_cfg: quantization config.
alpha: smooth alpha in SmoothQuant, 1.0 will fallback to SPIQ.
folding: whether insert mul(False) or just allow foldable layers(True) for SmoothQuant.
percentile: percentile of calibration to remove outliers.
op_types: The op types whose input tensor will be dumped.
scales_per_op: True, each op will have an individual scale, mainly for accuracy.
False, ops with the same input will share a scale, mainly for performance.
Returns:
model: A smoothed Tensorflow model.
"""
# Do a pre-optimization before smooth quant
from .tf_utils.graph_rewriter.generic.pre_optimize import PreOptimization

self.pre_optimizer_handle = PreOptimization(model, self.new_api, self.device)
self.pre_optimized_model = self.pre_optimizer_handle.get_optimized_model(self.itex_mode)
model.graph_def = self.pre_optimized_model.graph_def

# Get the nodes list which can't be quantized from tune_cfg
tune_cfg = None
black_nodes = []
if tune_cfg is not None:
self._tuning_cfg_to_fw(tune_cfg)
black_nodes = [node for node in self.quantize_config if self.quantize_config[node] == "fp32"]

# only support per-tensor MatMul now
op_types = ["MatMul"]
llm_temp_dir = self.work_dir + "/temp_saved_model"
# Run calibration to get max values per channel
from .tf_utils.smooth_quant_calibration import SmoothQuantCalibrationLLM

calibration = SmoothQuantCalibrationLLM(
model._model,
dataloader,
calib_iter,
op_types,
percentile,
black_nodes,
llm_temp_dir,
model.weight_name_mapping,
)
max_vals_per_channel, sq_target_node_names, sq_weight_tensor_dict, sq_graph_def = calibration(
model.input_node_names, model.output_node_names
)

# Calculate the smooth quant scaler and insert Mul op into the graph
from .tf_utils.smooth_quant_scaler import SmoothQuantScalerLLM

scaler = SmoothQuantScalerLLM(sq_graph_def, alpha, scales_per_op, op_types)
sq_graph_def, sq_weight_scale_dict, mul_list = scaler.transform(
max_vals_per_channel, sq_weight_tensor_dict, sq_target_node_names
)
model.graph_def = sq_graph_def
model.model_path = llm_temp_dir
model.sq_weight_scale_dict = sq_weight_scale_dict
self.smooth_quant_mul_ops.extend(mul_list)
self.smooth_quant_model = model
return self.smooth_quant_model


@adaptor_registry
class Tensorflow_ITEXAdaptor(TensorFlowAdaptor):
Expand Down Expand Up @@ -1945,6 +2027,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
fp32_ops=self.fp32_ops,
bf16_ops=self.bf16_ops,
data_loader=data_loader,
calib_func=q_func,
itex_mode=self.itex_mode,
qdq_enabled=self.qdq_enabled,
new_api=self.new_api,
Expand Down Expand Up @@ -1992,6 +2075,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
fp32_ops=self.fp32_ops,
bf16_ops=self.bf16_ops,
data_loader=data_loader,
calib_func=q_func,
itex_mode=self.itex_mode,
qdq_enabled=self.qdq_enabled,
new_api=self.new_api,
Expand Down
34 changes: 34 additions & 0 deletions neural_compressor/adaptor/tf_utils/graph_converter.py
Expand Up @@ -102,6 +102,7 @@ def __init__(
fp32_ops=[],
bf16_ops=[],
data_loader=None,
calib_func=None,
fake_quant=False,
itex_mode=False,
qdq_enabled=False,
Expand All @@ -116,6 +117,7 @@ def __init__(
:param fp32_ops: fall back to fp32 dtype op list
:param bf16_ops: fall back to bf16 dtype op list
:param data_loader: for calibration phase used dataloader
:param calib_func: for calibration phase used function
:param fake_quant: for quantization-aware training model conversion to default model
"""
self.model = model
Expand All @@ -139,6 +141,7 @@ def __init__(
self._calibration_data = []
self._fp32_print_data = []
self.data_loader = data_loader
self.calib_func = calib_func
self._check_tf_version()
self._check_args()

Expand All @@ -157,6 +160,7 @@ def __init__(
self._gen_tmp_filenames()
self._kl_op_dict = {}
self._kl_keys = []
self._llm_weight_minmax = {}
self._print_node_mapping = {}
self._enable_kl_op_names = [k for k in self.op_wise_config if self.op_wise_config[k][1] == "kl"]
self.scale_info = {}
Expand Down Expand Up @@ -193,6 +197,14 @@ def _inference(self, model):
Args:
model(TensorflowBaseModel): input TensorflowBaseModel
"""
if self.calib_func:
self.calib_func(model.model)
return

if model.model_type == "llm_saved_model":
self._inference_llm(model)
return

# ITEX optimization has broken INC calibration process.
# INC needs turn off ITEX optimization pass in calibration stage.
# TODO ITEX will provide API to replace setting environment variable.
Expand Down Expand Up @@ -281,6 +293,24 @@ def check_shape(tensor, data):
break
os.environ["ITEX_REMAPPER"] = "1"

def _inference_llm(self, model):
input_tensor_names = model.input_tensor_names
auto_trackable = model.model
infer = auto_trackable.signatures["serving_default"]
for idx, (inputs, _) in enumerate(self.data_loader):
feed_dict = {}
if len(input_tensor_names) == 1:
feed_dict[input_tensor_names[0]] = inputs
else:
assert len(input_tensor_names) == len(inputs), "inputs len must equal with input_tensor"
for i, input_tensor_name in enumerate(input_tensor_names):
feed_dict[input_tensor_name] = inputs[i]

_ = infer(**feed_dict)

if idx >= self.calib_iteration:
break

def _check_tf_version(self):
"""Check if the installed tensorflow version is supported."""
is_supported_version = False
Expand Down Expand Up @@ -849,6 +879,9 @@ def _insert_qdq_pairs(self):
self._inference(self._sampling_model)
self._calibration_data = Helper.gen_valid_sampling_log(tmp_dump_file)

if hasattr(self._sampling_model, "_weight_tensor_minmax_dict"):
self._llm_weight_minmax = self._sampling_model.weight_tensor_minmax_dict

del sampling_graph_def
del output_tensor_names
del self._sampling_model
Expand All @@ -868,6 +901,7 @@ def _insert_qdq_pairs(self):
self.device,
self.performance_only,
self.itex_mode,
self._llm_weight_minmax,
).do_transformation()

def _convert_qdq(self):
Expand Down
Expand Up @@ -46,6 +46,7 @@ def __init__(
device,
performance_only,
itex_mode,
llm_weight_minmax,
):
"""Initialization."""
super().__init__(model)
Expand All @@ -58,6 +59,7 @@ def __init__(
self.device = device
self.performance_only = performance_only
self.itex_mode = itex_mode
self.llm_weight_minmax = llm_weight_minmax
self.node_details = namedtuple("node_details", ["node", "output"])
self.node_name_mapping = {}
self.check_op_list = {
Expand Down Expand Up @@ -548,6 +550,24 @@ def _insert_qdq_pattern_for_weight_node(
# qint8_tensor = np.clip(qint8_tensor, -127, 127).astype(np.int8)
min_value = -range_value
max_value = range_value
elif weight_node.op == "ReadVariableOp":
min_value = self.llm_weight_minmax[weight_node.name][0]
max_value = self.llm_weight_minmax[weight_node.name][1]
min_value *= range_coefficent
max_value *= range_coefficent
min_value = min(min_value, 0.0)
if min_value == max_value:
if abs(min_value) < 0.000001:
max_value = min_value + 1.0
elif min_value > 0:
max_value = 2 * min_value
else:
max_value = min_value / 2.0
range_value = np.max(np.abs([min_value, max_value]))
# qint8_tensor = (np.around(float_tensor * 127.0 / range_value)).astype(np.int8)
# qint8_tensor = np.clip(qint8_tensor, -127, 127).astype(np.int8)
min_value = -range_value
max_value = range_value
elif host_op_type == "DepthwiseConv2dNative":
float_tensor = tensor_util.MakeNdarray(weight_node.attr["value"].tensor)
# get the max values based on dim 0 and 1 for depthwise conv
Expand Down
27 changes: 26 additions & 1 deletion neural_compressor/adaptor/tf_utils/graph_util.py
Expand Up @@ -1044,8 +1044,33 @@ def gen_per_iter(data):
res.append(mixed_str)
return res

def separate(line):
"""This function is to separate the strings.
Example:
';slice__print__;__max:[1];slice__print__;__min:[-1]' -->
[';slice__print__;__max:[1]', ';slice__print__;__min:[-1]']
"""
separated_lines = []
for subline in line.split("];"):
if not subline.startswith(";"):
subline = ";" + subline
if not subline.endswith("]"):
subline += "]"
separated_lines.append(subline)
return separated_lines

with open(log_path) as f:
valid_data = [i.strip() for i in f.readlines() if i.startswith(";")]
valid_data = []
for i in f.readlines():
if not i.startswith(";"):
continue
line = i.strip()
if line.find("];") != 0:
separated_lines = separate(line)
valid_data += separated_lines
else:
valid_data.append(line)

first_line = valid_data[0].rsplit(":")[0]

Expand Down

0 comments on commit 3b29252

Please sign in to comment.