Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions examples/tensorflow/image_recognition/resnet50_v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

model: # mandatory. used to specify model specific information.
model: # mandatory. lpot uses this model name and framework name to decide where to save tuning history and deploy yaml.
name: resnet50_v1
framework: tensorflow # mandatory. supported values are tensorflow, pytorch, pytorch_ipex, onnxrt_integer, onnxrt_qlinear or mxnet; allow new framework backend extension.

quantization: # optional. tuning constraints on model-wise for advance user to reduce tuning space.
calibration:
sampling_size: 50, 100 # optional. default value is 100. used to set how many samples should be used in calibration.
sampling_size: 50, 100 # optional. default value is the size of whole dataset. used to set how many portions of calibration dataset is used. exclusive with iterations field.
dataloader:
batch_size: 10
dataset:
ImageRecord:
root: /path/to/calibration/dataset # NOTE: modify to calibration dataset location if needed
root: /home2/changwa1/sig/LowPrecisionInferenceTool/examples/tensorflow/image_recognition/sub_imagenet # NOTE: modify to calibration dataset location if needed
transform:
ParseDecodeImagenet: {}
ResizeCropImagenet:
height: 224
width: 224
Expand All @@ -41,8 +42,9 @@ evaluation: # optional. required if use
batch_size: 32
dataset:
ImageRecord:
root: /path/to/evaluation/dataset # NOTE: modify to evaluation dataset location if needed
root: /home2/changwa1/sig/LowPrecisionInferenceTool/examples/tensorflow/image_recognition/sub_imagenet # NOTE: modify to evaluation dataset location if needed
transform:
ParseDecodeImagenet: {}
ResizeCropImagenet:
height: 224
width: 224
Expand All @@ -55,13 +57,17 @@ evaluation: # optional. required if use
batch_size: 1
dataset:
ImageRecord:
root: /path/to/evaluation/dataset # NOTE: modify to evaluation dataset location if needed
root: /home2/changwa1/sig/LowPrecisionInferenceTool/examples/tensorflow/image_recognition/sub_imagenet # NOTE: modify to evaluation dataset location if needed
transform:
ParseDecodeImagenet: {}
BilinearImagenet:
height: 224
width: 224

