Skip to content

Commit

Permalink
Support _QuantizedMatMul output quantization mode MIN_FIRST (#1437)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvliang-intel committed Nov 6, 2022
1 parent c3ff84a commit fcfafc5
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ class WidedeepDataloader(DefaultDataLoader):
def _generate_dataloader(self, dataset, batch_size, last_batch, collate_fn, sampler,
batch_sampler, num_workers, pin_memory, shuffle, distributed):

drop_last = False if last_batch == 'rollover' else True
sampler = self._generate_sampler(dataset, distributed)
self.batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.fetcher = FETCHERS[self.dataset_type](dataset, collate_fn, drop_last, distributed)
self.batch_sampler = BatchSampler(sampler, batch_size, self.drop_last)
self.fetcher = FETCHERS[self.dataset_type](dataset, collate_fn, self.drop_last, distributed)

for batched_indices in self.batch_sampler:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,6 @@ def __init__(self, **kwargs):
self.sorted_patterns = sorted(self.patterns,
key=lambda i: len(i),
reverse=True)
# TODO Remove this when TFDO supports output_quantization_mode 'MIN_FIRST'
# Root cause of the transformer_lt_mlperf model accuracy drop:
# MatMul + Relu fusion ==> the output quantization mode only can be set to 'SCALED',
# if the input_quantization_mode of the next _QuantizedMatMul is set to 'MIN_FIRST'.
# the mismatch will cause the accrucy drop.
if not self.performance_only:
if ['Dequantize', 'MatMul', 'Relu', 'QuantizeV2'] in self.sorted_patterns:
self.sorted_patterns.remove(['Dequantize', 'MatMul', 'Relu', 'QuantizeV2'])
if ['Dequantize', 'MatMul', 'BiasAdd', 'Relu', 'QuantizeV2'] in self.sorted_patterns:
self.sorted_patterns.remove(
['Dequantize', 'MatMul', 'BiasAdd', 'Relu', 'QuantizeV2'])

self.fusion_op_type = set(fusion[1] for fusion in self.patterns)

self.fusion_mapping = {
Expand Down Expand Up @@ -122,9 +110,9 @@ def apply_matmul_biasadd_relu_fusion(self, match_node_name):
helper.node_name_from_input(weight_node.input[0])].node
# FIXME We only quantize the MatMul op which second input node type is const. This is a
# workaround for RNN model like LTSM.
if not parent_node.op == 'Const':
self.logger.debug(' \
The weight node of matched_node {} is not Const or Const + Enter, skipped')
if parent_node.op != 'Const':
self.logger.debug( \
'The weight node of matched_node {} is not Const or Const + Enter, skipped')
self.exclude_matmul_nodes.append(matched_node.node.name)
self.output_graph = self.input_graph
return []
Expand Down Expand Up @@ -248,8 +236,8 @@ def apply_matmul_biasadd_relu_fusion(self, match_node_name):
helper.set_attr_dtype(quantized_matmul_node, "Tout", dtypes.qint32)
helper.set_attr_string(quantized_matmul_node, 'input_quant_mode',
b'MIN_FIRST' if self.is_asymmetric else b'SCALED')
# TODO TFDO will extend output quantization mode to MIN_FIRST in future.
helper.set_attr_string(quantized_matmul_node, 'output_quant_mode', b'SCALED')
helper.set_attr_string(quantized_matmul_node, 'output_quant_mode',
b'MIN_FIRST' if self.is_asymmetric else b'SCALED')
if self.node_name_mapping[relu_node_name].node.op == "Relu":
helper.set_attr_string_list(quantized_matmul_node,
'fused_ops', [b'BiasAdd', b'Relu'])
Expand Down Expand Up @@ -306,8 +294,8 @@ def apply_matmul_biasadd_relu_fusion(self, match_node_name):
if new_node.name in matmul_node_output:
for idx, node_input in enumerate(new_node.input):
if helper.node_name_from_input(node_input) == matmul_node.name:
new_node.input[idx] = \
node_input.replace(matmul_node.name, quantized_node_name)
new_node.input[idx] = node_input.replace(
matmul_node.name, quantized_node_name)
self.add_output_graph_node(new_node)
return match_node_name

Expand Down Expand Up @@ -340,16 +328,15 @@ def apply_matmul_biasadd_fusion(self, match_node_name):
helper.node_name_from_input(weight_node.input[0])].node
# FIXME We only quantize the MatMul op which second input node type is const. This is a
# workaround for RNN model like LTSM.
if not parent_node.op == 'Const':
self.logger.debug(' \
The weight node of matched_node {} is not Const or Const + Enter, skipped')
if parent_node.op != 'Const':
self.logger.debug(
'The weight node of matched_node {} is not Const or Const + Enter, skipped')
self.exclude_matmul_nodes.append(matched_node.node.name)
self.output_graph = self.input_graph
return []
else:
enter_node = weight_node
weight_node = parent_node
weight_name = weight_node.name
enter_node = weight_node
weight_node = parent_node
weight_name = weight_node.name
# QDQ inserted for other weight nodes in phase 1
else:
_, q_weights_inputs = self._get_node_input(weight_name)
Expand All @@ -376,12 +363,6 @@ def apply_matmul_biasadd_fusion(self, match_node_name):
self.output_graph = self.input_graph
return []

