Skip to content

Commit

Permalink
support > 2GB onnx model and update graph sort (#1359)
Browse files Browse the repository at this point in the history
  • Loading branch information
mengniwang95 committed Nov 7, 2022
1 parent fcfafc5 commit 8d83cc8
Show file tree
Hide file tree
Showing 13 changed files with 371 additions and 36 deletions.
9 changes: 9 additions & 0 deletions examples/.config/model_params_onnxrt.json
Expand Up @@ -621,6 +621,15 @@
"batch_size": 1,
"new_benchmark": true
},
"unet": {
"model_src_dir": "image_recognition/unet/quantization/ptq",
"dataset_location": "/tf_dataset2/datasets/imagenet/ImagenetRaw/ILSVRC2012_img_val",
"input_model": "/tf_dataset2/models/onnx/unet/model.onnx",
"yaml": "unet.yaml",
"strategy": "basic",
"batch_size": 1,
"new_benchmark": true
},
"BiDAF": {
"model_src_dir": "language_translation/onnx_model_zoo/BiDAF/quantization/ptq",
"dataset_location": "/tf_dataset2/datasets/squad/dev-v1.1.json",
Expand Down
82 changes: 82 additions & 0 deletions examples/onnxrt/image_recognition/unet/quantization/ptq/main.py
@@ -0,0 +1,82 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name,logging-format-interpolation


import logging
import argparse

import numpy as np
import onnx

logger = logging.getLogger(__name__)
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.WARN)

if __name__ == "__main__":
logger.info("Evaluating ONNXRuntime full precision accuracy and performance:")
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'--model_path',
type=str,
help="Pre-trained mobilenet_v3 model on onnx file"
)
parser.add_argument(
'--benchmark',
action='store_true', \
default=False
)
parser.add_argument(
'--tune',
action='store_true', \
default=False,
help="whether quantize the model"
)
parser.add_argument(
'--config',
type=str,
help="config yaml path"
)
parser.add_argument(
'--output_model',
type=str,
help="output model path"
)
parser.add_argument(
'--mode',
type=str,
default='performance',
help="benchmark mode of performance or accuracy"
)
args = parser.parse_args()
if args.benchmark:
from neural_compressor.experimental import Benchmark, common
evaluator = Benchmark(args.config)
evaluator.model = common.Model(args.model_path)
evaluator(args.mode)

if args.tune:
from neural_compressor.experimental import Quantization, common

quantize = Quantization(args.config)
quantize.model = common.Model(args.model_path)
q_model = quantize()
q_model.save(args.output_model)

31 changes: 31 additions & 0 deletions examples/onnxrt/image_recognition/unet/quantization/ptq/readme.md
@@ -0,0 +1,31 @@
# Evaluate performance of ONNX Runtime(unet)

This is an experimental example to quantize unet model. We use dummy data to do quantization and evaluation, so the accuracy is not guaranteed.

### Environment
onnx: 1.12.0
onnxruntime: 1.12.1

### Prepare model

```bash
git clone https://github.com/huggingface/diffusers.git
cd diffusers/scripts/
python convert_stable_diffusion_checkpoint_to_onnx.py --model_path "CompVis/stable-diffusion-v1-4" --output_path /workdir/output_path
```

### Quantization

```bash
bash run_tuning.sh --input_model=/workdir/output_path/unet/model.onnx \
--config=unet.yaml \
--output_model=path/to/save
```

### Benchmark

```bash
bash run_benchmark.sh --input_model=/workdir/output_path/unet/model.onnx \
--config=unet.yaml \
--mode=performance
```
@@ -0,0 +1,3 @@
onnx==1.12.0
onnxruntime==1.12.0
onnxruntime-extensions; python_version < '3.10'
@@ -0,0 +1,41 @@
#!/bin/bash
set -x

function main {
init_params "$@"
run_benchmark

}

# init params
function init_params {

for var in "$@"
do
case $var in
--config=*)
config=$(echo $var |cut -f2 -d=)
;;
--input_model=*)
input_model=$(echo $var |cut -f2 -d=)
;;
--mode=*)
mode=$(echo $var |cut -f2 -d=)
;;
esac
done

}

# run_benchmark
function run_benchmark {

python main.py \
--model_path ${input_model} \
--config ${config} \
--mode=${mode} \
--benchmark

}

main "$@"
@@ -0,0 +1,39 @@
#!/bin/bash
set -x

function main {
init_params "$@"
run_tuning

}

# init params
function init_params {

for var in "$@"
do
case $var in
--config=*)
config=$(echo $var |cut -f2 -d=)
;;
--input_model=*)
input_model=$(echo $var |cut -f2 -d=)
;;
--output_model=*)
output_model=$(echo $var |cut -f2 -d=)
;;
esac
done

}

# run_tuning
function run_tuning {
python main.py \
--model_path ${input_model} \
--output_model ${output_model} \
--config ${config} \
--tune
}

main "$@"
57 changes: 57 additions & 0 deletions examples/onnxrt/image_recognition/unet/quantization/ptq/unet.yaml
@@ -0,0 +1,57 @@
#
# Copyright (c) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

version: 1.0

model: # mandatory. used to specify model specific information.
name: unet
framework: onnxrt_qlinearops # mandatory. supported values are tensorflow, pytorch, pytorch_ipex, onnxrt_integer, onnxrt_qlinear or mxnet; allow new framework backend extension.

