Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Tensorflow image recognition saved model examples using new API #435

Merged
merged 15 commits into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 10 additions & 16 deletions examples/.config/model_params_tensorflow.json
Original file line number Diff line number Diff line change
Expand Up @@ -367,22 +367,18 @@
"new_benchmark": true
},
"mobilenetv1_saved": {
"model_src_dir": "image_recognition/SavedModel/quantization/ptq",
"model_src_dir": "image_recognition/SavedModel/mobilenet_v1/quantization/ptq",
"dataset_location": "/tf_dataset/dataset/imagenet",
"input_model": "/tf_dataset/tensorflow/saved_model/mobilenet_v1",
"yaml": "mobilenet_v1.yaml",
"strategy": "basic",
"batch_size": 100,
"new_benchmark": true
"main_script": "main.py",
"batch_size": 100
},
"mobilenetv2_saved": {
"model_src_dir": "image_recognition/SavedModel/quantization/ptq",
"model_src_dir": "image_recognition/SavedModel/mobilenet_v2/quantization/ptq",
"dataset_location": "/tf_dataset/dataset/imagenet",
"input_model": "/tf_dataset/tensorflow/saved_model/mobilenet_v2",
"yaml": "mobilenet_v2.yaml",
"strategy": "basic",
"batch_size": 100,
"new_benchmark": true
"main_script": "main.py",
"batch_size": 100
},
"nasnet_mobile": {
"model_src_dir": "image_recognition/tensorflow_models/quantization/ptq",
Expand All @@ -403,13 +399,11 @@
"new_benchmark": true
},
"efficientnet_v2_b0": {
"model_src_dir": "image_recognition/SavedModel/quantization/ptq",
"dataset_location": "/tf_dataset2/datasets/imagenet/ImagenetRaw/ILSVRC2012_img_val",
"model_src_dir": "image_recognition/SavedModel/efficientnet_v2_b0/quantization/ptq",
"dataset_location": "/tf_dataset2/datasets/imagenet/ImagenetRaw/",
"input_model": "/tf_dataset/tensorflow/saved_model/efficientnet_v2",
"yaml": "efficientnet_v2_b0.yaml",
"strategy": "basic",
"batch_size": 100,
"new_benchmark": true
"main_script": "main.py",
"batch_size": 100
},
"ssd_resnet50_v1": {
"model_src_dir": "object_detection/tensorflow_models/ssd_resnet50_v1/quantization/ptq",
Expand Down
3 changes: 1 addition & 2 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,8 @@ Intel® Neural Compressor validated examples with multiple compression technique
<td>*EfficientNet V2 B0</td>
<td>Image Recognition</td>
<td>Post-Training Static Quantization</td>
<td><a href="https://github.com/intel/neural-compressor/tree/old_api_examples/examples/tensorflow/image_recognition/SavedModel/quantization/ptq">SavedModel</a></td>
<td><a href="./tensorflow/image_recognition/SavedModel/efficientnet_v2_b0/quantization/ptq">SavedModel</a></td>
</tr>
<tr>
<td>BERT base MRPC</td>
<td>Natural Language Processing</td>
<td>Post-Training Static Quantization</td>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
Step-by-Step
============

This document is used to enable Tensorflow efficientnet_v2_b0 SavedModel format using Intel® Neural Compressor.
This example can run on Intel CPUs and GPUs.


# Prerequisite

## 1. Environment

### Install Intel® Neural Compressor
```shell
# Install Intel® Neural Compressor
pip install neural-compressor
```
### Install Intel Tensorflow
```shell
pip install intel-tensorflow
```
> Note: Supported Tensorflow >= 2.4.0.

### Install Intel Extension for Tensorflow
#### Quantizing the model on Intel GPU
Intel Extension for Tensorflow is mandatory to be installed for quantizing the model on Intel GPUs.

```shell
pip install --upgrade intel-extension-for-tensorflow[gpu]
```
For any more details, please follow the procedure in [install-gpu-drivers](https://github.com/intel-innersource/frameworks.ai.infrastructure.intel-extension-for-tensorflow.intel-extension-for-tensorflow/blob/master/docs/install/install_for_gpu.md#install-gpu-drivers)

#### Quantizing the model on Intel CPU(Experimental)
Intel Extension for Tensorflow for Intel CPUs is experimental currently. It's not mandatory for quantizing the model on Intel CPUs.

```shell
pip install --upgrade intel-extension-for-tensorflow[cpu]
```

## 2. Prepare Pretrained model
Download the mobilenetv1 model from tensorflow-hub.

image recognition
- [efficientnet_v2_b0](https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_b0/classification/2)

## 3. Prepare Dataset

Download [ImageNet](http://www.image-net.org/) Raw image to dir: /path/to/ImageNet. The dir include below folder and files:

```bash
ls /path/to/ImageNet
ILSVRC2012_img_val val.txt
```

# Run Command
## 1. Quantization
```shell
bash run_tuning.sh --input_model=./SavedModel --output_model=./nc_SavedModel --dataset_location=/path/to/ImageNet/
```

## 2. Benchmark
```shell
bash run_benchmark.sh --input_model=./SavedModel --mode=accuracy --dataset_location=/path/to/ImageNet/ --batch_size=32
bash run_benchmark.sh --input_model=./SavedModel --mode=performance --dataset_location=/path/to/ImageNet/ --batch_size=1
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#

from __future__ import division
import time
import os
import tensorflow as tf
import numpy as np
from argparse import ArgumentParser

arg_parser = ArgumentParser(description='Parse args')
arg_parser.add_argument('-g', "--input-graph",
help='Specify the input graph for the transform tool',
dest='input_graph')
arg_parser.add_argument("--output-graph",
help='Specify tune result model save dir',
dest='output_graph')
arg_parser.add_argument('--benchmark', dest='benchmark', action='store_true', help='run benchmark')
arg_parser.add_argument('--mode', dest='mode', default='performance', help='benchmark mode')
arg_parser.add_argument('--export', dest='export', action='store_true', help='use neural_compressor to export.')
arg_parser.add_argument('--tune', dest='tune', action='store_true', help='use neural_compressor to tune.')
arg_parser.add_argument('--dataset_location', dest='dataset_location',
help='location of calibration dataset and evaluate dataset')
arg_parser.add_argument('--batch_size', type=int, default=32, dest='batch_size', help='batch_size of evaluation')
arg_parser.add_argument('--iters', type=int, default=100, dest='iters', help='interations')
args = arg_parser.parse_args()

def evaluate(model):
"""Custom evaluate function to estimate the accuracy of the model.

Args:
model (tf.Graph_def): The input model graph

Returns:
accuracy (float): evaluation result, the larger is better.
"""
infer = model.signatures["serving_default"]
output_dict_keys = infer.structured_outputs.keys()
output_name = list(output_dict_keys )[0]
from neural_compressor.metric import TensorflowTopK
metric = TensorflowTopK(k=1)

def eval_func(dataloader, metric):
warmup = 5
iteration = None
latency_list = []
if args.benchmark and args.mode == 'performance':
iteration = args.iters
for idx, (inputs, labels) in enumerate(dataloader):
inputs = np.array(inputs)
input_tensor = tf.constant(inputs, dtype=tf.float32)
start = time.time()
predictions = infer(input_tensor)[output_name]
end = time.time()
predictions = predictions.numpy()
metric.update(predictions, labels)
latency_list.append(end - start)
if iteration and idx >= iteration:
break
latency = np.array(latency_list[warmup:]).mean() / eval_dataloader.batch_size
return latency

from neural_compressor.utils.create_obj_from_config import create_dataloader
data_path = os.path.join(args.dataset_location, 'ILSVRC2012_img_val')
label_path = os.path.join(args.dataset_location, 'val.txt')
dataloader_args = {
'batch_size': args.batch_size,
'dataset': {"ImagenetRaw": {'data_path':data_path, 'image_list':label_path}},
'transform': {'PaddedCenterCrop': {'size': 224, 'crop_padding': 32},
'Resize': {'size': 224, 'interpolation': 'bicubic'},
'Normalize': {'mean': [123.675, 116.28, 103.53], 'std': [58.395, 57.12, 57.375]}
},
'filter': None
}
eval_dataloader = create_dataloader('tensorflow', dataloader_args)
latency = eval_func(eval_dataloader, metric)
if args.benchmark and args.mode == 'performance':
print("Batch size = {}".format(eval_dataloader.batch_size))
print("Latency: {:.3f} ms".format(latency * 1000))
print("Throughput: {:.3f} images/sec".format(1. / latency))
acc = metric.result()
return acc

class eval_object_detection_optimized_graph(object):
def run(self):
from neural_compressor.utils import set_random_seed
set_random_seed(9527)
if args.tune:
from neural_compressor import quantization
from neural_compressor.config import PostTrainingQuantConfig
from neural_compressor.utils.create_obj_from_config import create_dataloader
data_path = os.path.join(args.dataset_location, 'ILSVRC2012_img_val')
label_path = os.path.join(args.dataset_location, 'val.txt')
calib_dataloader_args = {
'dataset': {"ImagenetRaw": {'data_path':data_path, 'image_list':label_path}},
'transform': {'PaddedCenterCrop': {'size': 224, 'crop_padding': 32},
'Resize': {'size': 224, 'interpolation': 'bicubic'},
'Normalize': {'mean': [123.675, 116.28, 103.53], 'std': [58.395, 57.12, 57.375]}
},
'filter': None
}
calib_dataloader = create_dataloader('tensorflow', calib_dataloader_args)
conf = PostTrainingQuantConfig(calibration_sampling_size=[5, 10, 50, 100])
q_model = quantization.fit(model=args.input_graph, conf=conf,
calib_dataloader=calib_dataloader, eval_func=evaluate)
q_model.save(args.output_graph)

if args.benchmark:
from neural_compressor.benchmark import fit
from neural_compressor.config import BenchmarkConfig
if args.mode == 'performance':
conf = BenchmarkConfig(cores_per_instance=4, num_of_instance=7)
fit(model=args.input_graph, config=conf, b_func=evaluate)
else:
from neural_compressor.model import Model
model = Model(args.input_graph).model
accuracy = evaluate(model)
print('Batch size = %d' % args.batch_size)
print("Accuracy: %.5f" % accuracy)

if __name__ == "__main__":
evaluate_opt_graph = eval_object_detection_optimized_graph()
evaluate_opt_graph.run()
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/bin/bash
set -x

function main {

init_params "$@"
run_benchmark

}

# init params
function init_params {
batch_size=32
iters=100

for var in "$@"
do
case $var in
--input_model=*)
input_model=$(echo $var |cut -f2 -d=)
;;
--mode=*)
mode=$(echo $var |cut -f2 -d=)
;;
--dataset_location=*)
dataset_location=$(echo $var |cut -f2 -d=)
;;
--batch_size=*)
batch_size=$(echo $var |cut -f2 -d=)
;;
--iters=*)
iters=$(echo $var |cut -f2 -d=)
;;
esac
done

}

# run_tuning
function run_benchmark {

python main.py \
--input-graph ${input_model} \
--mode ${mode} \
--dataset_location ${dataset_location} \
--batch_size ${batch_size} \
--benchmark \
--iters ${iters}
}

main "$@"
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash
set -x

function main {

init_params "$@"

run_tuning

}

# init params
function init_params {

for var in "$@"
do
case $var in
--input_model=*)
input_model=$(echo "$var" |cut -f2 -d=)
;;
--output_model=*)
output_model=$(echo "$var" |cut -f2 -d=)
;;
--dataset_location=*)
dataset_location=$(echo $var |cut -f2 -d=)
esac
done

}

# run_tuning
function run_tuning {
python main.py \
--input-graph "${input_model}" \
--output-graph "${output_model}" \
--dataset_location ${dataset_location} \
--tune
}

main "$@"