diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation_cpu_woq.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation_cpu_woq.py index ab3060a823e..d0d0511bdeb 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_generation_cpu_woq.py +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation_cpu_woq.py @@ -154,7 +154,7 @@ help="minmax learning rate, if None,it will beset to be the same with lr", ) parser.add_argument( - "--enable_quanted_input", + "--disable_quanted_input", action="store_true", help="whether to use the output of quantized block to tune the next block", ) @@ -286,7 +286,7 @@ calib_len=args.calib_len, lr=args.lr, minmax_lr=args.minmax_lr, - enable_quanted_input=args.enable_quanted_input, + disable_quanted_input=args.disable_quanted_input, use_ipex=args.use_ipex, ) else: diff --git a/intel_extension_for_transformers/transformers/llm/quantization/utils.py b/intel_extension_for_transformers/transformers/llm/quantization/utils.py index fc0e4f86221..e04d778c95c 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/utils.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/utils.py @@ -527,7 +527,7 @@ def default_calib_func(model): "seqlen": config.calib_len, "iters": config.iters, "scale_dtype": config.scale_dtype, - "enable_quanted_input": config.enable_quanted_input, + "enable_quanted_input": not config.disable_quanted_input, "lr": config.lr, "minmax_lr": config.minmax_lr, } diff --git a/intel_extension_for_transformers/transformers/utils/config.py b/intel_extension_for_transformers/transformers/utils/config.py index 503f18e9889..a63d22ba0ea 100644 --- a/intel_extension_for_transformers/transformers/utils/config.py +++ b/intel_extension_for_transformers/transformers/utils/config.py @@ -1056,7 +1056,7 @@ def __init__( sym: bool = False, lr: float = None, minmax_lr: float = None, - enable_quanted_input: bool = True, + disable_quanted_input: bool = False, nsamples: int = 512, iters: int = 200, use_ggml: bool = False, @@ -1083,7 +1083,7 @@ def __init__( self.group_size = group_size self.lr = lr self.minmax_lr = minmax_lr - self.enable_quanted_input = enable_quanted_input + self.disable_quanted_input = disable_quanted_input self.iters = iters self.llm_int8_skip_modules = ( llm_int8_skip_modules if llm_int8_skip_modules else []