Skip to content

Commit

Permalink
Enable more Ops for Keras PTQ and Support Mix-precision (#835)
Browse files Browse the repository at this point in the history
Signed-off-by: zehao-intel <zehao.huang@intel.com>
  • Loading branch information
zehao-intel committed May 11, 2023
1 parent 62be9d1 commit 6997518
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 16 deletions.
79 changes: 65 additions & 14 deletions neural_compressor/adaptor/keras.py
Expand Up @@ -39,13 +39,18 @@ def _add_supported_quantized_objects(custom_objects):
from neural_compressor.adaptor.keras_utils.depthwise_conv2d import QDepthwiseConv2D
from neural_compressor.adaptor.keras_utils.separable_conv2d import QSeparableConv2D
from neural_compressor.adaptor.keras_utils.dense import QDense
from neural_compressor.adaptor.keras_utils.pool2d import QMaxPool2D, QAvgPool2D
custom_objects["Quantize"] = Quantize
custom_objects["DeQuantize"] = DeQuantize
custom_objects["FakeQuant"] = FakeQuant
custom_objects["QConv2D"] = QConv2D
custom_objects["QDepthwiseConv2D"] = QDepthwiseConv2D
custom_objects["QSeparableConv2D"] = QSeparableConv2D
custom_objects["QDense"] = QDense
custom_objects["QMaxPool2D"] = QMaxPool2D
custom_objects["QAvgPool2D"] = QAvgPool2D
custom_objects["QMaxPooling2D"] = QMaxPool2D
custom_objects["QAveragePooling2D"] = QAvgPool2D
return custom_objects

@adaptor_registry
Expand All @@ -62,11 +67,13 @@ def __init__(self, framework_specific_info):
#self.work_dir = os.path.abspath(self.framework_specific_info['workspace_path'])
self.recipes = deep_get(self.framework_specific_info, 'recipes', {})
#os.makedirs(self.work_dir, exist_ok=True)
self.supported_op = ['Conv2D', 'Dense', 'SeparableConv2D', 'DepthwiseConv2D']
self.supported_op = ['Conv2D', 'Dense', 'SeparableConv2D', 'DepthwiseConv2D', 'AveragePooling2D',
'MaxPooling2D', 'AvgPool2D', 'MaxPool2D']

self.pre_optimized_object = None
self.pre_optimized_model = None
self.pre_optimizer_handle = None
self.bf16_ops = []
self.fp32_ops = []
self.query_handler = KerasQuery(local_config_file=os.path.join(
os.path.dirname(__file__), 'keras.yaml'))
Expand All @@ -84,6 +91,8 @@ def tuning_cfg_to_fw(self, tuning_cfg):
self.quantize_config['device'] = self.device
self.quantize_config['advance'] = deep_get(tuning_cfg, 'advance')
fp32_ops = []
bf16_ops = []
bf16_type = set(self.query_handler.get_op_types_by_precision(precision='bf16'))
dispatched_op_names = [j[0] for j in tuning_cfg['op']]
invalid_op_names = [i for i in self.quantize_config['op_wise_config']
if i not in dispatched_op_names]
Expand All @@ -93,6 +102,12 @@ def tuning_cfg_to_fw(self, tuning_cfg):

for each_op_info in tuning_cfg['op']:
op_name = each_op_info[0]

if tuning_cfg['op'][each_op_info]['activation']['dtype'] == 'bf16':
if each_op_info[1] in bf16_type:
bf16_ops.append(op_name)
continue

if tuning_cfg['op'][each_op_info]['activation']['dtype'] == 'fp32':
if op_name in self.quantize_config['op_wise_config']:
self.quantize_config['op_wise_config'].pop(op_name)
Expand All @@ -114,6 +129,8 @@ def tuning_cfg_to_fw(self, tuning_cfg):
algorithm,
is_asymmetric,
weight_bit)
self.bf16_ops = bf16_ops
self.bf16_ops.pop(-1)
self.fp32_ops = fp32_ops

