Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance 3.x common logger and update 3.x torch example #1783

Merged
merged 15 commits into from
May 15, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,11 @@ def get_user_model():

# 3.x api
if args.approach == 'weight_only':
from neural_compressor.torch.quantization import RTNConfig, GPTQConfig, quantize
from neural_compressor.torch.quantization import RTNConfig, GPTQConfig, prepare, convert, quantize
from neural_compressor.torch.utils import get_double_quant_config
weight_sym = True if args.woq_scheme == "sym" else False
double_quant_config_dict = get_double_quant_config(args.double_quant_type)

if args.woq_algo == "RTN":
if args.double_quant_type is not None:
double_quant_config_dict.update(
Expand Down Expand Up @@ -269,9 +269,8 @@ def get_user_model():
double_quant_group_size=args.double_quant_group_size,
)
quant_config.set_local("lm_head", RTNConfig(dtype="fp32"))
user_model = quantize(
model=user_model, quant_config=quant_config
)
user_model = prepare(model=user_model, quant_config=quant_config)
user_model = convert(model=user_model)
elif args.woq_algo == "GPTQ":
from utils import DataloaderPreprocessor
dataloaderPreprocessor = DataloaderPreprocessor(
Expand Down Expand Up @@ -326,24 +325,24 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args):
double_quant_group_size=args.double_quant_group_size,
)
quant_config.set_local("lm_head", GPTQConfig(dtype="fp32"))
user_model = quantize(
model=user_model, quant_config=quant_config, run_fn=run_fn_for_gptq, run_args=(dataloader_for_calibration, )
)
user_model = prepare(model=user_model, quant_config=quant_config)
run_fn_for_gptq(user_model, dataloader_for_calibration)
user_model = convert(user_model)
else:
if args.sq:
from neural_compressor.torch.quantization import SmoothQuantConfig, quantize
from neural_compressor.torch.quantization import SmoothQuantConfig

# alpha can be a float number of a list of float number.
args.alpha = args.alpha if args.alpha == "auto" else eval(args.alpha)
if re.search("falcon", user_model.config.model_type):
quant_config = SmoothQuantConfig(alpha=args.alpha, folding=False)
else:
quant_config = SmoothQuantConfig(alpha=args.alpha, folding=True)

if re.search("gpt", user_model.config.model_type):
quant_config.set_local("add", SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32"))
else:
from neural_compressor.torch.quantization import quantize, get_default_static_config, StaticQuantConfig
from neural_compressor.torch.quantization import get_default_static_config, StaticQuantConfig

quant_config = get_default_static_config()
if re.search("gpt", user_model.config.model_type):
Expand All @@ -364,12 +363,23 @@ def run_fn(model):
except ValueError:
pass
return

from utils import get_example_inputs
example_inputs = get_example_inputs(user_model, calib_dataloader)
user_model = quantize(
model=user_model, quant_config=quant_config, example_inputs=example_inputs, run_fn=run_fn
)
if args.sq:
# currently, smooth quant only support quantize API
# TODO: support prepare/convert API for smooth quant
from neural_compressor.torch.quantization import quantize

user_model = quantize(
model=user_model, quant_config=quant_config, example_inputs=example_inputs, run_fn=run_fn
)
else:
from neural_compressor.torch.quantization import prepare, convert

user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(user_model)
user_model = convert(user_model)
user_model.save(args.output_dir)


Expand All @@ -394,7 +404,7 @@ def run_fn(model):
user_model.eval()
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
eval_args = LMEvalParser(
model="hf",
model="hf",
user_model=user_model,
tokenizer=tokenizer,
batch_size=args.batch_size,
Expand All @@ -417,7 +427,7 @@ def run_fn(model):

samples = args.iters * args.batch_size
eval_args = LMEvalParser(
model="hf",
model="hf",
user_model=user_model,
tokenizer=tokenizer,
batch_size=args.batch_size,
Expand All @@ -436,4 +446,4 @@ def run_fn(model):
print("Accuracy: %.5f" % acc)
print('Throughput: %.3f samples/sec' % (samples / (end - start)))
print('Latency: %.3f ms' % ((end - start) * 1000 / samples))
print('Batch size = %d' % args.batch_size)
print('Batch size = %d' % args.batch_size)
16 changes: 16 additions & 0 deletions neural_compressor/common/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,22 @@ def quantization_start(cls, stacklevel=2) -> None:
def quantization_end(cls, stacklevel=2) -> None:
logger.info("Quantization end.", stacklevel=stacklevel)

@classmethod
def preparation_start(cls, stacklevel=2) -> None:
logger.info("Preparation started.", stacklevel=stacklevel)

@classmethod
def preparation_end(cls, stacklevel=2) -> None:
logger.info("Preparation end.", stacklevel=stacklevel)

@classmethod
def conversion_start(cls, stacklevel=2) -> None:
logger.info("Conversion started.", stacklevel=stacklevel)

@classmethod
def conversion_end(cls, stacklevel=2) -> None:
logger.info("Conversion end.", stacklevel=stacklevel)

@classmethod
def evaluation_start(cls) -> None:
logger.info("Evaluation started.")
Expand Down
39 changes: 28 additions & 11 deletions neural_compressor/common/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,31 @@ def set_tensorboard(tensorboard: bool):
default_tuning_logger = TuningLogger()


def log_quant_execution(func):
def wrapper(*args, **kwargs):
default_tuning_logger.quantization_start(stacklevel=4)

# Call the original function
result = func(*args, **kwargs)

default_tuning_logger.quantization_end(stacklevel=4)
return result

return wrapper
def log_quant_execution(mode="quantize"):
def log_quant_execution_wrapper(func):
def inner_wrapper(*args, **kwargs):
start_log, end_log = None, None
if mode == "quantize":
yuwenzho marked this conversation as resolved.
Show resolved Hide resolved
start_log = default_tuning_logger.quantization_start
end_log = default_tuning_logger.quantization_end
elif mode == "prepare":
start_log = default_tuning_logger.preparation_start
end_log = default_tuning_logger.preparation_end
elif mode == "convert":
start_log = default_tuning_logger.conversion_start
end_log = default_tuning_logger.conversion_end

if start_log is not None:
start_log(stacklevel=4)

# Call the original function
result = func(*args, **kwargs)

if end_log is not None:
end_log(stacklevel=4)

return result

return inner_wrapper

return log_quant_execution_wrapper
4 changes: 3 additions & 1 deletion neural_compressor/torch/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def need_apply(configs_mapping: Dict[Tuple[str, callable], BaseConfig], algo_nam
return any(config.name == algo_name for config in configs_mapping.values())


@log_quant_execution
@log_quant_execution(mode=Mode.QUANTIZE.value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @chensuyue, please be aware of this change, the quantization process is complete when either the convert or quantize is completed.

def quantize(
model: torch.nn.Module,
quant_config: BaseConfig,
Expand Down Expand Up @@ -86,6 +86,7 @@ def quantize(
return q_model


@log_quant_execution(mode=Mode.PREPARE.value)
def prepare(
model: torch.nn.Module,
quant_config: BaseConfig,
Expand Down Expand Up @@ -143,6 +144,7 @@ def prepare(
return prepared_model


@log_quant_execution(mode=Mode.CONVERT.value)
def convert(
model: torch.nn.Module,
quant_config: BaseConfig = None,
Expand Down