tuning:
strategy:
name: sigopt
sigopt_api_token: WECSOKFVSWZNKNOPOMMIILBOXJHNMOBWPIKLKEGXSDRPCIGR
accuracy_criterion:
relative: 0.01 # optional. default value is relative, other value is absolute. this example allows relative accuracy loss: 1%.
exit_policy:
Expand Down
4 changes: 2 additions & 2 deletions lpot/conf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,8 @@ def percent_to_float(data):
'exit_policy': {'timeout': 0, 'max_trials': 100, 'performance_only': False},
'random_seed': 1978, 'tensorboard': False,
'workspace': {'path': default_workspace}}): {
Optional('strategy', default={'name': 'basic'}): {
'name': And(str, lambda s: s in STRATEGIES),
Optional('strategy', default={'name': 'basic','sigopt_api_token':None}): {
'name': And(str, lambda s: s in STRATEGIES),'sigopt_api_token':str,
Optional('accuracy_weight', default=1.0): float,
Optional('latency_weight', default=1.0): float
} ,
Expand Down
202 changes: 202 additions & 0 deletions lpot/strategy/sigopt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from ..utils import logger
from ..utils.utility import Timeout
from .strategy import strategy_registry, TuneStrategy
from sigopt import Connection


@strategy_registry
class SigOptTuneStrategy(TuneStrategy):
"""The tuning strategy using SigOpt HPO search in tuning space.

Args:
model (object): The FP32 model specified for low precision tuning.
conf (Conf): The Conf class instance initialized from user yaml
config file.
q_dataloader (generator): Data loader for calibration, mandatory for
post-training quantization.
It is iterable and should yield a tuple (input,
label) for calibration dataset containing label,
or yield (input, _) for label-free calibration
dataset. The input could be a object, list, tuple or
dict, depending on user implementation, as well as
it can be taken as model input.
q_func (function, optional): Reserved for future use.
eval_dataloader (generator, optional): Data loader for evaluation. It is iterable
and should yield a tuple of (input, label).
The input could be a object, list, tuple or dict,
depending on user implementation, as well as it can
be taken as model input. The label should be able
to take as input of supported metrics. If this
parameter is not None, user needs to specify
pre-defined evaluation metrics through configuration
file and should set "eval_func" parameter as None.
Tuner will combine model, eval_dataloader and
pre-defined metrics to run evaluation process.
eval_func (function, optional): The evaluation function provided by user.
This function takes model as parameter, and
evaluation dataset and metrics should be
encapsulated in this function implementation and
outputs a higher-is-better accuracy scalar value.

The pseudo code should be something like:

def eval_func(model):
input, label = dataloader()
output = model(input)
accuracy = metric(output, label)
return accuracy
dicts (dict, optional): The dict containing resume information.
Defaults to None.

"""

def __init__(self, model, conf, q_dataloader, q_func=None,
eval_dataloader=None, eval_func=None, dicts=None):
super().__init__(
model,
conf,
q_dataloader,
q_func,
eval_dataloader,
eval_func,
dicts)

# SigOpt init
client_token=conf.usr_cfg.tuning.strategy.sigopt_api_token
self.conn = Connection(client_token)
self.conn.set_proxies({'http': 'http://child-prc.intel.com:913', 'https': 'http://child-prc.intel.com:913'})
self.experiment = None

def params_to_tune_configs(self, params):
op_cfgs = {}
op_cfgs['op'] = {}
for op, configs in self.opwise_quant_cfgs.items():
if len(configs) > 1:
value = int(params[op[0]])
if value == len(configs):
value = len(configs) - 1
op_cfgs['op'][op] = copy.deepcopy(configs[value])
elif len(configs) == 1:
op_cfgs['op'][op] = copy.deepcopy(configs[0])
else:
op_cfgs['op'][op] = copy.deepcopy(self.opwise_tune_cfgs[op][0])
if len(self.calib_iter) > 1:
value = int(params['calib_iteration'])
if value == len(self.calib_iter):
value = len(configs) - 1
op_cfgs['calib_iteration'] = int(self.calib_iter[value])
else:
op_cfgs['calib_iteration'] = int(self.calib_iter[0])
return op_cfgs

def next_tune_cfg(self):
"""The generator of yielding next tuning config to traverse by concrete strategies
according to last tuning result.

"""
while self.experiment.progress.observation_count < self.experiment.observation_budget:
suggestion = self.conn.experiments(self.experiment.id).suggestions().create()
yield self.params_to_tune_configs(suggestion.assignments)
values = [
dict(name='accuracy', value=self.last_tune_result[0]),
dict(name='latency', value=self.last_tune_result[1])
]
obs = self.conn.experiments(self.experiment.id).observations().create(
suggestion=suggestion.id, values=values)
logger.info('[suggestion_id, observation_id]: [%s, %s]' % (suggestion.id, obs.id))
self.experiment = self.conn.experiments(self.experiment.id).fetch()

def get_acc_target(self, base_acc):
if self.cfg.tuning.accuracy_criterion.relative:
return base_acc * (1. - self.cfg.tuning.accuracy_criterion.relative)
else:
return base_acc - self.cfg.tuning.accuracy_criterion.absolute

def traverse(self):
"""The main traverse logic, which could be override by some concrete strategy which needs
more hooks.
This is SigOpt version of traverse -- with additional constraints setting to HPO.
"""
with Timeout(self.cfg.tuning.exit_policy.timeout) as t:
# get fp32 model baseline
if self.baseline is None:
logger.info('Getting FP32 model baseline...')
self.baseline = self._evaluate(self.model)
# record the FP32 baseline
self._add_tuning_history()
logger.info('FP32 baseline is: ' +
('[{:.4f}, {:.4f}]'.format(*self.baseline) if self.baseline else 'None'))
# now initiate the HPO here
logger.info("now initiate the HPO here")
self.experiment = self.create_exp(acc_target=self.get_acc_target(self.baseline[0]))
trials_count = 0
for tune_cfg in self.next_tune_cfg():
# add tune_cfg here as quantize use tune_cfg
print("add tune_cfg here as quantize use tune_cfg")
tune_cfg['advance'] = self.cfg.quantization.advance
trials_count += 1
tuning_history = self._find_tuning_history(tune_cfg)
if tuning_history and trials_count < self.cfg.tuning.exit_policy.max_trials:
self.last_tune_result = tuning_history['last_tune_result']
self.best_tune_result = tuning_history['best_tune_result']
logger.debug('This tuning config was evaluated, skip!')
continue

logger.debug('Dump current tuning configuration:')
logger.debug(tune_cfg)
self.last_qmodel = self.adaptor.quantize(
tune_cfg, self.model, self.calib_dataloader, self.q_func)
assert self.last_qmodel
self.last_tune_result = self._evaluate(self.last_qmodel)

need_stop = self.stop(t, trials_count)

# record the tuning history
saved_tune_cfg = copy.deepcopy(tune_cfg)
saved_last_tune_result = copy.deepcopy(self.last_tune_result)
self._add_tuning_history(saved_tune_cfg, saved_last_tune_result)

if need_stop:
break

def create_exp(self, acc_target):
params = []
for op, configs in self.opwise_quant_cfgs.items():
if len(configs) > 1:
params.append(dict(name=op[0], type='int', bounds=dict(min=0, max=len(configs) - 1)))
if len(self.calib_iter) > 1:
params.append(dict(name='calib_iteration', type='int', bounds=dict(min=0, max=len(self.calib_iter) - 1)))
experiment = self.conn.experiments().create(
name='lpot-tune',
parameters=params,
metrics=[
dict(name='accuracy', objective='maximize', strategy='constraint', threshold=acc_target),
dict(name='latency', objective='minimize', strategy='optimize'),
],
parallel_bandwidth=1,
# Define an Observation Budget for your experiment
observation_budget=100,
project='lpot',
)

logger.info("created experiment: https://app.sigopt.com/experiment/" + experiment.id)

return experiment
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ psutil
ruamel.yaml
pycocotools-windows; sys_platform != 'linux'
pycocotools; sys_platform == 'linux'
opencv-python
opencv-python
sigopt