def _pre_optimize(self, model):
Expand Down Expand Up @@ -266,6 +283,10 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
q_func (optional): training function for quantization aware training mode.
'''
self.tuning_cfg_to_fw(tune_cfg)
# just convert the input model to mixed_bfloat16
if self.bf16_ops and not self.quantize_config['op_wise_config']:
converted_model = self.convert_bf16()
return converted_model
logger.debug("Dump quantization configurations:")
logger.debug(self.quantize_config)
calib_sampling_size = tune_cfg.get('calib_sampling_size', 1)
Expand Down Expand Up @@ -394,16 +415,23 @@ def _calibrate(self, model, dataloader, calib_interation):
q_layer_name = 'Q' + layer['class_name']
# this is for inbounds search
q_name = layer['config']['name']
kernel = self.layer_weights[layer['config']['name']][0]
dim = list(range(0, kernel.ndim))
t_dim = [dim.pop(-1)]
t_dim.extend(dim)
channel_size = kernel.shape[-1]
kernel_channel = kernel.transpose(t_dim).reshape(channel_size, -1)
layer_config['min_value'] = json.dumps(\
np.min(kernel_channel, axis=1).tolist())
layer_config['max_value'] = json.dumps(\
np.max(kernel_channel, axis=1).tolist())
# for layers that have weights
if layer['config']['name'] in self.layer_weights:
kernel = self.layer_weights[layer['config']['name']][0]
dim = list(range(0, kernel.ndim))
t_dim = [dim.pop(-1)]
t_dim.extend(dim)
channel_size = kernel.shape[-1]
kernel_channel = kernel.transpose(t_dim).reshape(channel_size, -1)
layer_config['min_value'] = json.dumps(\
np.min(kernel_channel, axis=1).tolist())
layer_config['max_value'] = json.dumps(\
np.max(kernel_channel, axis=1).tolist())
else:
# default value, but never expected to be used
# cause no kernel weights for this layer
layer_config['min_value'] = json.dumps([-10000])
layer_config['max_value'] = json.dumps([10000])
layer_config['name'] = q_name
q_layer = {'class_name': q_layer_name,
'name': q_name,
Expand All @@ -418,6 +446,22 @@ def _calibrate(self, model, dataloader, calib_interation):
quantized_model = self._restore_model_from_json(json_model)
return quantized_model

def convert_bf16(self):
'''Execute the BF16 conversion.
'''
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
json_model = copy.deepcopy(json.loads(self.pre_optimized_object.to_json()))

for layer in json_model['config']['layers']:
if layer['config']['name'] in self.bf16_ops:
layer['config']['dtype'] = 'mixed_bfloat16'

converted_model = self._restore_model_from_json(json_model)
tf.keras.mixed_precision.set_global_policy('float32')

from neural_compressor.model.keras_model import KerasModel
converted_model = KerasModel(converted_model)
return converted_model

#(TODO) choose the properly quantize mode
def _check_quantize_mode(self, json_model):
Expand Down Expand Up @@ -510,12 +554,15 @@ def query_fw_capability(self, model):
model (object): The model to query quantization tuning capability.
'''
fp32_config = {'weight': {'dtype': 'fp32'}, 'activation': {'dtype': 'fp32'}}
bf16_config = {'weight': {'dtype': 'bf16'}, 'activation': {'dtype': 'bf16'}}
int8_type = self.query_handler.get_op_types_by_precision(precision='int8')
op_capability = self.query_handler.get_quantization_capability()
conv_config = copy.deepcopy(op_capability['int8']['Conv2D'])
conv_config = copy.deepcopy(op_capability['int8']['SeparableConv2D'])
conv_config = copy.deepcopy(op_capability['int8']['DepthwiseConv2D'])
dense_config = copy.deepcopy(op_capability['int8']['Dense'])
maxpool_config = copy.deepcopy(op_capability['int8']['MaxPooling2D'])
avgpool_config = copy.deepcopy(op_capability['int8']['AveragePooling2D'])
other_config = copy.deepcopy(op_capability['int8']['default'])