# If weight node non const, can't insert dummy biasadd to do matmul fusion.
if weight_node.op != 'Const' and len(match_node_name) == 3:
self.exclude_matmul_nodes.append(matched_node.node.name)
self.output_graph = self.input_graph
return []

len_output = len(matched_node.output)
is_shared_output = False
if len_output == 2:
Expand All @@ -408,6 +389,8 @@ def apply_matmul_biasadd_fusion(self, match_node_name):
if (add_a_node.op != 'Const' and add_b_node.op == 'Const') or\
(add_a_node.op != 'Const' and add_b_node.op == 'Enter'):
single_matmul_fusion = False
else:
return self.apply_matmul_biasadd_fusion(match_node_name[:2]+[match_node_name[-1]])

sum_node_name = ""
if len(match_node_name) == 4:
Expand Down Expand Up @@ -499,8 +482,8 @@ def apply_matmul_biasadd_fusion(self, match_node_name):
helper.set_attr_dtype(quantized_matmul_node, 'U', dtypes.float32)
helper.set_attr_string(quantized_matmul_node, 'input_quant_mode',
b'MIN_FIRST' if self.is_asymmetric else b'SCALED')
# TODO TFDO will extend output quantization mode to MIN_FIRST in future.
helper.set_attr_string(quantized_matmul_node, 'output_quant_mode', b'SCALED')
helper.set_attr_string(quantized_matmul_node, 'output_quant_mode',
b'MIN_FIRST' if self.is_asymmetric else b'SCALED')
helper.set_attr_dtype(quantized_matmul_node, 'Tbias', dtypes.float32)
if sum_node_name:
helper.set_attr_string_list(quantized_matmul_node, 'fused_ops',
Expand Down Expand Up @@ -583,8 +566,8 @@ def apply_batchmatmulv2_fusion(self, match_node_name):
helper.node_name_from_input(weight_node.input[0])].node
# FIXME We only quantize the MatMul op which second input node type is const. This is a
# workaround for RNN model like LTSM.
if not parent_node.op == 'Const':
self.logger.debug( \
if parent_node.op != 'Const':
self.logger.debug(
'The weight node of matched_node {} is not Const or Const + Enter, skipped')
self.exclude_matmul_nodes.append(matched_node.node.name)
self.output_graph = self.input_graph
Expand Down Expand Up @@ -680,11 +663,13 @@ def apply_batchmatmulv2_fusion(self, match_node_name):

self.add_output_graph_node(quantized_matmul_node)
for i in self.node_name_mapping[node.name].output:
batchmatmul_next_node[i] = quantized_node_name
batchmatmul_next_node[i] = (quantized_node_name, node.name)
else:
new_node = node_def_pb2.NodeDef()
if batchmatmul_next_node.get(node.name):
node.input[0] = batchmatmul_next_node[node.name]
for index, name in enumerate(node.input):
if name == batchmatmul_next_node[node.name][1]:
node.input[index] = batchmatmul_next_node[node.name][0]
new_node.CopyFrom(node)
self.add_output_graph_node(new_node)
return match_node_name
Expand All @@ -698,6 +683,20 @@ def apply_batchmatmulv2_mul_add_fusion(self, match_node_name):
# Dequantize + BatchMatMulV2 + Mul + AddV2 + QuantizeV2
skip_node_name = match_node_name[2:]
matched_node = self.node_name_mapping[match_node_name[1]]
# oneDNN limitation: add tensor ndim must be 4
if len(match_node_name) == 4 and \
self.node_name_mapping[match_node_name[2]].node.op in ("Add","AddV2"):
add_node_input_name = self.node_name_mapping[match_node_name[2]].node.input[1]
if add_node_input_name == matched_node.node.name:
add_node_input_name = self.node_name_mapping[match_node_name[2]].node.input[0]
add_input_node = self.node_name_mapping[add_node_input_name].node
if add_input_node.op != 'Const':
return self.apply_batchmatmulv2_fusion(match_node_name[:2]+[match_node_name[-1]])

shape = tensor_util.MakeNdarray(add_input_node.attr["value"].tensor)
if shape.ndim != 4:
return self.apply_batchmatmulv2_fusion(match_node_name[:2]+[match_node_name[-1]])

control_inputs, normal_inputs = self._get_node_input(
matched_node.node.name)

Expand All @@ -713,8 +712,8 @@ def apply_batchmatmulv2_mul_add_fusion(self, match_node_name):
helper.node_name_from_input(weight_node.input[0])].node
# FIXME We only quantize the MatMul op which second input node type is const. This is a
# workaround for RNN model like LTSM.
if not parent_node.op == 'Const':
self.logger.debug( \
if parent_node.op != 'Const':
self.logger.debug(
'The weight node of matched_node {} is not Const or Const + Enter, skipped')
self.exclude_matmul_nodes.append(matched_node.node.name)
self.output_graph = self.input_graph
Expand Down Expand Up @@ -853,14 +852,16 @@ def apply_batchmatmulv2_mul_add_fusion(self, match_node_name):
.decode('UTF-8', 'ignore').strip() if x.isprintable())
if "MulAdd" in attr_fused_ops:
for i in self.node_name_mapping[match_node_name[3]].output:
batchmatmul_next_node[i] = quantized_node_name
batchmatmul_next_node[i] = (quantized_node_name, match_node_name[3])
else:
for i in self.node_name_mapping[match_node_name[2]].output:
batchmatmul_next_node[i] = quantized_node_name
batchmatmul_next_node[i] = (quantized_node_name, match_node_name[2])
else:
new_node = node_def_pb2.NodeDef()
if batchmatmul_next_node.get(node.name):
node.input[0] = batchmatmul_next_node[node.name]
for index, name in enumerate(node.input):
if name == batchmatmul_next_node[node.name][1]:
node.input[index] = batchmatmul_next_node[node.name][0]
new_node.CopyFrom(node)
self.add_output_graph_node(new_node)
return match_node_name
Expand All @@ -885,7 +886,7 @@ def apply_the_transform(self):
self.logger.debug("Unknown fusion pattern {}.".format(fusion_name))
if self.remove_redundant_quant_flag:
self.input_graph = self.remove_redundant_quantization(self.input_graph)
return self.input_graph, []
return self.input_graph, self.exclude_matmul_nodes

