From 1deb7d2f80524714bd0c6c1192842fea9f0e340e Mon Sep 17 00:00:00 2001 From: zehao-intel Date: Fri, 9 Dec 2022 22:45:57 +0800 Subject: [PATCH] Refactor Quantization Aware Training of TF backend (#250) Signed-off-by: zehao-intel --- .../scripts/codeScan/pyspelling/inc_dict.txt | 3 + examples/.config/model_params_tensorflow.json | 14 + .../mnist/quantization/qat/README.md | 114 ++++-- .../mnist/quantization/qat/benchmark.py | 28 -- .../mnist/quantization/qat/convert.py | 7 - .../mnist/quantization/qat/main.py | 179 ++++++++++ .../mnist/quantization/qat/mnist.yaml | 26 -- .../mnist/quantization/qat/mnist_itex.yaml | 26 -- .../mnist/quantization/qat/prepare_model.py | 74 ++++ .../mnist/quantization/qat/qat.py | 37 -- .../mnist/quantization/qat/requirements.txt | 2 + .../mnist/quantization/qat/run_benchmark.sh | 40 +++ .../mnist/quantization/qat/run_tuning.sh | 35 ++ .../mnist/quantization/qat/train.py | 38 -- .../resnet50/quantization/qat/README | 119 +++++++ .../resnet50/quantization/qat/main.py | 184 ++++++++++ .../quantization/qat/prepare_model.py | 37 ++ .../quantization/qat/requirements.txt | 2 + .../quantization/qat/run_benchmark.sh | 44 +++ .../resnet50/quantization/qat/run_tuning.sh | 39 +++ neural_compressor/adaptor/tensorflow.py | 45 ++- .../tf_utils/quantize_graph/qat/__init__.py | 16 + .../quantize_graph/qat/fake_quantize.py | 233 +++++++++++++ .../quantize_graph/qat/quantize_config.py | 119 +++++++ .../quantize_graph/qat/quantize_helper.py | 85 +++++ .../qat/quantize_layers/__init__.py | 16 + .../qat/quantize_layers/optimize_layer.py | 32 ++ .../qat/quantize_layers/quantize_layer_add.py | 79 +++++ .../quantize_layers/quantize_layer_base.py | 84 +++++ .../qat/quantize_layers/quantize_layer_bn.py | 55 +++ .../quantize_graph/qat/quantize_wrapper.py | 330 ++++++++++++++++++ neural_compressor/experimental/component.py | 4 +- .../experimental/quantization.py | 34 ++ neural_compressor/model/model.py | 29 ++ test/model/test_model.py | 21 ++ test/quantization/test_tensorflow_qat.py | 134 ++----- 36 files changed, 2064 insertions(+), 300 deletions(-) delete mode 100644 examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/benchmark.py delete mode 100644 examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/convert.py create mode 100644 examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/main.py delete mode 100644 examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/mnist.yaml delete mode 100644 examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/mnist_itex.yaml create mode 100644 examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/prepare_model.py delete mode 100644 examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/qat.py create mode 100644 examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/requirements.txt create mode 100644 examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/run_benchmark.sh create mode 100644 examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/run_tuning.sh delete mode 100644 examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/train.py create mode 100644 examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/README create mode 100644 examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/main.py create mode 100644 examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/prepare_model.py create mode 100644 examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/requirements.txt create mode 100644 examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/run_benchmark.sh create mode 100644 examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/run_tuning.sh create mode 100644 neural_compressor/adaptor/tf_utils/quantize_graph/qat/__init__.py create mode 100644 neural_compressor/adaptor/tf_utils/quantize_graph/qat/fake_quantize.py create mode 100644 neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_config.py create mode 100644 neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_helper.py create mode 100644 neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/__init__.py create mode 100644 neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/optimize_layer.py create mode 100644 neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/quantize_layer_add.py create mode 100644 neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/quantize_layer_base.py create mode 100644 neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/quantize_layer_bn.py create mode 100644 neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_wrapper.py diff --git a/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt b/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt index 9a1c2de0509..cecf227cbaf 100644 --- a/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt +++ b/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt @@ -2436,6 +2436,9 @@ npmjs AWSSageMakerSupport sagemaker xpu +dgpu +BenchmarkConfig +QuantizationAwareTrainingConfig Startup doesn startup diff --git a/examples/.config/model_params_tensorflow.json b/examples/.config/model_params_tensorflow.json index fb70465dc90..db1e355e945 100644 --- a/examples/.config/model_params_tensorflow.json +++ b/examples/.config/model_params_tensorflow.json @@ -1944,6 +1944,13 @@ "batch_size": 1, "new_benchmark": false }, + "mnist_keras": { + "model_src_dir": "image_recognition/keras_models/mnist/quantization/qat", + "dataset_location": "", + "input_model": "/tf_dataset2/models/tensorflow/mnist_keras/saved_model/", + "main_script": "main.py", + "batch_size": 32 + }, "resnet50_fashion": { "model_src_dir": "image_recognition/keras_models/resnet50_fashion/quantization/ptq", "dataset_location": "/tf_dataset2/datasets/mnist/FashionMNIST", @@ -1962,6 +1969,13 @@ "batch_size": 1, "new_benchmark": true }, + "resnet50_keras_qat": { + "model_src_dir": "image_recognition/keras_models/resnet50/quantization/qat", + "dataset_location": "/tf_dataset/dataset/imagenet", + "input_model": "/tf_dataset2/models/tensorflow/resnet50_keras/resnet50", + "main_script": "main.py", + "batch_size": 32 + }, "resnet50_keras_h5": { "model_src_dir": "image_recognition/keras_models/resnet50/quantization/ptq", "dataset_location": "/tf_dataset/dataset/imagenet", diff --git a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/README.md b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/README.md index e1318d4638d..682d4de9bcd 100644 --- a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/README.md +++ b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/README.md @@ -1,10 +1,9 @@ Step-by-Step ============ -This document is used to list steps of reproducing TensorFlow keras Intel® Neural Compressor QAT conversion. +This document is used to apply QAT to Tensorflow Keras models using Intel® Neural Compressor. This example can run on Intel CPUs and GPUs. - ## Prerequisite ### 1. Installation @@ -12,45 +11,108 @@ This example can run on Intel CPUs and GPUs. # Install Intel® Neural Compressor pip install neural-compressor ``` -### 2. Install Intel Tensorflow and TensorFlow Model Optimization +### 2. Install requirements +The Tensorflow and intel-extension-for-tensorflow is mandatory to be installed to run this QAT example. +The Intel Extension for Tensorflow for Intel CPUs is installed as default. ```shell -pip install intel-tensorflow==2.4.0 -pip install tensorflow_model_optimization==0.5.0 +pip install -r requirements.txt ``` -> Note: To generate correct qat model with tensorflow_model_optimization 0.5.0, pls use TensorFlow 2.4 or above. +> Note: Supported Tensorflow [Version](../../../../../../../README.md). -### 3. Install Intel Extension for Tensorflow +### 3. Benchmarking the model on Intel GPU (Optional) -#### Quantizing the model on Intel GPU -Intel Extension for Tensorflow is mandatory to be installed for quantizing the model on Intel GPUs. +To run benchmark of the model on Intel GPUs, Intel Extension for Tensorflow for Intel GPUs is required. ```shell pip install --upgrade intel-extension-for-tensorflow[gpu] ``` -For any more details, please follow the procedure in [install-gpu-drivers](https://github.com/intel-innersource/frameworks.ai.infrastructure.intel-extension-for-tensorflow.intel-extension-for-tensorflow/blob/master/docs/install/install_for_gpu.md#install-gpu-drivers) - -#### Quantizing the model on Intel CPU(Experimental) -Intel Extension for Tensorflow for Intel CPUs is experimental currently. It's not mandatory for quantizing the model on Intel CPUs. -```shell -pip install --upgrade intel-extension-for-tensorflow[cpu] -``` +Please refer to the [Installation Guides](https://dgpu-docs.intel.com/installation-guides/ubuntu/ubuntu-focal-dc.html) for latest Intel GPU driver installation. +For any more details, please follow the procedure in [install-gpu-drivers](https://github.com/intel-innersource/frameworks.ai.infrastructure.intel-extension-for-tensorflow.intel-extension-for-tensorflow/blob/master/docs/install/install_for_gpu.md#install-gpu-drivers). ### 4. Prepare Pretrained model -Run the `train.py` script to get pretrained fp32 model. +The pretrained model is provided by [Keras Applications](https://keras.io/api/applications/). prepare the model, Run as follow: + ``` -### 5. Prepare QAT model - -Run the `qat.py` script to get QAT model which in fact is a fp32 model with quant/dequant pair inserted. - -## Write Yaml config file -In examples directory, there is a mnist.yaml for tuning the model on Intel CPUs. The 'framework' in the yaml is set to 'tensorflow'. If running this example on Intel GPUs, the 'framework' should be set to 'tensorflow_itex' and the device in yaml file should be set to 'gpu'. The mnist_itex.yaml is prepared for the GPU case. We could remove most of items and only keep mandatory item for tuning. We also implement a calibration dataloader and have evaluation field for creation of evaluation function at internal neural_compressor. +python prepare_model.py --output_model=/path/to/model + ``` +`--output_model ` the model should be saved as SavedModel format or H5 format. ## Run Command ```shell - python convert.py # to convert QAT model to quantized model. - - python benchmark.py # to run accuracy benchmark. + bash run_tuning.sh --input_model=./path/to/model --output_model=./result + bash run_benchmark.sh --input_model=./path/to/model --mode=performance --batch_size=32 ``` +Details of enabling Intel® Neural Compressor to apply QAT. +========================= + +This is a tutorial of how to to apply QAT with Intel® Neural Compressor. +## User Code Analysis +1. User specifies fp32 *model* to apply quantization, the dataset is automatically downloaded. In this step, QDQ patterns will be inserted to the keras model, but the fp32 model will not be converted to a int8 model. + +2. User specifies *model* with QDQ patterns inserted, evaluate function to run benchmark. The model we get from the previous step will be run on ITEX backend. Then, the model is going to be fused and inferred. + +### Quantization Config +The Quantization Config class has default parameters setting for running on Intel CPUs. If running this example on Intel GPUs, the 'backend' parameter should be set to 'itex' and the 'device' parameter should be set to 'gpu'. + +``` +config = QuantizationAwareTrainingConfig( + device="gpu", + backend="itex", + ... + ) +``` + +### Code update + +After prepare step is done, we add quantization and benchmark code to generate quantized model and benchmark. + +#### Tune +```python + logger.info('start quantizing the model...') + from neural_compressor import training, QuantizationAwareTrainingConfig + config = QuantizationAwareTrainingConfig() + # create a compression_manager instance to implement QAT + compression_manager = training.prepare_compression(FLAGS.input_model, config) + # QDQ patterns will be inserted to the input keras model + compression_manager.callbacks.on_train_begin() + # get the model with QDQ patterns inserted + q_aware_model = compression_manager.model.model + + # training code defined by users + q_aware_model.compile(optimizer='adam', + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=['accuracy']) + q_aware_model.summary() + train_images_subset = train_images[0:1000] + train_labels_subset = train_labels[0:1000] + q_aware_model.fit(train_images_subset, train_labels_subset, + batch_size=500, epochs=1, validation_split=0.1) + _, q_aware_model_accuracy = q_aware_model.evaluate( + test_images, test_labels, verbose=0) + print('Quant test accuracy:', q_aware_model_accuracy) + + # apply some post process steps and save the output model + compression_manager.callbacks.on_train_end() + compression_manager.save(FLAGS.output_model) +``` +#### Benchmark +```python + from neural_compressor.benchmark import fit + from neural_compressor.experimental import common + from neural_compressor.config import BenchmarkConfig + assert FLAGS.mode == 'performance' or FLAGS.mode == 'accuracy', \ + "Benchmark only supports performance or accuracy mode." + + # convert the quantized keras model to graph_def so that it can be fused by ITEX + model = common.Model(FLAGS.input_model).graph_def + if FLAGS.mode == 'performance': + conf = BenchmarkConfig(cores_per_instance=4, num_of_instance=7) + fit(model, conf, b_func=evaluate) + elif FLAGS.mode == 'accuracy': + accuracy = evaluate(model) + print('Batch size = %d' % FLAGS.batch_size) + print("Accuracy: %.5f" % accuracy) +``` \ No newline at end of file diff --git a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/benchmark.py b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/benchmark.py deleted file mode 100644 index df49ab3b075..00000000000 --- a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/benchmark.py +++ /dev/null @@ -1,28 +0,0 @@ -import tensorflow as tf -from tensorflow import keras -import numpy as np - -class dataloader(object): - def __init__(self, batch_size=100): - mnist = keras.datasets.mnist - (train_images, train_labels), (test_images, test_labels) = mnist.load_data() - - # Normalize the input image so that each pixel value is between 0 to 1. - self.train_images = train_images / 255.0 - self.test_images = test_images / 255.0 - self.train_labels = train_labels - self.test_labels = test_labels - - self.batch_size = batch_size - self.i = 0 - - def __iter__(self): - while self.i < len(self.test_images): - yield self.test_images[self.i: self.i + self.batch_size], self.test_labels[self.i: self.i + self.batch_size] - self.i = self.i + self.batch_size - -from neural_compressor.experimental import Benchmark, common -evaluator = Benchmark('mnist.yaml') -evaluator.model = common.Model('quantized_model') -evaluator.b_dataloader = dataloader() -evaluator('accuracy') diff --git a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/convert.py b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/convert.py deleted file mode 100644 index f1b8c7054b3..00000000000 --- a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/convert.py +++ /dev/null @@ -1,7 +0,0 @@ -from neural_compressor.experimental import ModelConversion, common -conversion = ModelConversion() -conversion.source = 'QAT' -conversion.destination = 'default' -conversion.model = common.Model('../qat/trained_qat_model') -q_model = conversion() -q_model.save('quantized_model') diff --git a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/main.py b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/main.py new file mode 100644 index 00000000000..0709d23794d --- /dev/null +++ b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/main.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import numpy as np +import tensorflow as tf + +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + +flags = tf.compat.v1.flags +FLAGS = flags.FLAGS +logger = logging.getLogger(__name__) + +## Required parameters +flags.DEFINE_string( + 'input_model', None, 'Run inference with specified keras model.') + +flags.DEFINE_string( + 'output_model', None, 'The output quantized model.') + +flags.DEFINE_string( + 'mode', 'performance', 'define benchmark mode for accuracy or performance') + +flags.DEFINE_bool( + 'tune', False, 'whether to tune the model') + +flags.DEFINE_bool( + 'benchmark', False, 'whether to benchmark the model') + +flags.DEFINE_integer( + 'batch_size', 32, 'batch_size') + + +def prepare_data(): + """Load the dataset of MNIST. + + Returns: + train (tuple): The images and labels for training. + test (tuple): The images and labels for testing. + """ + # Load MNIST dataset + mnist = tf.keras.datasets.mnist + (train_images, train_labels), (test_images, test_labels) = mnist.load_data() + + # Normalize the input image so that each pixel value is between 0 to 1. + train_images = train_images / 255.0 + test_images = test_images / 255.0 + + return (train_images, train_labels), (test_images, test_labels) + +(train_images, train_labels), (test_images, test_labels) = prepare_data() + +class dataloader(object): + def __init__(self, batch_size=100): + mnist = tf.keras.datasets.mnist + (train_images, train_labels), (test_images, test_labels) = mnist.load_data() + + # Normalize the input image so that each pixel value is between 0 to 1. + self.train_images = train_images / 255.0 + self.test_images = test_images / 255.0 + self.train_labels = train_labels + self.test_labels = test_labels + + self.batch_size = batch_size + self.i = 0 + + def __iter__(self): + while self.i < len(self.test_images): + yield self.test_images[self.i: self.i + self.batch_size], self.test_labels[self.i: self.i + self.batch_size] + self.i = self.i + self.batch_size + +def evaluate(model): + """Custom evaluate function to estimate the accuracy of the model. + + Args: + model (tf.Graph_def): The input model graph + + Returns: + accuracy (float): evaluation result, the larger is better. + """ + from neural_compressor.experimental import common + model = common.Model(model) + input_tensor = model.input_tensor + output_tensor = model.output_tensor if len(model.output_tensor)>1 else \ + model.output_tensor[0] + iteration = -1 + if FLAGS.benchmark and FLAGS.mode == 'performance': + iteration = 100 + postprocess = LabelShift(label_shift=1) + metric = TensorflowTopK(k=1) + + def eval_func(dataloader): + latency_list = [] + for idx, (inputs, labels) in enumerate(dataloader): + # dataloader should keep the order and len of inputs same with input_tensor + assert len(input_tensor) == len(inputs), \ + 'inputs len must equal with input_tensor' + feed_dict = dict(zip(input_tensor, inputs)) + + start = time.time() + predictions = model.sess.run(output_tensor, feed_dict) + end = time.time() + + predictions, labels = postprocess((predictions, labels)) + metric.update(predictions, labels) + latency_list.append(end-start) + if idx + 1 == iteration: + break + latency = np.array(latency_list).mean() / FLAGS.batch_size + return latency + + dataloader = dataloader(batch_size=FLAGS.batch_size) + latency = eval_func(dataloader) + if FLAGS.benchmark and FLAGS.mode == 'performance': + print("Batch size = {}".format(FLAGS.batch_size)) + print("Latency: {:.3f} ms".format(latency * 1000)) + print("Throughput: {:.3f} images/sec".format(1. / latency)) + acc = metric.result() + return acc + + +def main(): + if FLAGS.tune: + logger.info('start quantizing the model...') + from neural_compressor import training, QuantizationAwareTrainingConfig + config = QuantizationAwareTrainingConfig() + compression_manager = training.prepare_compression(FLAGS.input_model, config) + compression_manager.callbacks.on_train_begin() + + q_aware_model = compression_manager.model.model + + q_aware_model.compile(optimizer='adam', + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=['accuracy']) + + q_aware_model.summary() + train_images_subset = train_images[0:1000] + train_labels_subset = train_labels[0:1000] + q_aware_model.fit(train_images_subset, train_labels_subset, + batch_size=500, epochs=1, validation_split=0.1) + _, q_aware_model_accuracy = q_aware_model.evaluate( + test_images, test_labels, verbose=0) + print('Quant test accuracy:', q_aware_model_accuracy) + + compression_manager.callbacks.on_train_end() + compression_manager.save(FLAGS.output_model) + + if FLAGS.benchmark: + from neural_compressor.benchmark import fit + from neural_compressor.experimental import common + from neural_compressor.config import BenchmarkConfig + assert FLAGS.mode == 'performance' or FLAGS.mode == 'accuracy', \ + "Benchmark only supports performance or accuracy mode." + + model = common.Model(FLAGS.input_model).graph_def + if FLAGS.mode == 'performance': + conf = BenchmarkConfig(cores_per_instance=4, num_of_instance=7) + fit(model, conf, b_func=evaluate) + elif FLAGS.mode == 'accuracy': + accuracy = evaluate(model) + print('Batch size = %d' % FLAGS.batch_size) + print("Accuracy: %.5f" % accuracy) + +if __name__ == "__main__": + main() diff --git a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/mnist.yaml b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/mnist.yaml deleted file mode 100644 index 30c89e41ce6..00000000000 --- a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/mnist.yaml +++ /dev/null @@ -1,26 +0,0 @@ -# -# Copyright (c) 2021 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -model: # mandatory. used to specify model specific information. - name: mnist - framework: tensorflow # mandatory. supported values are tensorflow, pytorch, pytorch_ipex, onnxrt_integer, onnxrt_qlinear or mxnet; allow new framework backend extension. - -device: cpu # optional. default value is cpu, other value is gpu. - -evaluation: # optional. required if user doesn't provide eval_func in neural_compressor.Quantization. - accuracy: # optional. required if user doesn't provide eval_func in neural_compressor.Quantization. - metric: - Accuracy: {} # built-in metrics are topk, map, f1, allow user to register new metric. - diff --git a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/mnist_itex.yaml b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/mnist_itex.yaml deleted file mode 100644 index 5681e991f3a..00000000000 --- a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/mnist_itex.yaml +++ /dev/null @@ -1,26 +0,0 @@ -# -# Copyright (c) 2021 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -model: # mandatory. used to specify model specific information. - name: mnist - framework: tensorflow_itex # mandatory. supported values are tensorflow, tensorflow_itex, pytorch, pytorch_ipex, onnxrt_integer, onnxrt_qlinear or mxnet; allow new framework backend extension. - -device: gpu # optional. set cpu if installed intel-extension-for-tensorflow[cpu], set gpu if installed intel-extension-for-tensorflow[gpu]. - -evaluation: # optional. required if user doesn't provide eval_func in neural_compressor.Quantization. - accuracy: # optional. required if user doesn't provide eval_func in neural_compressor.Quantization. - metric: - Accuracy: {} # built-in metrics are topk, map, f1, allow user to register new metric. - diff --git a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/prepare_model.py b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/prepare_model.py new file mode 100644 index 00000000000..907196a046d --- /dev/null +++ b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/prepare_model.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import tensorflow as tf +from tensorflow import keras + +def train_func(): + # Load MNIST dataset + mnist = keras.datasets.mnist + (train_images, train_labels), (test_images, test_labels) = mnist.load_data() + + # Normalize the input image so that each pixel value is between 0 to 1. + train_images = train_images / 255.0 + test_images = test_images / 255.0 + + # Define the model architecture. + model = keras.Sequential([ + keras.layers.InputLayer(input_shape=(28, 28)), + keras.layers.Reshape(target_shape=(28, 28, 1)), + keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'), + keras.layers.MaxPooling2D(pool_size=(2, 2)), + keras.layers.Flatten(), + keras.layers.Dense(10) + ]) + + # Train the digit classification model + model.compile(optimizer='adam', + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=['accuracy']) + + model.fit( + train_images, + train_labels, + epochs=1, + validation_split=0.1, + ) + + _, baseline_model_accuracy = model.evaluate( + test_images, test_labels, verbose=0) + + print('Baseline test accuracy:', baseline_model_accuracy) + + return model + +def get_mnist_model(saved_path): + assert saved_path is not None, "save path should not be None" + model = train_func() + model.save(saved_path) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Export pretained keras model', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--output_model', + type=str, + help='path to exported model file') + + args = parser.parse_args() + get_mnist_model(args.output_model) \ No newline at end of file diff --git a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/qat.py b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/qat.py deleted file mode 100644 index 655f70fc9dd..00000000000 --- a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/qat.py +++ /dev/null @@ -1,37 +0,0 @@ -import tensorflow as tf -from tensorflow import keras - -# Load MNIST dataset -mnist = keras.datasets.mnist -(train_images, train_labels), (test_images, test_labels) = mnist.load_data() - -# Normalize the input image so that each pixel value is between 0 to 1. -train_images = train_images / 255.0 -test_images = test_images / 255.0 - -model = tf.keras.models.load_model("baseline_model") - -import tensorflow_model_optimization as tfmot -quantize_model = tfmot.quantization.keras.quantize_model - -# q_aware stands for for quantization aware. -q_aware_model = quantize_model(model) - -# `quantize_model` requires a recompile. -q_aware_model.compile(optimizer='adam', - loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), - metrics=['accuracy']) - -q_aware_model.summary() - -train_images_subset = train_images[0:1000] # out of 60000 -train_labels_subset = train_labels[0:1000] - -q_aware_model.fit(train_images_subset, train_labels_subset, - batch_size=500, epochs=1, validation_split=0.1) - -_, q_aware_model_accuracy = q_aware_model.evaluate( - test_images, test_labels, verbose=0) - -print('Quant test accuracy:', q_aware_model_accuracy) -q_aware_model.save("trained_qat_model") diff --git a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/requirements.txt b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/requirements.txt new file mode 100644 index 00000000000..c8cbd6d70a6 --- /dev/null +++ b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/requirements.txt @@ -0,0 +1,2 @@ +tensorflow +intel-extension-for-tensorflow[cpu] \ No newline at end of file diff --git a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/run_benchmark.sh b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/run_benchmark.sh new file mode 100644 index 00000000000..a50d81dcd9c --- /dev/null +++ b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/run_benchmark.sh @@ -0,0 +1,40 @@ +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_benchmark + +} + +# init params +function init_params { + for var in "$@" + do + case $var in + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --mode=*) + mode=$(echo $var |cut -f2 -d=) + ;; + --batch_size=*) + batch_size=$(echo $var |cut -f2 -d=) + ;; + esac + done + +} + +# run_tuning +function run_benchmark { + + python main.py \ + --input_model ${input_model} \ + --benchmark \ + --mode ${mode} \ + --batch_size ${batch_size} \ +} + +main "$@" diff --git a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/run_tuning.sh b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/run_tuning.sh new file mode 100644 index 00000000000..ad02bf6ea2f --- /dev/null +++ b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/run_tuning.sh @@ -0,0 +1,35 @@ +#!/bin/bash +set -x + +function main { + init_params "$@" + run_tuning + +} + +# init params +function init_params { + + for var in "$@" + do + case $var in + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --output_model=*) + output_model=$(echo $var |cut -f2 -d=) + ;; + esac + done + +} + +# run_tuning +function run_tuning { + python main.py \ + --input_model ${input_model} \ + --output_model ${output_model} \ + --tune +} + +main "$@" diff --git a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/train.py b/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/train.py deleted file mode 100644 index 5820b434628..00000000000 --- a/examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/train.py +++ /dev/null @@ -1,38 +0,0 @@ -import tensorflow as tf -from tensorflow import keras - -# Load MNIST dataset -mnist = keras.datasets.mnist -(train_images, train_labels), (test_images, test_labels) = mnist.load_data() - -# Normalize the input image so that each pixel value is between 0 to 1. -train_images = train_images / 255.0 -test_images = test_images / 255.0 - -# Define the model architecture. -model = keras.Sequential([ - keras.layers.InputLayer(input_shape=(28, 28)), - keras.layers.Reshape(target_shape=(28, 28, 1)), - keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'), - keras.layers.MaxPooling2D(pool_size=(2, 2)), - keras.layers.Flatten(), - keras.layers.Dense(10) -]) - -# Train the digit classification model -model.compile(optimizer='adam', - loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), - metrics=['accuracy']) - -model.fit( - train_images, - train_labels, - epochs=1, - validation_split=0.1, -) - -_, baseline_model_accuracy = model.evaluate( - test_images, test_labels, verbose=0) - -print('Baseline test accuracy:', baseline_model_accuracy) -model.save("baseline_model") diff --git a/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/README b/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/README new file mode 100644 index 00000000000..e755742156d --- /dev/null +++ b/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/README @@ -0,0 +1,119 @@ +Step-by-Step +============ + +This document is used to apply QAT to Tensorflow Keras models using Intel® Neural Compressor. +This example can run on Intel CPUs and GPUs. + + +## Prerequisite + +### 1. Installation +```shell +# Install Intel® Neural Compressor +pip install neural-compressor +``` +### 2. Install requirements +The Tensorflow and intel-extension-for-tensorflow is mandatory to be installed to run this QAT example. +The Intel Extension for Tensorflow for Intel CPUs is installed as default. +```shell +pip install -r requirements.txt +``` +> Note: Supported Tensorflow [Version](../../../../../../../README.md). + +### 3. Benchmarking the model on Intel GPU (Optional) + +To run benchmark of the model on Intel GPUs, Intel Extension for Tensorflow for Intel GPUs is required. + +```shell +pip install --upgrade intel-extension-for-tensorflow[gpu] +``` + +Please refer to the [Installation Guides](https://dgpu-docs.intel.com/installation-guides/ubuntu/ubuntu-focal-dc.html) for latest Intel GPU driver installation. +For any more details, please follow the procedure in [install-gpu-drivers](https://github.com/intel-innersource/frameworks.ai.infrastructure.intel-extension-for-tensorflow.intel-extension-for-tensorflow/blob/master/docs/install/install_for_gpu.md#install-gpu-drivers). + +### 4. Prepare Pretrained model + +The pretrained model is provided by [Keras Applications](https://keras.io/api/applications/). prepare the model, Run as follow: + ``` + +python prepare_model.py --output_model=/path/to/model + ``` +`--output_model ` the model should be saved as SavedModel format or H5 format. + +## Run Command + ```shell + bash run_tuning.sh --input_model=./path/to/model --output_model=./result --dataset_location=/path/to/evaluation/dataset + bash run_benchmark.sh --input_model=./path/to/model --mode=performance --dataset_location=/path/to/evaluation/dataset --batch_size=100 + ``` + +Details of enabling Intel® Neural Compressor to apply QAT. +========================= + +This is a tutorial of how to to apply QAT with Intel® Neural Compressor. +## User Code Analysis +1. User specifies fp32 *model*, training dataset *dataset_location* to apply quantization. In this step, QDQ patterns will be inserted to the keras model, but the fp32 model will not be converted to a int8 model. + +2. User specifies *model* with QDQ patterns inserted, evaluate function to run benchmark. The model we get from the previous step will be run on ITEX backend. Then, the model is going to be fused and inferred. + +### Quantization Config +The Quantization Config class has default parameters setting for running on Intel CPUs. If running this example on Intel GPUs, the 'backend' parameter should be set to 'itex' and the 'device' parameter should be set to 'gpu'. + +``` +config = QuantizationAwareTrainingConfig( + device="gpu", + backend="itex", + ... + ) +``` + +### Code update + +After prepare step is done, we add quantization and benchmark code to generate quantized model and benchmark. + +#### Tune +```python + logger.info('start quantizing the model...') + from neural_compressor import training, QuantizationAwareTrainingConfig + config = QuantizationAwareTrainingConfig() + # create a compression_manager instance to implement QAT + compression_manager = training.prepare_compression(FLAGS.input_model, config) + # QDQ patterns will be inserted to the input keras model + compression_manager.callbacks.on_train_begin() + # get the model with QDQ patterns inserted + q_aware_model = compression_manager.model.model + + # training code defined by users + q_aware_model.compile( + optimizer='sgd', + loss=tf.keras.losses.SparseCategoricalCrossentropy(), + metrics=["accuracy"], + ) + q_aware_model.summary() + x_train, y_train = prepare_data(FLAGS.dataset_location) + q_aware_model.fit(x_train, + y_train, + batch_size=64, + epochs=1) + + # apply some post process steps and save the output model + compression_manager.callbacks.on_train_end() + compression_manager.save(FLAGS.output_model) +``` +#### Benchmark +```python + from neural_compressor.benchmark import fit + from neural_compressor.experimental import common + from neural_compressor.config import BenchmarkConfig + assert FLAGS.mode == 'performance' or FLAGS.mode == 'accuracy', \ + "Benchmark only supports performance or accuracy mode." + + # convert the quantized keras model to graph_def so that it can be fused by ITEX + model = common.Model(FLAGS.input_model).graph_def + if FLAGS.mode == 'performance': + conf = BenchmarkConfig(cores_per_instance=4, num_of_instance=7) + fit(model, conf, b_func=evaluate) + elif FLAGS.mode == 'accuracy': + accuracy = evaluate(model) + print('Batch size = %d' % FLAGS.batch_size) + print("Accuracy: %.5f" % accuracy) +``` \ No newline at end of file diff --git a/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/main.py b/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/main.py new file mode 100644 index 00000000000..d7a150f665a --- /dev/null +++ b/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/main.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import codeop +import logging +import numpy as np +import tensorflow as tf + +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + +flags = tf.compat.v1.flags +FLAGS = flags.FLAGS +logger = logging.getLogger(__name__) + +## Required parameters +flags.DEFINE_string( + 'input_model', None, 'Run inference with specified keras model.') + +flags.DEFINE_string( + 'output_model', None, 'The output quantized model.') + +flags.DEFINE_string( + 'mode', 'performance', 'define benchmark mode for accuracy or performance') + +flags.DEFINE_bool( + 'tune', False, 'whether to tune the model') + +flags.DEFINE_bool( + 'benchmark', False, 'whether to benchmark the model') + +flags.DEFINE_string( + 'dataset_location', None, 'location of the dataset on tfrecord format') + +flags.DEFINE_integer( + 'batch_size', 32, 'batch_size') + +from neural_compressor.experimental.metric.metric import TensorflowTopK +from neural_compressor.experimental.data.transforms.transform import ComposeTransform +from neural_compressor.experimental.data.datasets.dataset import TensorflowImageRecord +from neural_compressor.experimental.data.transforms.imagenet_transform import LabelShift +from neural_compressor.data.transforms.imagenet_transform import TensorflowResizeCropImagenetTransform + +def prepare_data(root): + """ + Parse the input tf_record data. + + Args: + root (string): The path to tfrecord files. + + Returns: + data (float): The images that can be used for training or evaluation. + label (float): The labels corresponding to the images. + """ + dataset = TensorflowImageRecord( + root=root, + transform=ComposeTransform(transform_list=[ + TensorflowResizeCropImagenetTransform( + height=224, width=224) + ])) + + data = np.array(list(dataset.map(lambda x, y: x))) + data = tf.keras.applications.resnet.preprocess_input(data) + label = np.array(list(dataset.map(lambda x, y: y))).squeeze(1) + + if len(data) > 10000: + data = data[:10000] + label = label[:10000] + + for idx, i in enumerate(label): + label[idx] = i-1 + + return data, label + +def evaluate(model): + """Custom evaluate function to estimate the accuracy of the model. + + Args: + model (tf.Graph_def): The input model graph + + Returns: + accuracy (float): evaluation result, the larger is better. + """ + from neural_compressor.experimental import common + model = common.Model(model) + input_tensor = model.input_tensor + output_tensor = model.output_tensor if len(model.output_tensor)>1 else \ + model.output_tensor[0] + iteration = -1 + if FLAGS.benchmark and FLAGS.mode == 'performance': + iteration = 100 + postprocess = LabelShift(label_shift=1) + metric = TensorflowTopK(k=1) + + def eval_func(dataloader): + latency_list = [] + for idx, (inputs, labels) in enumerate(dataloader): + # dataloader should keep the order and len of inputs same with input_tensor + inputs = np.array([inputs]) + assert len(input_tensor) == len(inputs), \ + 'inputs len must equal with input_tensor' + feed_dict = dict(zip(input_tensor, inputs)) + + start = time.time() + predictions = model.sess.run(output_tensor, feed_dict) + end = time.time() + + predictions, labels = postprocess((predictions, labels)) + metric.update(predictions, labels) + latency_list.append(end-start) + if idx + 1 == iteration: + break + latency = np.array(latency_list).mean() / FLAGS.batch_size + return latency + + from neural_compressor.experimental.data.dataloaders.default_dataloader import DefaultDataLoader + dataset = TensorflowImageRecord(root=FLAGS.dataset_location, transform=ComposeTransform(transform_list=[ + TensorflowResizeCropImagenetTransform(height=224, width=224)])) + dataloader = DefaultDataLoader(dataset, batch_size=FLAGS.batch_size) + latency = eval_func(dataloader) + if FLAGS.benchmark and FLAGS.mode == 'performance': + print("Batch size = {}".format(FLAGS.batch_size)) + print("Latency: {:.3f} ms".format(latency * 1000)) + print("Throughput: {:.3f} images/sec".format(1. / latency)) + acc = metric.result() + return acc + +def main(): + if FLAGS.tune: + logger.info('start quantizing the model...') + from neural_compressor import training, QuantizationAwareTrainingConfig + config = QuantizationAwareTrainingConfig() + compression_manager = training.prepare_compression(FLAGS.input_model, config) + compression_manager.callbacks.on_train_begin() + + q_aware_model = compression_manager.model.model + q_aware_model.compile( + optimizer='sgd', + loss=tf.keras.losses.SparseCategoricalCrossentropy(), + metrics=["accuracy"], + ) + + q_aware_model.summary() + x_train, y_train = prepare_data(FLAGS.dataset_location) + q_aware_model.fit(x_train, + y_train, + batch_size=64, + epochs=1) + + compression_manager.callbacks.on_train_end() + compression_manager.save(FLAGS.output_model) + + if FLAGS.benchmark: + from neural_compressor.benchmark import fit + from neural_compressor.experimental import common + from neural_compressor.config import BenchmarkConfig + assert FLAGS.mode == 'performance' or FLAGS.mode == 'accuracy', \ + "Benchmark only supports performance or accuracy mode." + + model = common.Model(FLAGS.input_model).graph_def + if FLAGS.mode == 'performance': + conf = BenchmarkConfig(cores_per_instance=4, num_of_instance=7) + fit(model, conf, b_func=evaluate) + elif FLAGS.mode == 'accuracy': + accuracy = evaluate(model) + print('Batch size = %d' % FLAGS.batch_size) + print("Accuracy: %.5f" % accuracy) + +if __name__ == "__main__": + main() diff --git a/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/prepare_model.py b/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/prepare_model.py new file mode 100644 index 00000000000..086690ea37b --- /dev/null +++ b/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/prepare_model.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from tensorflow.keras.applications import ResNet50 + + +def get_resnet50_model(saved_path): + assert saved_path is not None, "save path should not be None" + model = ResNet50(weights='imagenet') + model.save(saved_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Export pretained keras model', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--output_model', + type=str, + help='path to exported model file') + + args = parser.parse_args() + get_resnet50_model(args.output_model) diff --git a/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/requirements.txt b/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/requirements.txt new file mode 100644 index 00000000000..c8cbd6d70a6 --- /dev/null +++ b/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/requirements.txt @@ -0,0 +1,2 @@ +tensorflow +intel-extension-for-tensorflow[cpu] \ No newline at end of file diff --git a/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/run_benchmark.sh b/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/run_benchmark.sh new file mode 100644 index 00000000000..203630c8d95 --- /dev/null +++ b/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/run_benchmark.sh @@ -0,0 +1,44 @@ +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_benchmark + +} + +# init params +function init_params { + for var in "$@" + do + case $var in + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --mode=*) + mode=$(echo $var |cut -f2 -d=) + ;; + --dataset_location=*) + dataset_location=$(echo $var |cut -f2 -d=) + ;; + --batch_size=*) + batch_size=$(echo $var |cut -f2 -d=) + ;; + esac + done + +} + +# run_tuning +function run_benchmark { + + python main.py \ + --input_model ${input_model} \ + --benchmark \ + --mode ${mode} \ + --batch_size ${batch_size} \ + --dataset_location ${dataset_location} +} + +main "$@" diff --git a/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/run_tuning.sh b/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/run_tuning.sh new file mode 100644 index 00000000000..43c392a1be0 --- /dev/null +++ b/examples/tensorflow/image_recognition/keras_models/resnet50/quantization/qat/run_tuning.sh @@ -0,0 +1,39 @@ +#!/bin/bash +set -x + +function main { + init_params "$@" + run_tuning + +} + +# init params +function init_params { + + for var in "$@" + do + case $var in + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --output_model=*) + output_model=$(echo $var |cut -f2 -d=) + ;; + --dataset_location=*) + dataset_location=$(echo $var |cut -f2 -d=) + ;; + esac + done + +} + +# run_tuning +function run_tuning { + python main.py \ + --input_model ${input_model} \ + --output_model ${output_model} \ + --dataset_location ${dataset_location} \ + --tune +} + +main "$@" diff --git a/neural_compressor/adaptor/tensorflow.py b/neural_compressor/adaptor/tensorflow.py index d56760bfd7a..d4956ce4167 100644 --- a/neural_compressor/adaptor/tensorflow.py +++ b/neural_compressor/adaptor/tensorflow.py @@ -68,6 +68,7 @@ def __init__(self, framework_specific_info): self.format = self.framework_specific_info['format'] os.makedirs(self.work_dir, exist_ok=True) + self.model = None self.pre_optimized_model = None self.pre_optimizer_handle = None @@ -524,17 +525,6 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): Returns: tf.compat.v1.GraphDef: the quantized model """ - if self.approach == "quant_aware_training": - assert q_func is not None, "quantization aware training mode \ - is not configured correctly" - - from neural_compressor.experimental import common - qat_model = q_func(model) - - return self.convert(common.Model(qat_model), 'QAT', 'default') - - assert q_func is None, \ - "post-training quantization mode is not support calibration function for Tensorflow!" self.tuning_cfg_to_fw(tune_cfg) logger.debug("Dump quantization configurations:") logger.debug(self.quantize_config) @@ -1370,6 +1360,12 @@ def get_optype_wise_ability(self): res[op[1]] = {'activation': {'dtype': ['bf16']}, 'weight': {'dtype': ['bf16']}} return res + def _pre_hook_for_qat(self, dataloader=None): + self.model.model = self.qat_convert(self.model.model) + + def _post_hook_for_qat(self): + pass + def _pre_eval_hook(self, model): return model @@ -1380,6 +1376,8 @@ def _post_eval_hook(self, model, **kwargs): def save(self, model, path): pass + # this function is used to convert keras QAT model to pb in old QAT implementation, + # and it's not used in refactored QAT def convert(self, model, source, destination): '''The function is used to convert a source model format to another. @@ -1426,6 +1424,31 @@ def convert(self, model, source, destination): return converter.convert() + def qat_convert(self, model, quantize_recipe=None): + """ + Convert a fp32 'tf.keras' model to be a int8 one with quantization aware training implementation. + + Args: + model (tf.keras.Model): The model to be quantized, expected to be a Keras Functional or Sequential model. + quantize_recipe (dict): A dict that decide whether given layers should be quantized. + + Returns: + converted_model (tf.keras.Model): Quantized model with fake quant nodes inserted. + """ + import tensorflow as tf + assert isinstance(model, tf.keras.Model), ("The model to be converted is expected to be " + "a `tf.keras.Model` instance. You should not pass an instance of type: {input}.".format( + input=model.__class__.__name__)) + + assert ( + model.__class__.__name__ in ['Functional', 'Sequential'] + ), "Only `Functional` or `Sequential` keras model is supported for QAT." + + from .tf_utils.quantize_graph.qat.quantize_helper import init_quantize_config, qat_clone_function + config = init_quantize_config(model, quantize_recipe) + q_model = tf.keras.models.clone_model(model, input_tensors=None, clone_function=qat_clone_function) + return q_model + @dump_elapsed_time("Pass recover model") def recover_tuned_model(self, model, q_config): """Execute the recover process on the specified model. diff --git a/neural_compressor/adaptor/tf_utils/quantize_graph/qat/__init__.py b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/__init__.py new file mode 100644 index 00000000000..369707c0ef6 --- /dev/null +++ b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/neural_compressor/adaptor/tf_utils/quantize_graph/qat/fake_quantize.py b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/fake_quantize.py new file mode 100644 index 00000000000..ffa016a1888 --- /dev/null +++ b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/fake_quantize.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import six +import tensorflow as tf + +@six.add_metaclass(abc.ABCMeta) +class FakeQuantizeBase(object): + """ABC interface class for applying fake quantization by insert qdq.""" + + @abc.abstractmethod + def __call__(self, inputs, range, training, **kwargs): + """Apply quantization to the input tensor. + This is the main logic of the 'FakeQuantize' which implements the core logic + to quantize the tensor. It is invoked during the `call` stage of the layer, + and allows modifying the tensors used in graph construction. + + Args: + inputs (tf.Tensor): Input tensor to be quantized. + range (dict): The min-max range of input tensor. + training (bool): Whether the graph is currently training. + **kwargs: Additional variables which may be passed to the FakeQuantize class. + + Returns: + output (tf.Tensor): The tensor to be quantized. + """ + raise NotImplementedError + + @abc.abstractmethod + def get_config(self): + """Returns the config used to serialize the 'FakeQuantize'.""" + raise NotImplementedError('FakeQuantize should implement get_config().') + + @classmethod + def from_config(cls, config): + """Instantiates a 'FakeQuantize' from its config. + + Args: + config (dict): A dict containing required information. + + Returns: + output (FakeQuantize): A 'FakeQuantize' instance. + """ + return cls(**config) + +class FakeQuantize(FakeQuantizeBase): + """The class that applies fake quantization.""" + + def __init__( + self, + per_channel=False, + num_bits=8, + channel_axis=-1, + symmetric=True, + narrow_range=True + ): + """Initialize a FakeQuantize class. + + Args: + per_channel (bool): Whether to apply per_channel quantization. The last dimension is + used as the channel. + num_bits (int): Number of bits for quantization + symmetric (bool): If true, use symmetric quantization limits instead of training + the minimum and maximum of each quantization range separately. + narrow_range (bool): In case of 8 bits, narrow_range nudges the quantized range + to be [-127, 127] instead of [-128, 127]. This ensures symmetric range + has 0 as the centre. + """ + self.num_bits = num_bits + self.per_channel = per_channel + self.symmetric = symmetric + self.narrow_range = narrow_range + self.channel_axis = channel_axis + self.name_prefix = 'FakeQuantize' + + def __call__(self, inputs, ranges, training, **kwargs): + """Applying fake quantization by insert qdq. + The quantized tensor is calculated based on range of the last batch of values. + + Args: + inputs (tf.Tensor): Input tensor to be quantized. + range (dict): The min-max range of input tensor. + training (bool): Whether the graph is currently training. + **kwargs: Additional variables which may be passed to the FakeQuantize class. + + Returns: + output (tf.Tensor): The tensor to be quantized. + """ + with tf.name_scope(self.name_prefix): + input_shape = inputs.get_shape() + input_dim = len(input_shape) + if self.channel_axis == -1: + self.channel_axis += input_dim + + if not training: + return self._insert_qdq(inputs, ranges["min_var"], ranges["max_var"]) + + if self.per_channel: + if input_dim == 2: + reduce_dims = [0] + elif input_dim == 4: + reduce_dims = [i for i in range(input_dim) if i != self.channel_axis] + + if self.per_channel: + if input_dim >= 2: + batch_min = tf.math.reduce_min( + inputs, axis=reduce_dims, name="BatchMin" + ) + else: + batch_min = inputs + else: + batch_min = tf.math.reduce_min(inputs, name="BatchMin") + + if self.per_channel: + if input_dim >= 2: + batch_max = tf.math.reduce_max( + inputs, axis=reduce_dims, name="BatchMax" + ) + else: + batch_max = inputs + else: + batch_max = tf.math.reduce_max(inputs, name="BatchMax") + + if self.symmetric: + if self.narrow_range: + min_max_ratio = -1 + else: + min_max_ratio = -((1 << self.num_bits) - 2) / (1 << self.num_bits) + + range_min = tf.math.minimum(batch_min, batch_max / min_max_ratio) + range_max = tf.math.maximum(batch_max, batch_min * min_max_ratio) + else: + range_min = tf.math.minimum(batch_min, 0.0) + range_max = tf.math.maximum(batch_max, 0.0) + + assign_min = ranges["min_var"].assign(range_min, name="AssignMinLast") + assign_max = ranges["max_var"].assign(range_max, name="AssignMaxLast") + + return self._insert_qdq(inputs, assign_min, assign_max) + + def _insert_qdq(self, inputs, min_var, max_var): + """Adds a fake quantization operation. + Depending on value of self.per_channel, this operation may do global quantization + or per channel quantization. min_var and max_var should have corresponding + shapes: [1] when per_channel == False and [d] when per_channel == True. + + Args: + inputs (tf.Tensor): A tensor containing values to be quantized. + min_var (tf.Variable): A variable containing quantization range lower end(s). + max_var (tf.Variable): A variable containing quantization range upper end(s). + + Returns: + outputs (tf.Tensor): A tensor containing quantized values. + """ + if self.per_channel: + + return tf.quantization.quantize_and_dequantize_v2( + inputs, + min_var, + max_var, + num_bits=self.num_bits, + narrow_range=self.narrow_range, + axis=self.channel_axis, + range_given=True, + ) + else: + assert min_var.get_shape() == [] + assert max_var.get_shape() == [] + + return tf.quantization.quantize_and_dequantize_v2( + inputs, + min_var, + max_var, + num_bits=self.num_bits, + narrow_range=self.narrow_range, + range_given=True, + ) + + def get_config(self): + """Returns the config used to serialize the 'FakeQuantize'. + + Returns: + config (dict): A dict containing required information. + """ + return { + 'num_bits': self.num_bits, + 'per_channel': self.per_channel, + 'symmetric': self.symmetric, + 'narrow_range': self.narrow_range + } + + def __eq__(self, other): + """Check if this instance is equal to another instance. + + Args: + other (FakeQuantize): Another instance to be checked. + + Returns: + is_equal (bool): If the two instances are equal. + """ + if not isinstance(other, FakeQuantize): + return False + + return (self.num_bits == other.num_bits and + self.per_channel == other.per_channel and + self.symmetric == other.symmetric and + self.narrow_range == other.narrow_range) + + def __ne__(self, other): + """Check if this instance is not equal to another instance. + + Args: + other (FakeQuantize): Another instance to be checked. + + Returns: + not_equal (bool): If the two instances are not equal. + """ + return not self.__eq__(other) diff --git a/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_config.py b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_config.py new file mode 100644 index 00000000000..278f79c28ac --- /dev/null +++ b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_config.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +global_config = {} +logger = logging.getLogger("neural_compressor") + +class QuantizeConfig(): + """Class for building custom quantize config. + There should be only one QuantizeConfig instance for global setting. + """ + + def __new__(cls): + """Created a QuantizeConfig instance and add it to the global_config dict. + + Returns: + instance (QuantizeConfig) : The created QuantizeConfig instance. + """ + instance = super().__new__(cls) + global_config['quantize_config'] = instance + return instance + + def __init__(self): + """Initialize QuantizeConfig instance.""" + self.quantize_recipe = {} + self.model_name = None + + def add_quantize_recipe(self, quantize_recipe): + """Add custom recipe for quantization to the QuantizeConfig instance. + + Args: + quantize_recipe (dict): A dict that decide whether given layers should be quantized. + A typical quantize_recipe will be a dict of layer_name and + dict as key-value pairs. In each value dict, there should be + a {'quantize': bool} key-value pair and a {'index': list} pair. + The latter one is used to decide which inputs should be quantized + in some layers with multiple inputs. + For example: + {'conv5_block3_3_conv': {'quantize': Flase} + 'conv5_block3_3_add' : {'quantize': True, 'index': [1, 3]} + } + """ + self.quantize_recipe.update(quantize_recipe) + + def query_layer(self, layer_name): + """Query if a specific layer is in the quantize_recipe dict. + + Args: + layer_name (string): The input layer name. + Returns: + layer_recipe (dict): The quantize recipe for this input layer. + """ + if layer_name in self.quantize_recipe: + return self.quantize_recipe[layer_name] + return {} + + def remove_layer(self, layer_name): + """Remove a specific layer from the quantize_recipe dict. + + Args: + layer_name (string): The name of layer to be removed. + """ + if layer_name in self.quantize_recipe: + del self.quantize_recipe[layer_name] + + def remove_layers(self, layer_names): + """Remove a batch of layers from the quantize_recipe dict. + + Args: + layers_names (List): The names of layers to be removed. + """ + for layer_name in layer_names: + self.remove_layer(layer_name) + + def get_quantize_recipe(self): + """Get the current recipe dict for quantization. + + Returns: + quantize_recipe (dict): A dict that decide whether given layers should be quantized. + """ + return self.quantize_recipe + + def is_empty(self): + """Check if the recipe of quantization is an empty dict. + + Returns: + is_empty (bool): True if no custom recipe is updated to this class. + """ + if self.quantize_recipe: + return False + return True + + def clear_quantize_recipe(self): + """Clear recipe of quantization to be an empty dict.""" + self.quantize_recipe.clear() + +layer_wise_config = { + 'quantize_layers': {'Conv2D', 'Dense', 'DepthwiseConv2D', 'MaxPooling2D', + 'AveragePooling2D', 'GlobalAveragePooling2D'}, + 'possible_quantize_layers': {'Multiply', 'Concatenate', 'Add', 'BatchNormalization'}, + 'weighted_layers': {'Conv2D', 'Dense', 'DepthwiseConv2D'}, + 'multiple_inputs_layers': {'Multiply', 'Concatenate', 'Add'} +} + diff --git a/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_helper.py b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_helper.py new file mode 100644 index 00000000000..26faf2ada1e --- /dev/null +++ b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_helper.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .quantize_wrapper import QuantizeWrapper +from .quantize_layers.optimize_layer import config_quantizable_layers +from .quantize_config import layer_wise_config, global_config, QuantizeConfig + +def init_quantize_config(model, quantize_recipe=None): + """Initialize quantization config at the beginning of QAT process. + + Args: + model_name (string): Special pre-optimized model name. + quantize_recipe (dict): A dict that decide whether given layers should be quantized. + + Returns: + config (QuantizeConfig): QuantizeConfig instance used to decide whether a specific layer + should be quantized. + """ + assert 'quantize_config' not in global_config, ("quantize_config has been unexpectedly" + "created. Please check your QAT workflow") + + config = QuantizeConfig() + config_quantizable_layers(model) + + if quantize_recipe: + config.add_quantize_recipe(quantize_recipe) + + return config + +def _is_quantizable_layer(layer): + """Query if the input layer should be quantized. + + Args: + layer (tf.keras.layers.Layer): input Keras layer + + Returns: + capability (bool): whether the input layer is capable of quantization. + """ + quantizable = True + layer_class = layer.__class__.__name__ + + quantize_config = global_config['quantize_config'] + specific_layer_config = quantize_config.query_layer(layer) + if specific_layer_config: + # the layer is set to be unquantizable by QuantizeConfig + if not specific_layer_config['quantize']: + return False + else: + if layer_class in layer_wise_config['quantize_layers'] or \ + layer_class in layer_wise_config['possible_quantize_layers']: + return True + + if layer_class not in layer_wise_config['quantize_layers']: + quantizable = False + + return quantizable + +def qat_clone_function(layer): + """Wrap or leave given layer based on quantize config object parameters. + + Args: + layer (tf.keras.layers.Layer): input Keras layer + + Returns: + wrapped_layer (QuantizeWrapper): layer wrapped by QuantizeWrapper class. + """ + wrapped_layer= layer + if _is_quantizable_layer(layer): + wrapped_layer = QuantizeWrapper(layer) + + return wrapped_layer \ No newline at end of file diff --git a/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/__init__.py b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/__init__.py new file mode 100644 index 00000000000..369707c0ef6 --- /dev/null +++ b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/optimize_layer.py b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/optimize_layer.py new file mode 100644 index 00000000000..5d5a87083a6 --- /dev/null +++ b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/optimize_layer.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .quantize_layer_add import QuantizeLayerAdd +from .quantize_layer_bn import QuantizeLayerBatchNormalization + +def config_quantizable_layers(model): + quantize_layer_mapping = { + 'Add': QuantizeLayerAdd, + 'BatchNormalization': QuantizeLayerBatchNormalization + } + + for layer_class, quantize_layer in quantize_layer_mapping.items(): + quantize_layer_mapping[layer_class] = quantize_layer() + + for layer in model.layers: + if layer.__class__.__name__ in quantize_layer_mapping: + quantize_layer_mapping[layer.__class__.__name__](layer) \ No newline at end of file diff --git a/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/quantize_layer_add.py b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/quantize_layer_add.py new file mode 100644 index 00000000000..f2e413d8b16 --- /dev/null +++ b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/quantize_layer_add.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from .quantize_layer_base import QuantizeLayerBase + +logger = logging.getLogger("neural_compressor") + +class QuantizeLayerAdd(QuantizeLayerBase): # pragma: no cover + """The class for quantization of Add.""" + + def __init__(self): + """Initialize QuantizeLayerAdd class.""" + self.quantize_patterns = [ + ['Conv', 'BatchNorm', 'Add'], + ['Conv', 'BatchNorm', 'Activation', 'Add'], + ['Conv', 'BatchNorm', 'Activation', 'Dropout', 'Add'] + ] + + super().__init__() + + def _quantizable_add(self): + """Check if the input layer meets criteria of quantization. + + Args: + layer (tf.keras.layers.Layer): The input layer. + + Returns: + quantizable (bool): If this layer should be quantized. + """ + input_layer = self._find_input_layers(self.layer) + if len(input_layer) == 1: + logger.warning("The layer 'Add' should have more than one input. " + "You input a model with layer {} which has only one input".format(self.layer.name)) + return False + + return True + + def __call__(self, layer): + """The main logic of QuantizeLayerAdd. + Neural Compressor will enumerate all layers of the input model to check + if there are any layer meeting the criteria. The choosen ones will be marked + as quantizable by QuantizeConfig. + + Args: + layer (tf.keras.layers.Layer): The keras layer to be estimated. + """ + self.layer = layer + if self._quantizable_add(): + input_layers = self._find_input_layers(self.layer) + fused_conv_index = None + for i, input_layer in enumerate(input_layers): + # Check that the input is a Conv pattern + if 'Conv' in input_layer.__class__.__name__ or self._find_patterns(input_layer): + if hasattr(input_layer, 'outbound_nodes') and \ + len(getattr(input_layer, 'outbound_nodes')) == 1: + fused_conv_index = i + break + + input_indexes = [i for i in range(0, len(input_layers))] + if fused_conv_index: + del input_indexes[fused_conv_index] + + self.quantize_config.add_quantize_recipe({self.layer.name: {'quantize': True, + 'index': input_indexes}}) \ No newline at end of file diff --git a/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/quantize_layer_base.py b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/quantize_layer_base.py new file mode 100644 index 00000000000..e57970703c7 --- /dev/null +++ b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/quantize_layer_base.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..quantize_config import global_config + +class QuantizeLayerBase(): # pragma: no cover + + def __init__(self): + """Initialize QuantizeLayerBase class.""" + self.quantize_patterns = [] + assert 'quantize_config' in global_config, \ + "QuantizeConfig is not correctly created." + self.quantize_config = global_config['quantize_config'] + + def _find_input_layers(self, layer): + """Find all inputs of a specific layer. + + Args: + layer (tf.keras.layers.Layer): The target keras layer that this method + is to find its input layers. + + Returns: + input_layers (list): List of input layers found by this method. + """ + input_layers = [] + if isinstance(layer.input, list): + for input_tensor in layer.input: + input_layer = input_tensor._keras_history.layer + input_layers.append(input_layer) + else: + input_layer = layer.input._keras_history.layer + input_layers.append(input_layer) + return input_layers + + def _find_patterns(self, layer): + """ Checks if the input layer can satisfy the patterns. + + Args: + layer (tf.keras.layers.Layer): The input keras layer that this method + is to find patterns. + + Returns: + valid_patterns (bool): If the input layer can satisfy any pattern. + """ + if not self.quantize_patterns: + return False + + for quantize_pattern in self.quantize_patterns: + index = len(quantize_pattern) - 2 + previous_layer = layer + while(index >= 0): + previous_layer = self._find_input_layers(previous_layer) + if quantize_pattern[index] not in previous_layer.__class__.__name__: + break + index -= 1 + if index == -1: + return True + + return False + + def __call__(self, layer): + """The main logic of QuantizeLayerBase. + Neural Compressor will enumerate all layers of the input model to check + if there are any layer meeting the criteria. The choosen ones will be marked + as quantizable by QuantizeConfig. + + Args: + layer (tf.keras.layers.Layer): The keras layer to be estimated. + """ + raise NotImplementedError() diff --git a/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/quantize_layer_bn.py b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/quantize_layer_bn.py new file mode 100644 index 00000000000..840e91addb5 --- /dev/null +++ b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_layers/quantize_layer_bn.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .quantize_layer_base import QuantizeLayerBase + +class QuantizeLayerBatchNormalization(QuantizeLayerBase): # pragma: no cover + """The class for quantization of BatchNormalization.""" + + def __init__(self): + """Initialize QuantizeLayerBatchNormalization class.""" + super().__init__() + + def _quantizable_bn(self): + """Check if the input layer meets criteria of quantization. + + Args: + layer (tf.keras.layers.Layer): The input layer. + + Returns: + quantizable (bool): If this layer should be quantized. + """ + input_layer = self._find_input_layers(self.layer) + assert len(input_layer) == 1, "BatchNormalization only has one input." + input_layer_class = input_layer.__class__.__name__ + if 'Conv' not in input_layer_class: + return True + + return False + + def __call__(self, layer): + """The main logic of QuantizeLayerBatchNormalization. + Neural Compressor will enumerate all layers of the input model to check + if there are any layer meeting the criteria. The choosen ones will be marked + as quantizable by QuantizeConfig. + + Args: + layer (tf.keras.layers.Layer): The keras layer to be estimated. + """ + self.layer = layer + if self._quantizable_bn(): + self.quantize_config.add_quantize_recipe({self.layer.name: {'quantize': True}}) diff --git a/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_wrapper.py b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_wrapper.py new file mode 100644 index 00000000000..d05f47e9300 --- /dev/null +++ b/neural_compressor/adaptor/tf_utils/quantize_graph/qat/quantize_wrapper.py @@ -0,0 +1,330 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from abc import abstractmethod +from .fake_quantize import FakeQuantize +from tensorflow.python.util import tf_inspect +from .quantize_config import layer_wise_config, global_config + +class QuantizeWrapperBase(tf.keras.layers.Wrapper): + """Base class for quantize wrapper""" + + def __init__(self, layer, **kwargs): + """Create a quantize wrapper for a keras layer. + This wrapper provides options to quantize inputs and weights of the layer. + + Args: + layer (tf.keras.layers.Layer): The keras layer to be wrapped. + **kwargs: Additional keyword arguments to be passed. + """ + assert layer is not None, "'layer' should not be None." + + assert isinstance(layer, tf.keras.layers.Layer) or isinstance(layer, + tf.keras.Model),("'layer' can only be a 'tf.keras.layers.Layer' instance." + " You passed an instance of type: {input}.".format(input=layer.__class__.__name__)) + + if "name" not in kwargs: + kwargs["name"] = self._make_layer_name(layer) + + super(QuantizeWrapperBase, self).__init__(layer, **kwargs) + + self.index = None + self._layer_class = layer.__class__.__name__ + self._track_trackable(layer, name="layer") + + @staticmethod + def _make_layer_name(layer): + """Modify the layer name to be quantized layer.""" + return "{}_{}".format("quant", layer.name) + + @staticmethod + def _weight_name(name): + """Extracts the weight name from the full TensorFlow variable name. + For example, returns 'kernel' for 'dense_2/kernel:0'. + + Args: + name (string): TensorFlow variable name. + + Returns: + weight_name (string): Extracted weight name. + """ + return name.split(":")[0].split("/")[-1] + + def build(self, input_shape): + """Creates the variables of the layer. + + Args: + input_shape (tf.TensorShape or list): shapes of input tensors + """ + super(QuantizeWrapperBase, self).build(input_shape) + + self.optimizer_step = self.add_weight( + "optimizer_step", + initializer=tf.keras.initializers.Constant(-1), + dtype=tf.dtypes.int32, + trainable=False, + ) + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer. + This method will cause the layer's state to be built, if that has not + happened before. This requires that the layer will later be used with + inputs that match the input shape provided here. + + Args: + input_shape (tuple of integers or tf.TensorShape): input shape of the layer. + + Returns: + output_shape(tf.TensorShape) : output shape of the layer. + """ + return self.layer.compute_output_shape(self.layer.input_shape) + + def _init_min_max_variables(self, name, shape): + """Initialize the minimum and maximum values of variables to the wrapped layer. + + Args: + name (string): Name prefix of the variables. + shape (tf.TensorShape): shape of variables to be added. + + Returns: + min_variable (tf.Variable) : The initialized minimum value of given variables. + min_variable (tf.Variable) : The initialized maximum value of given variables. + """ + min_variable = self.layer.add_weight( + name + "_min", + shape = (shape), + trainable = False, + initializer = tf.keras.initializers.Constant(-6.0), + ) + max_variable = self.layer.add_weight( + name + "_max", + shape = (shape), + trainable = False, + initializer = tf.keras.initializers.Constant(6.0), + ) + + return min_variable, max_variable + + def query_input_index(self): + """Query QuantizeConfig to check if there is any designated input index for this layer.""" + quantize_config = global_config['quantize_config'] + custom_layer_config = quantize_config.query_layer(self.layer) + if custom_layer_config and 'index' in custom_layer_config: + self.index = custom_layer_config['index'] + + @abstractmethod + def call(self, inputs, training=None): + """This is where the quantize wrapper's logic lives. + + Args: + inputs (tf.Tensor or dict/list/tuple): Inputs of the wrapped layer. + + Returns: + outputs (tf.Tensor or dict/list/tuple): Outputs of the wrapped layer. + """ + raise NotImplementedError + + def get_config(self): + """Get the config of the quantize wrapper. + + Returns: + config (dict): dict of wrapper config. + """ + base_config = super(QuantizeWrapperBase, self).get_config() + config = {"quantize_config": None} + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config): + """Creates a quantize wrapper instance from its config. + + Args: + config (dict): A Python dictionary, typically the output of get_config. + + Returns: + output_obj: (QuantizeWrapperBase): A quantize wrapper instance. + """ + config = config.copy() + quantize_config = tf.keras.utils.deserialize_keras_object( + config.pop("quantize_config"), module_objects=globals(), custom_objects=None + ) + + layer = tf.keras.layers.deserialize(config.pop("layer")) + + return cls(layer=layer, quantize_config=quantize_config, **config) + + @property + def trainable(self): + """Get trainable attribute for the layer and its sublayers.""" + return self.layer.trainable + + @trainable.setter + def trainable(self, value): + """Set trainable attribute for the layer and its sublayers. + + Args: + value (Boolean): The desired state for the layer's trainable attribute. + """ + self.layer.trainable = value + + @property + def trainable_weights(self): + """List of all trainable weights tracked by this layer. + Trainable weights are updated via gradient descent during training. + + Returns: + trainable_weights (list): A list of trainable variables. + """ + return self.layer.trainable_weights + self._trainable_weights + + @property + def non_trainable_weights(self): + """List of all non-trainable weights tracked by this layer. + Non-trainable weights are *not* updated during training. They are + expected to be updated manually in `call()`. + + Returns: + non_trainable_weights (list): A list of non-trainable variables. + """ + return self.layer.non_trainable_weights + self._non_trainable_weights + + @property + def updates(self): + """update layer """ + return self.layer.updates + self._updates + + @property + def losses(self): + """List of losses added using the `add_loss()` API. + Variable regularization tensors are created when this property is + accessed, so it is eager safe: accessing `losses` under a + `tf.GradientTape` will propagate gradients back to the corresponding + variables. + + Returns: + losses (list): A list of tensors. + """ + return self.layer.losses + self._losses + +class QuantizeWrapper(QuantizeWrapperBase): + """General QuantizeWrapper for quantizable layers. Weights and inputs will be quantized + according to the layer type and quantize config. + """ + + def __init__(self, layer, **kwargs): + """Create a quantize wrapper for a keras layer. + This wrapper provides options to quantize inputs and weights of the layer. + + Args: + layer (tf.keras.layers.Layer): The keras layer to be wrapped. + **kwargs: Additional keyword arguments to be passed. + """ + super().__init__(layer, **kwargs) + + self.kernel = 'kernel' + self.kernel_weights = None + self.channel_axis = kwargs.get("axis", -1) + if self._layer_class == 'DepthwiseConv2D': + self.kernel = 'depthwise_kernel' + self.channel_axis = 2 + if self._layer_class in layer_wise_config['multiple_inputs_layers']: + self.query_input_index() + + def build(self, input_shape): + """Creates the variables of the layer. + + Args: + input_shape (tf.TensorShape or list): shapes of input tensors + """ + super().build(input_shape) + + if self._layer_class in layer_wise_config['weighted_layers']: + self.kernel_weights = getattr(self.layer, self.kernel) + + weight_min, weight_max = self._init_min_max_variables( + name = self.kernel_weights.name.split(":")[0], + shape = self.kernel_weights.shape[self.channel_axis] + ) + + self.weight_range = {"min_var": weight_min, "max_var": weight_max} + self._trainable_weights.append(self.kernel_weights) + + num_input = 1 + if not isinstance(input_shape, tf.TensorShape): + num_input = len(input_shape) + if not self.index: + self.index = [i for i in range(num_input)] + + if num_input == 1: + inputs_min, inputs_max = self._init_min_max_variables( + name = self.layer.name + "_input{}".format(0), + shape = None + ) + self.inputs_range = {"min_var": inputs_min, "max_var": inputs_max} + else: + self.inputs_range = [] + for i in range(num_input): + self.inputs_range.append({}) + if i in self.index: + inputs_min, inputs_max = self._init_min_max_variables( + name = self.layer.name + "_input{}".format(i), + shape = None + ) + self.inputs_range[i] = {"min_var": inputs_min, "max_var": inputs_max} + + def call(self, inputs, training=None): + """This is where the quantize wrapper's logic lives. + + Args: + inputs (tf.Tensor or dict/list/tuple): Inputs of the wrapped layer. + + Returns: + outputs (tf.Tensor or dict/list/tuple): Outputs of the wrapped layer. + """ + if training is None: + training = tf.keras.backend.learning_phase() + + # Quantize all weights, and replace them in the underlying layer. + if self._layer_class in layer_wise_config['weighted_layers']: + weight_quantizer = FakeQuantize( + per_channel = True, + channel_axis = self.channel_axis, + ) + quantized_weight = weight_quantizer(self.kernel_weights, self.weight_range, training) + setattr(self.layer, self.kernel, quantized_weight) + + quantized_inputs = inputs + inputs_quantizer = FakeQuantize( + per_channel = False, + channel_axis = self.channel_axis, + ) + + if not isinstance(quantized_inputs, tf.Tensor): + for i in range(len(quantized_inputs)): + if i in self.index: + quantized_inputs[i] = inputs_quantizer(inputs[i], self.inputs_range[i], training) + else: + quantized_inputs = inputs_quantizer(inputs, self.inputs_range, training) + + args = tf_inspect.getfullargspec(self.layer.call).args + if "training" in args: + outputs = self.layer.call(quantized_inputs, training=training) + else: + outputs = self.layer.call(quantized_inputs) + + return outputs \ No newline at end of file diff --git a/neural_compressor/experimental/component.py b/neural_compressor/experimental/component.py index 8afc1703c23..dbbcf5d2d71 100644 --- a/neural_compressor/experimental/component.py +++ b/neural_compressor/experimental/component.py @@ -130,7 +130,9 @@ def prepare_qat(self): framework_specific_info = {'device': self.cfg.device, 'random_seed': self.cfg.tuning.random_seed, 'workspace_path': self.cfg.tuning.workspace.path, - 'q_dataloader': None} + 'q_dataloader': None, + 'backend': self.cfg.model.get('backend', 'default'), + 'format': self.cfg.model.get('quant_format', 'default')} if self.cfg.quantization.approach is not None: framework_specific_info['approach'] = self.cfg.quantization.approach diff --git a/neural_compressor/experimental/quantization.py b/neural_compressor/experimental/quantization.py index cab874bcca7..c1161d11217 100644 --- a/neural_compressor/experimental/quantization.py +++ b/neural_compressor/experimental/quantization.py @@ -28,6 +28,7 @@ from ..utils.utility import time_limit from ..utils.create_obj_from_config import create_dataloader from ..model import BaseModel +from ..model.model import TensorflowQATModel, get_model_fwk_name from ..conf.config import QuantConf from ..conf.pythonic_config import Config from deprecated import deprecated @@ -408,6 +409,39 @@ def q_func(self, user_q_func): calib_func = q_func + @property + def model(self): + """Override model getter method to handle quantization aware training case.""" + return self._model + + @model.setter + def model(self, user_model): + """Override model setter method to handle quantization aware training case. + + Args: + user_model: user are supported to set model from original framework model format + (eg, tensorflow frozen_pb or path to a saved model), + but not recommended. Best practice is to set from a initialized + neural_compressor.experimental.common.Model. + If tensorflow model is used, model's inputs/outputs will be + auto inferenced, but sometimes auto inferenced + inputs/outputs will not meet your requests, + set them manually in config yaml file. + Another corner case is slim model of tensorflow, + be careful of the name of model configured in yaml file, + make sure the name is in supported slim model list. + """ + approach_cfg = deep_get(self.cfg, 'quantization.approach') + if not self.framework: + self.framework = get_model_fwk_name(user_model) + if self.framework == 'tensorflow' and approach_cfg == 'quant_aware_training': + if type(user_model) == str: + self._model = TensorflowQATModel(user_model) + else: + self._model = TensorflowQATModel(user_model._model) + else: + Component.model.__set__(self, user_model) + def __repr__(self): """Return the class string.""" return 'Quantization' diff --git a/neural_compressor/model/model.py b/neural_compressor/model/model.py index 9d1830db47f..cca8b7d8d65 100644 --- a/neural_compressor/model/model.py +++ b/neural_compressor/model/model.py @@ -1064,6 +1064,35 @@ def save(self, root=None): builder.save() logger.info("Save quantized model to {}.".format(root)) +class TensorflowQATModel(TensorflowSavedModelModel): + def __init__(self, model='', **kwargs): + super(TensorflowQATModel, self).__init__(model) + self.keras_model = None + self.model_type = 'keras' + + @property + def model(self): + if self.keras_model == None: + self.keras_model = tf.keras.models.load_model(self._model) + return self.keras_model + + @model.setter + def model(self, q_model): + self.keras_model = q_model + + def save(self, root=None): + if not root: + root = cfg.default_workspace + '/saved_model' + root = os.path.abspath(os.path.expanduser(root)) + # if not have suffix, default append .pb + os.makedirs(os.path.dirname(root), exist_ok=True) + q_aware_model = self.keras_model + q_aware_model.save(root) + saved_format = 'saved_model' + if root.endswith('.h5'): + saved_format = 'h5 file' + logger.info("Save quantized model to {}.".format(saved_format)) + return root class TensorflowCheckpointModel(TensorflowBaseModel): diff --git a/test/model/test_model.py b/test/model/test_model.py index 8508105f193..492ada7f844 100644 --- a/test/model/test_model.py +++ b/test/model/test_model.py @@ -247,6 +247,27 @@ def test_keras_saved_model(self): os.system('rm -rf simple_model') os.system('rm -rf keras_model') + def test_tf_qat_model(self): + if tf.version.VERSION < '2.3.0': + return + keras_model = build_keras() + self.assertEqual('tensorflow', get_model_fwk_name(keras_model)) + + from neural_compressor.model.model import TensorflowQATModel + model = TensorflowQATModel(keras_model) + assert isinstance(model.model, tf.keras.Model) + keras_model.save('./simple_model') + # load from path + model = TensorflowQATModel('./simple_model') + assert isinstance(model.model, tf.keras.Model) + + + os.makedirs('./keras_model', exist_ok=True) + model.save('./keras_model') + load_model = tf.keras.models.load_model('./keras_model') + os.system('rm -rf simple_model') + os.system('rm -rf keras_model') + @unittest.skipIf(tf.version.VERSION < '2.4.0' or platform.system().lower() == "windows", "Only supports tf 2.4.0 or above") def test_saved_model(self): ssd_resnet50_ckpt_url = 'http://download.tensorflow.org/models/object_detection/ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03.tar.gz' diff --git a/test/quantization/test_tensorflow_qat.py b/test/quantization/test_tensorflow_qat.py index 48bfe4ac8d3..9f2860afe62 100644 --- a/test/quantization/test_tensorflow_qat.py +++ b/test/quantization/test_tensorflow_qat.py @@ -3,54 +3,6 @@ import yaml import shutil -def build_fake_yaml(): - fake_yaml = ''' - model: - name: fake_yaml - framework: tensorflow - - device: cpu - quantization: - approach: quant_aware_training - evaluation: - accuracy: - metric: - Accuracy: {} - ''' - y = yaml.load(fake_yaml, Loader=yaml.SafeLoader) - with open('fake_yaml.yaml', "w", encoding="utf-8") as f: - yaml.dump(y, f) - f.close() - - -def build_fake_yaml_by_train(): - fake_yaml = ''' - model: - name: fake_yaml - framework: tensorflow - - device: cpu - quantization: - approach: quant_aware_training - train: - optimizer: - SGD: - learning_rate: 0.1 - criterion: - CrossEntropyLoss: - reduction: none - evaluation: - accuracy: - metric: - Accuracy: {} - ''' - - y = yaml.load(fake_yaml, Loader=yaml.SafeLoader) - with open('fake_yaml_train.yaml', "w", encoding="utf-8") as f: - yaml.dump(y, f) - f.close() - - def train_func(): import tensorflow as tf from tensorflow import keras @@ -89,46 +41,7 @@ def train_func(): print('Baseline test accuracy:', baseline_model_accuracy) model.save("baseline_model") - - -def q_func(model): - import tensorflow as tf - from tensorflow import keras - mnist = keras.datasets.mnist - (train_images, train_labels), (test_images, test_labels) = mnist.load_data() - - # Normalize the input image so that each pixel value is between 0 to 1. - train_images = train_images / 255.0 - test_images = test_images / 255.0 - - model = tf.keras.models.load_model("baseline_model") - - import tensorflow_model_optimization as tfmot - quantize_model = tfmot.quantization.keras.quantize_model - - # q_aware stands for for quantization aware. - q_aware_model = quantize_model(model) - - # `quantize_model` requires a recompile. - q_aware_model.compile(optimizer='adam', - loss=tf.keras.losses.SparseCategoricalCrossentropy( - from_logits=True), - metrics=['accuracy']) - - train_images_subset = train_images[0:1000] # out of 60000 - train_labels_subset = train_labels[0:1000] - - q_aware_model.fit(train_images_subset, train_labels_subset, - batch_size=500, epochs=1, validation_split=0.1) - - _, q_aware_model_accuracy = q_aware_model.evaluate( - test_images, test_labels, verbose=0) - - print('Quant test accuracy:', q_aware_model_accuracy) - q_aware_model.save("trained_qat_model") - return 'trained_qat_model' - - + class Dataset(object): def __init__(self, batch_size=100): import tensorflow as tf @@ -148,31 +61,52 @@ def __len__(self): def __getitem__(self, idx): return self.test_images[idx], self.test_labels[idx] - class TestTensorflowQAT(unittest.TestCase): import tensorflow as tf @classmethod def setUpClass(self): - build_fake_yaml() train_func() - build_fake_yaml_by_train() @classmethod def tearDownClass(self): - os.remove('fake_yaml.yaml') shutil.rmtree('baseline_model',ignore_errors=True) shutil.rmtree('trained_qat_model',ignore_errors=True) - os.remove('fake_yaml_train.yaml') @unittest.skipIf(tf.version.VERSION < '2.3.0', " keras model need tensorflow version >= 2.3.0, so the case is skipped") - def test_qat_with_external_q_func(self): - from neural_compressor.experimental import Quantization, common - quantizer = Quantization('fake_yaml.yaml') - quantizer.eval_dataloader = common.DataLoader(Dataset()) - quantizer.model = './baseline_model' - quantizer.q_func = q_func - quantizer.fit() + def test_qat(self): + import tensorflow as tf + from tensorflow import keras + mnist = keras.datasets.mnist + (train_images, train_labels), (test_images, test_labels) = mnist.load_data() + + # Normalize the input image so that each pixel value is between 0 to 1. + train_images = train_images / 255.0 + test_images = test_images / 255.0 + + from neural_compressor import training, QuantizationAwareTrainingConfig + config = QuantizationAwareTrainingConfig() + compression_manager = training.prepare_compression('./baseline_model', config) + compression_manager.callbacks.on_train_begin() + + # `quantize_model` requires a recompile. + q_aware_model.compile(optimizer='adam', + loss=tf.keras.losses.SparseCategoricalCrossentropy( + from_logits=True), + metrics=['accuracy']) + + train_images_subset = train_images[0:1000] # out of 60000 + train_labels_subset = train_labels[0:1000] + + q_aware_model.fit(train_images_subset, train_labels_subset, + batch_size=500, epochs=1, validation_split=0.1) + + _, q_aware_model_accuracy = q_aware_model.evaluate( + test_images, test_labels, verbose=0) + + print('Quant test accuracy:', q_aware_model_accuracy) + compression_manager.callbacks.on_train_end() + compression_manager.save("trained_qat_model") if __name__ == '__main__': unittest.main()