quantization: # optional. tuning constraints on model-wise for advance user to reduce tuning space.
approach: post_training_static_quant # optional. default value is post_training_static_quant.
calibration:
dataloader:
batch_size: 1
dataset:
dummy:
shape: [[1, 4, 64, 64], [1], [1, 77, 768]]
dtype: ['float32', 'int64', 'float32']

evaluation: # optional. required if user doesn't provide eval_func in neural_compressor.Quantization.
accuracy: # optional. required if user doesn't provide eval_func in neural_compressor.Quantization.
dataloader:
batch_size: 1
dataset:
dummy:
shape: [[1, 4, 64, 64], [1], [1, 77, 768]]
dtype: ['float32', 'int64', 'float32']

performance: # optional. used to benchmark performance of passing model.
warmup: 10
iteration: 500
configs:
cores_per_instance: 4
num_of_instance: 7
dataloader:
batch_size: 1
dataset:
dummy:
shape: [[1, 4, 64, 64], [1], [1, 77, 768]]
dtype: ['float32', 'int64', 'float32']

tuning:
exit_policy:
performance_only: True
random_seed: 9527 # optional. random seed for deterministic tuning.
35 changes: 27 additions & 8 deletions neural_compressor/adaptor/onnxrt.py
Expand Up @@ -142,7 +142,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
break
tmp_iterations = int(math.ceil(calib_sampling_size / calib_batch_size))
data_loader.batch(calib_batch_size)
quantize_params = self._get_quantize_params(tmp_model.model, data_loader, \
quantize_params = self._get_quantize_params(tmp_model, data_loader, \
quantize_config, tmp_iterations)
except Exception as e: # pragma: no cover
if 'Got invalid dimensions for input' in str(e):
Expand All @@ -153,7 +153,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
"Fail to forward with batch size={}, set to {} now.".
format(batch_size, 1))
data_loader.batch(1)
quantize_params = self._get_quantize_params(tmp_model.model, data_loader, \
quantize_params = self._get_quantize_params(tmp_model, data_loader, \
quantize_config, calib_sampling_size)
else: # pragma: no cover
if hasattr(data_loader, 'batch_size') and \
Expand All @@ -164,13 +164,13 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
"So the real sampling size is {}.".
format(calib_sampling_size, data_loader.batch_size,
data_loader.batch_size * iterations))
quantize_params = self._get_quantize_params(tmp_model.model, data_loader, \
quantize_params = self._get_quantize_params(tmp_model, data_loader, \
quantize_config, iterations)
else:
quantize_params = None
self.quantize_params = quantize_params
from neural_compressor.adaptor.ox_utils.quantizer import Quantizer
quantizer = Quantizer(tmp_model.model,
quantizer = Quantizer(copy.deepcopy(model),
quantize_config,
backend,
self.static,
Expand Down Expand Up @@ -459,15 +459,25 @@ def _pre_optimize(self, model, level=1):
if sys.version_info < (3,10) and find_spec('onnxruntime_extensions'): # pragma: no cover
from onnxruntime_extensions import get_library_path
sess_options.register_custom_ops_library(get_library_path())
_ = ort.InferenceSession(model.model.SerializeToString(), sess_options)
tmp_model = onnx.load(sess_options.optimized_model_filepath)
if not model.large_size:
ort.InferenceSession(model.model.SerializeToString(), sess_options)
elif model.model_path is not None: # pragma: no cover
ort.InferenceSession(model.model_path, sess_options)
else: # pragma: no cover
logger.warning('Please use model path instead of onnx model object to quantize')

tmp_model = onnx.load(sess_options.optimized_model_filepath, load_external_data=False)
if model.large_size: # pragma: no cover
from onnx.external_data_helper import load_external_data_for_model
load_external_data_for_model(tmp_model, os.path.split(model.model_path)[0])
model.model_path = sess_options.optimized_model_filepath
model.model = self._replace_gemm_with_matmul(tmp_model).model \
if self.graph_optimization.gemm2matmul else tmp_model
model.model = self._rename_node(model.model)
model = self._revert_fusedconv(model)
model = split_shared_bias(model)
model.topological_sort()
self.pre_optimized_model = model
self.pre_optimized_model = copy.deepcopy(model)

def _revert_fusedconv(self, model):
from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg
Expand Down Expand Up @@ -787,6 +797,13 @@ def evaluate(self, input_graph, dataloader, postprocess=None,
Returns:
(float) evaluation results. acc, f1 e.g.
"""
if input_graph.large_size: # pragma: no cover
onnx.save_model(input_graph.model,
self.work_space + 'eval.onnx',
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.pb",
convert_attribute=False)
sess_options = ort.SessionOptions()
if measurer:
# https://github.com/microsoft/onnxruntime/issues/7347
Expand All @@ -796,7 +813,9 @@ def evaluate(self, input_graph, dataloader, postprocess=None,
if sys.version_info < (3,10) and find_spec('onnxruntime_extensions'): # pragma: no cover
from onnxruntime_extensions import get_library_path
sess_options.register_custom_ops_library(get_library_path())
session = ort.InferenceSession(input_graph.model.SerializeToString(), sess_options)
session = ort.InferenceSession(self.work_space + 'eval.onnx', sess_options) if \
input_graph.large_size else \
ort.InferenceSession(input_graph.model.SerializeToString(), sess_options)
results = []
if metrics:
for metric in metrics:
Expand Down

0 comments on commit 8d83cc8

Please sign in to comment.