# # get fp32 layer weights
Expand Down Expand Up @@ -545,11 +592,15 @@ def query_fw_capability(self, model):
node_op = details['class_name']
node_name = details['config']['name']
if node_op == 'Conv2D':
quantizable_op_details[(node_name, node_op)] = [conv_config, fp32_config]
quantizable_op_details[(node_name, node_op)] = [conv_config, bf16_config, fp32_config]
elif node_op == 'Dense':
quantizable_op_details[(node_name, node_op)] = [dense_config, fp32_config]
quantizable_op_details[(node_name, node_op)] = [dense_config, bf16_config, fp32_config]
elif node_op in {'AveragePooling2D', 'AvgPool2D'}:
quantizable_op_details[(node_name, node_op)] = [avgpool_config, bf16_config, fp32_config]
elif node_op in {'MaxPooling2D', 'MaxPool2D'}:
quantizable_op_details[(node_name, node_op)] = [maxpool_config, bf16_config, fp32_config]
else:
quantizable_op_details[(node_name, node_op)] = [fp32_config]
quantizable_op_details[(node_name, node_op)] = [bf16_config, fp32_config]

capability = {
'opwise': copy.deepcopy(quantizable_op_details),
Expand Down
26 changes: 24 additions & 2 deletions neural_compressor/adaptor/keras.yaml
Expand Up @@ -22,7 +22,25 @@
valid_mixed_precisions: []

ops: &common_ops
int8: ['Conv2D', 'SeparableConv2D', 'DepthwiseConv2D', 'Dense']
int8: ['Conv2D', 'SeparableConv2D', 'DepthwiseConv2D', 'Dense', 'AveragePooling2D', 'MaxPooling2D',
'AvgPool2D', 'MaxPool2D']
bf16: ['Dense', 'Conv1D', 'Conv2D', 'Conv3D', 'SeparableConv1D', 'SeparableConv2D', 'SeparableConv3D',
'DepthwiseConv2D', 'Conv1DTranspose', 'Conv2DTranspose', 'Conv3DTranspose', 'AveragePooling2D',
'MaxPooling2D', 'AvgPool2D', 'MaxPool2D', 'MaxPooling1D', 'MaxPooling3D', 'AveragePooling1D',
'AveragePooling3D', 'GlobalMaxPooling1D', 'GlobalMaxPooling2D', 'GlobalMaxPooling3D', 'SimpleRNN',
'GlobalAveragePooling1D', 'GlobalAveragePooling2D', 'GlobalAveragePooling3D', 'LSTM', 'GRU',
'TimeDistributed', 'Bidirectional', 'ConvLSTM1D', 'ConvLSTM2D', 'ConvLSTM3D', 'TextVectorization',
'Normalization', 'Discretization', 'CategoryEncoding', 'Hashing', 'StringLookup', 'IntegerLookup',
'Resizing', 'Rescaling', 'CenterCrop', 'RandomCrop', 'RandomFlip', 'RandomTranslation', 'Activation',
'RandomRotation', 'RandomZoom', 'RandomHeight', 'RandomWidth', 'RandomContrast', 'RandomBrightness',
'BatchNormalization', 'LayerNormalization', 'UnitNormalization', 'GroupNormalization', 'Dropout',
'SpatialDropout1D', 'SpatialDropout2D', 'SpatialDropout3D', 'GaussianDropout', 'GaussianNoise',
'ActivityRegularization', 'AlphaDropout', 'MultiHeadAttention', 'Attention', 'AdditiveAttention',
'Reshape', 'Flatten', 'RepeatVector', 'Permute', 'Cropping1D', 'Cropping2D', 'Cropping3D', 'UpSampling1D',
'UpSampling2D', 'UpSampling3D', 'ZeroPadding1D', 'ZeroPadding2D', 'ZeroPadding3D', 'Concatenate', 'Average',
'Maximum', 'Minimum', 'Add', 'Subtract', 'Multiply', 'Dot', 'LocallyConnected1D', 'LocallyConnected2D',
'Embedding', 'Masking', 'Lambda', 'ReLU', 'Softmax', 'LeakyReLU', 'PReLU', 'ELU', 'ThresholdedReLU'
]
fp32: ['*'] # '*' means all op types

capabilities: &common_capabilities
Expand Down Expand Up @@ -87,7 +105,7 @@
'granularity': ['per_tensor'],
}
},
'default': {
'default': &ref_default_static {
'activation': {
'dtype': ['int8'],
'quant_mode': 'static',
Expand All @@ -96,4 +114,8 @@
'granularity': ['per_tensor']
}
},
'AveragePooling2D': *ref_default_static,
'AvgPool2D': *ref_default_static,
'MaxPooling2D': *ref_default_static,
'MaxPool2D': *ref_default_static,
}
64 changes: 64 additions & 0 deletions neural_compressor/adaptor/keras_utils/pool2d.py
@@ -0,0 +1,64 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import tensorflow as tf
from tensorflow.keras import activations
from tensorflow.keras import backend
from tensorflow.keras import constraints
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import AveragePooling2D
from tensorflow import quantization