self.input_graph = self.output_graph
self._reset_output_node_maps()
Expand All @@ -895,7 +896,7 @@ def apply_the_transform(self):

if self.remove_redundant_quant_flag:
self.input_graph = self.remove_redundant_quantization(self.input_graph)
return self.input_graph, []
return self.input_graph, self.exclude_matmul_nodes

def _is_match_matmul(self, patterns, qdq_inserted=False):
"""Detect the rule matched nodes collections.
Expand All @@ -915,6 +916,7 @@ def _is_match_matmul(self, patterns, qdq_inserted=False):

if not self.performance_only and (cur_node.op == 'BatchMatMulV2' or
cur_node.op == 'BatchMatMul') and not self.itex_mode:
self.exclude_matmul_nodes.append(cur_node.name)
continue

control_inputs, normal_inputs = self._get_node_input(cur_node.name)
Expand All @@ -925,10 +927,10 @@ def _is_match_matmul(self, patterns, qdq_inserted=False):
# This is a workaround for RNN model like LTSM.
parent_node = None
if cur_node.op == "MatMul" and not self.itex_mode:
if control_inputs:
self.exclude_matmul_nodes.append(cur_node.name)
continue
if weight_node.op != 'Const':
if not self.performance_only:
continue

if weight_node.input:
parent_node = self.node_name_mapping \
[helper.node_name_from_input(weight_node.input[0])].node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import collections
import numpy as np
from math import ceil, floor
from abc import abstractmethod
from .sampler import IterableSampler, SequentialSampler, BatchSampler
from .fetcher import FETCHERS
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(self, dataset, batch_size=1, last_batch='rollover', collate_fn=None
self._batch_size = batch_size
self.shuffle = shuffle
self.distributed = distributed
self.drop_last = False if last_batch == 'rollover' else True
if self.collate_fn == None:
self.collate_fn = default_collate

Expand All @@ -80,13 +82,28 @@ def __iter__(self):
shuffle=self.shuffle,
distributed=self.distributed)

def __len__(self):
try:
dataset_len = self.dataset.__len__()
except (AttributeError, TypeError):
dataset_len = 0
for _ in self.dataset:
dataset_len += 1
except:
raise ValueError(f"{self.dataset} is invalid, {self.dataset}" \
" does not support calculating the length of its dataloader")
if self.drop_last == False:
dataloader_len = ceil(dataset_len / self.batch_size)
else:
dataloader_len = floor(dataset_len / self.batch_size)
return dataloader_len

def _generate_dataloader(self, dataset, batch_size, last_batch, collate_fn, sampler,
batch_sampler, num_workers, pin_memory, shuffle, distributed):

drop_last = False if last_batch == 'rollover' else True
sampler = self._generate_sampler(dataset, distributed)
self.batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.fetcher = FETCHERS[self.dataset_type](dataset, collate_fn, drop_last, distributed)
self.batch_sampler = BatchSampler(sampler, batch_size, self.drop_last)
self.fetcher = FETCHERS[self.dataset_type](dataset, collate_fn, self.drop_last, distributed)

for batched_indices in self.batch_sampler:
try:
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/strategy/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def _tune_cfg_converter(self, op_tuning_cfg):
else:
tune_cfg[op_name_type] = op_config
tune_cfg['calib_sampling_size'] = op_tuning_cfg['calib_sampling_size']
if self.calib_dataloader:
if self.calib_dataloader is not None:
tune_cfg['calib_iteration'] = math.ceil(int(tune_cfg['calib_sampling_size']) / \
self.calib_dataloader.batch_size)
else:
Expand Down

0 comments on commit fcfafc5

Please sign in to comment.