Skip to content

Commit

Permalink
Fixed inspect tensor (#957)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Co-authored-by: Wang, Mengni <mengni.wang@intel.com>
  • Loading branch information
yiliu30 and mengniwang95 committed Jun 12, 2023
1 parent 86b63d1 commit 8f5f5de
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 19 deletions.
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/onnxrt.py
Expand Up @@ -1265,7 +1265,7 @@ def diagnosis_helper(self, fp32_model, int8_model, tune_cfg=None, save_path=None
'min': np.array(self.min_max[node.output[0]][0], dtype=np.float32),
'max': np.array(self.min_max[node.output[0]][1], dtype=np.float32)}
if save_path:
dump_data_to_local(filtered_params, save_path, 'dequan_min_max.pkl')
dump_data_to_local(filtered_params, save_path, 'activation_min_max.pkl')
dump_data_to_local(tune_cfg, save_path, 'cfg.pkl')
return inspect_node_list, tune_cfg

Expand Down
4 changes: 3 additions & 1 deletion neural_compressor/adaptor/ox_utils/calibration.py
Expand Up @@ -526,7 +526,9 @@ def dump_tensor(self, activation=True, weight=False):
for i in range(iters):
map_node_activation[i][node_name] = \
{tensor_name.replace('_quantized', ''): tensors[i]}
else:
elif not (node.op_type in ['Conv', 'Gemm', 'FusedConv'] and tensor_name not in node.input[:2]) and \
not (node.op_type in ['QLinearConv'] and tensor_name not in node.input[:8]) and \
not (node.op_type in ['QGemm'] and tensor_name not in node.input[:6]):
map_node_weight[node_name].update({tensor_name.replace('_quantized', ''): \
tensors[0]})
dumped_tensors_map = {}
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/adaptor/tensorflow.py
Expand Up @@ -1034,7 +1034,7 @@ def _get_fp32_op_name(model, tensor_name):

def inspect_weight_and_bias(self, node_list, graph_def, graph_info, graph_node_name_mapping):
"""Inspect the weights and biases."""
from neural_compressor.utils.utility import DequantizeWeight
from neural_compressor.utils.utility import dequantize_weight
from neural_compressor.adaptor.tf_utils.util import get_tensor_val_from_graph_node
from .tf_utils.util import int8_node_name_reverse
import tensorflow as tf
Expand Down Expand Up @@ -1069,7 +1069,7 @@ def inspect_weight_and_bias(self, node_list, graph_def, graph_info, graph_node_n
else:
min_filter_val = get_tensor_val_from_graph_node(graph_node_name_mapping, min_filter_node)
max_filter_val = get_tensor_val_from_graph_node(graph_node_name_mapping, max_filter_node)
DequantizeWeight(weight_node_val, min_filter_val, max_filter_val)
weight_node_val = dequantize_weight(weight_node_val, min_filter_val, max_filter_val)
weights_result[node_name] = {weight_node_name: weight_node_val}
return weights_result

Expand Down
10 changes: 5 additions & 5 deletions neural_compressor/adaptor/tf_utils/util.py
Expand Up @@ -502,25 +502,25 @@ def tf_diagnosis_helper(fp32_model, quan_model, tune_cfg, save_path):
else:
continue
inspect_node_lst = fp32_node_lst.intersection(bf16_node_lst.union(int8_node_lst))
dequan_min_max, updated_cfg = _parse_config(quan_model.q_config, tune_cfg, inspect_node_lst)
dump_data_to_local(dequan_min_max, save_path, 'dequan_min_max.pkl')
activation_min_max, updated_cfg = _parse_config(quan_model.q_config, tune_cfg, inspect_node_lst)
dump_data_to_local(activation_min_max, save_path, 'activation_min_max.pkl')
dump_data_to_local(updated_cfg, save_path, 'cfg.pkl')

return inspect_node_lst, updated_cfg

def _parse_config(q_config, cfg, op_list):
"""Parse q_config and get dequantize min max value."""
dequan_min_max = {}
activation_min_max = {}
if '__requant_min_max' in q_config:
for node_name, val in q_config['__requant_min_max'].items():
node_name = node_name.split('_eightbit_requant_range')[0]
if node_name in op_list:
dequan_min_max[node_name] = {'min': val[0], 'max': val[1]}
activation_min_max[node_name] = {'min': val[0], 'max': val[1]}
updated_cfg = {'op' : {}}
for op_name_and_type in cfg['op'].keys():
if op_name_and_type[0] in op_list:
updated_cfg['op'][op_name_and_type] = cfg['op'][op_name_and_type]
return dequan_min_max, updated_cfg
return activation_min_max, updated_cfg

def generate_feed_dict(input_tensor, inputs):
"""Generate feed dict helper function."""
Expand Down
17 changes: 9 additions & 8 deletions neural_compressor/utils/utility.py
Expand Up @@ -413,15 +413,16 @@ def str2array(s):
return np.array(ast.literal_eval(s))


def DequantizeWeight(weight_tensor, min_filter_tensor, max_filter_tensor):
def dequantize_weight(weight_tensor, min_filter_tensor, max_filter_tensor):
"""Dequantize the weight with min-max filter tensors."""
weight_channel = weight_tensor.shape[-1]
if len(min_filter_tensor) == 1:
return weight_tensor * ((max_filter_tensor[0] - min_filter_tensor[0])/ 127.0)
# TODO to calculate the de-quantized result in a parallel way
for i in range(weight_channel):
weight_tensor[:,:,:,i] = weight_tensor[:,:,:,i] * ((max_filter_tensor[i] - min_filter_tensor[i])/ 127.0)

weight_tensor = weight_tensor * ((max_filter_tensor[0] - min_filter_tensor[0])/ 127.0)
else:
# TODO to calculate the de-quantized result in a parallel way
for i in range(weight_channel):
weight_tensor[:,:,:,i] = weight_tensor[:,:,:,i] * ((max_filter_tensor[i] - min_filter_tensor[i])/ 127.0)
return weight_tensor

def Dequantize(data, scale_info):
"""Dequantize the data with the scale_info."""
Expand Down Expand Up @@ -938,7 +939,7 @@ def print_op_list(workload_location: str):
Returns:
None
"""
minmax_file_path = os.path.join(workload_location, "inspect_saved", "dequan_min_max.pkl")
minmax_file_path = os.path.join(workload_location, "inspect_saved", "activation_min_max.pkl")
input_model_tensors = get_tensors_info(
workload_location,
model_type="input",
Expand Down Expand Up @@ -983,7 +984,7 @@ def get_op_list(minmax_file_path, input_model_tensors, optimized_model_tensors)
"""Get OP list for model.
Args:
minmax_file_path: path to dequan_min_max.pkl
minmax_file_path: path to activation_min_max.pkl
input_model_tensors: dict with input tensors details
optimized_model_tensors: dict with optimized tensors details
Expand Down
4 changes: 2 additions & 2 deletions neural_insights/components/diagnosis/diagnosis.py
Expand Up @@ -79,7 +79,7 @@ def get_op_list(self) -> List[dict]:
minmax_file_path = os.path.join(
self.workload_location,
"inspect_saved",
"dequan_min_max.pkl",
"activation_min_max.pkl",
)
with open(minmax_file_path, "rb") as min_max_file:
min_max_data: dict = pickle.load(min_max_file)
Expand Down Expand Up @@ -110,7 +110,7 @@ def get_weights_details(self, inspect_type: str) -> List[WeightsDetails]:
minmax_file_path = os.path.join(
self.workload_location,
"inspect_saved",
"dequan_min_max.pkl",
"activation_min_max.pkl",
)
with open(minmax_file_path, "rb") as min_max_file:
min_max_data: dict = pickle.load(min_max_file)
Expand Down

0 comments on commit 8f5f5de

Please sign in to comment.