Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance autotune to return the best q_model directly #1875

Merged
merged 5 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions neural_compressor/common/base_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,17 @@ def set_baseline(self, baseline: float):
def get_number_of_trials(self):
return len(self.tuning_history)

def get_best_quant_config(self) -> BaseConfig:
def get_best_trial_record(self) -> _TrialRecord:
assert self.get_number_of_trials() > 0, "No trial record in tuning monitor."
# Put the record with a higher score at the beginning
sorted_trials_records: List[_TrialRecord] = sorted(
self.tuning_history, key=lambda x: x.trial_result, reverse=True
)
return sorted_trials_records[0].quant_config
return sorted_trials_records[0]

def get_best_quant_config(self) -> BaseConfig:
best_trial_record = self.get_best_trial_record()
return best_trial_record.quant_config

def need_stop(self) -> bool:
"""Check if need to stop tuning. Either accuracy goal is met, max trials is reached or timeout is reached.
Expand Down
15 changes: 15 additions & 0 deletions neural_compressor/common/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import importlib
import subprocess
import time
from typing import Dict

import cpuinfo
import psutil
Expand All @@ -35,6 +37,7 @@
"LazyImport",
"CpuInfo",
"default_tuning_logger",
"call_counter",
]


Expand Down Expand Up @@ -225,3 +228,15 @@ def inner_wrapper(*args, **kwargs):
return inner_wrapper

return log_process_wrapper


# decorator for recording number of times a function is called
FUNC_CALL_COUNTS: Dict[str, int] = collections.defaultdict(int)


def call_counter(func):
def wrapper(*args, **kwargs):
FUNC_CALL_COUNTS[func.__name__] += 1
return func(*args, **kwargs)

return wrapper
15 changes: 11 additions & 4 deletions neural_compressor/tensorflow/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from neural_compressor.common import logger
from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry
from neural_compressor.common.base_tuning import EvaluationFuncWrapper, TuningConfig, init_tuning
from neural_compressor.common.utils import dump_elapsed_time
from neural_compressor.common.utils import call_counter, dump_elapsed_time
from neural_compressor.tensorflow.quantization import quantize_model
from neural_compressor.tensorflow.quantization.config import FRAMEWORK_NAME, StaticQuantConfig
from neural_compressor.tensorflow.utils import BaseModel, Model, constants
Expand All @@ -36,6 +36,7 @@ def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:


@dump_elapsed_time("Pass auto-tune")
@call_counter
def autotune(
model: Union[str, tf.keras.Model, BaseModel],
tune_config: TuningConfig,
Expand All @@ -52,7 +53,7 @@ def autotune(
baseline: float = eval_func_wrapper.evaluate(model)
tuning_monitor.set_baseline(baseline)
tuning_logger.tuning_start()
for trial_index, quant_config in enumerate(config_loader):
for trial_index, quant_config in enumerate(config_loader, 1):
tuning_logger.trial_start(trial_index=trial_index)
tuning_logger.execution_start()
logger.info(quant_config.to_dict())
Expand All @@ -65,8 +66,14 @@ def autotune(
tuning_logger.trial_end(trial_index)
if tuning_monitor.need_stop():
logger.info("Stopped tuning.")
best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config()
best_quant_model = quantize_model(model, quant_config, calib_dataloader, calib_iteration)
best_trial_record = tuning_monitor.get_best_trial_record()
if best_trial_record.trial_index != trial_index:
logger.info("Re-quantizing with best quantization config...")
del q_model
best_quant_config: BaseConfig = best_trial_record.quant_config
best_quant_model = quantize_model(model, best_quant_config, calib_dataloader, calib_iteration)
else:
best_quant_model = q_model
break
tuning_logger.tuning_end()
return best_quant_model
9 changes: 5 additions & 4 deletions neural_compressor/torch/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def autotune(
baseline: float = eval_func_wrapper.evaluate(model)
tuning_monitor.set_baseline(baseline)
tuning_logger.tuning_start()
for trial_index, quant_config in enumerate(config_loader):
for trial_index, quant_config in enumerate(config_loader, 1):
tuning_logger.trial_start(trial_index=trial_index)
tuning_logger.execution_start()
logger.info(quant_config.to_dict())
Expand All @@ -93,10 +93,11 @@ def autotune(
tuning_logger.trial_end(trial_index)
if tuning_monitor.need_stop():
logger.info("Stopped tuning.")
if trial_index == 0: # recover the best q_model from previous results.
logger.info("Reconvering the best quantized model...")
best_trial_record = tuning_monitor.get_best_trial_record()
if best_trial_record.trial_index != trial_index:
logger.info("Re-quantizing with best quantization config...")
del q_model # maybe gc.collect() is needed for memory release
best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config()
best_quant_config: BaseConfig = best_trial_record.quant_config
# !!! Make sure to use deepcopy only when inplace is set to `True`.
q_model = quantize(
deepcopy(model),
Expand Down
3 changes: 2 additions & 1 deletion neural_compressor/torch/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from neural_compressor.common.base_config import BaseConfig, ComposableConfig, config_registry
from neural_compressor.common.utils import Mode, log_process
from neural_compressor.common.utils import Mode, call_counter, log_process
from neural_compressor.torch.quantization.config import SmoothQuantConfig, StaticQuantConfig
from neural_compressor.torch.utils import is_ipex_available, logger
from neural_compressor.torch.utils.utility import WHITE_MODULE_LIST, algos_mapping, get_model_info
Expand All @@ -31,6 +31,7 @@ def need_apply(configs_mapping: Dict[Tuple[str, callable], BaseConfig], algo_nam


@log_process(mode=Mode.QUANTIZE)
@call_counter
def quantize(
model: torch.nn.Module,
quant_config: BaseConfig,
Expand Down
22 changes: 22 additions & 0 deletions test/3x/common/test_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import unittest
from unittest.mock import MagicMock, patch

import neural_compressor.common.utils.utility as inc_utils
from neural_compressor.common import options
from neural_compressor.common.utils import (
CpuInfo,
Expand Down Expand Up @@ -166,5 +167,26 @@ def __init__(self):
assert instance2.value == 1, "Singleton should return the same instance"


class TestCallCounter(unittest.TestCase):
def test_call_counter(self):
# empty dict
inc_utils.FUNC_CALL_COUNTS.clear()

@inc_utils.call_counter
def add(a, b):
return a + b

# Initial count should be 0
self.assertEqual(inc_utils.FUNC_CALL_COUNTS["add"], 0)

# Call the function multiple times
add(1, 2)
add(3, 4)
add(5, 6)

# Count should be incremented accordingly
self.assertEqual(inc_utils.FUNC_CALL_COUNTS["add"], 3)


if __name__ == "__main__":
unittest.main()
38 changes: 38 additions & 0 deletions test/3x/torch/test_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import transformers

import neural_compressor.common.utils.utility as inc_utils
from neural_compressor.common import logger
from neural_compressor.torch.quantization import (
MixPrecisionConfig,
Expand Down Expand Up @@ -163,6 +164,43 @@ def eval_acc_fn(model) -> float:

custom_tune_config = TuningConfig(config_set=[RTNConfig(bits=[4, 6])], max_trials=2)
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn)
print(inc_utils.FUNC_CALL_COUNTS)
self.assertIsNotNone(best_model)

def test_autotune_return_qmodel_directly(self):
inc_utils.FUNC_CALL_COUNTS.clear()

baseline = 1
eval_result = [0.9, 1.1]
acc_list = [baseline] + eval_result

def eval_acc_fn(model) -> float:
acc = acc_list.pop(0)
return acc

custom_tune_config = TuningConfig(config_set=[RTNConfig(bits=[4, 6])], max_trials=2)
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn)
assert (
inc_utils.FUNC_CALL_COUNTS.get("quantize") == 2
), f"quantize should be called twice, but got {inc_utils.FUNC_CALL_COUNTS.get('quantize')}"
self.assertIsNotNone(best_model)

def test_autotune_return_re_quant_qmodel(self):
inc_utils.FUNC_CALL_COUNTS.clear()

baseline = 1
eval_result = [0.9, 0.8]
acc_list = [baseline] + eval_result

def eval_acc_fn(model) -> float:
acc = acc_list.pop(0)
return acc

custom_tune_config = TuningConfig(config_set=[RTNConfig(bits=[4, 6])], max_trials=2)
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn)
assert (
inc_utils.FUNC_CALL_COUNTS.get("quantize") == 3
), f"quantize should be called three times, but got {inc_utils.FUNC_CALL_COUNTS.get('quantize')}"
self.assertIsNotNone(best_model)

@reset_tuning_target
Expand Down
Loading