class QAvgPool2D(AveragePooling2D):
def __init__(self,
pool_size=(2, 2),
strides=None,
padding="valid",
data_format=None,
min_value=-10000,
max_value=10000,
**kwargs):
super(QAvgPool2D, self).__init__(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
**kwargs)
self.min_value = json.loads(min_value)
self.max_value = json.loads(max_value)


class QMaxPool2D(MaxPooling2D):
def __init__(self,
pool_size=(2, 2),
strides=None,
padding="valid",
data_format=None,
min_value=-10000,
max_value=10000,
**kwargs):
super(QMaxPool2D, self).__init__(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
**kwargs)
self.min_value = json.loads(min_value)
self.max_value = json.loads(max_value)
28 changes: 28 additions & 0 deletions test/mixed_precision/test_mixed_precision_keras_model.py
Expand Up @@ -75,6 +75,10 @@ def result(self):
np.array(self.pred_list) == np.array(self.label_list))
return correct_num / self.samples

class MyMetric_keras(MyMetric):
def __init__(self, *args):
super(MyMetric_keras, self).__init__(*args)

class TestMixedPrecisionWithKerasModel(unittest.TestCase):
@classmethod
def setUpClass(self):
Expand Down Expand Up @@ -117,5 +121,29 @@ def test_mixed_precision_with_keras_model(self):
break
self.assertEqual(found_cast, True)

def test_mixed_precision_with_keras_adaptor(self):
from neural_compressor.data import DataLoader
dataset = Dataset()
dataloader = DataLoader(framework='tensorflow', dataset=dataset)

from neural_compressor.config import MixedPrecisionConfig
from neural_compressor import mix_precision
# add backend='itex' to run on keras adaptor
config = MixedPrecisionConfig(backend='itex')

bf16_model = mix_precision.fit(
model='./models/saved_model',
config=config,
eval_dataloader=dataloader,
eval_metric=MyMetric_keras())

bf16_policy = keras.mixed_precision.Policy('mixed_bfloat16')
# bf16_model.model is an obj of tf.keras.Model
model_policy = bf16_model.model.dtype_policy
conv2d_layer_policy = bf16_model.model.get_layer('conv2d').dtype_policy

self.assertEqual(model_policy.compute_dtype, bf16_policy.compute_dtype)
self.assertEqual(conv2d_layer_policy.compute_dtype, bf16_policy.compute_dtype)

if __name__ == "__main__":
unittest.main()

0 comments on commit 6997518

Please sign in to comment.