Skip to content

Commit

Permalink
WOQ support autoround algo on cpu device (#1312)
Browse files Browse the repository at this point in the history
* woq support autoround algo on cpu device

Signed-off-by: changwangss <chang1.wang@intel.com>

---------

Signed-off-by: changwangss <chang1.wang@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
changwangss and pre-commit-ci[bot] committed Feb 27, 2024
1 parent 7b3abca commit 6c42b53
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,7 @@ neural-compressor
intel_extension_for_pytorch==2.2.0
optimum-intel
git+https://github.com/bigcode-project/bigcode-evaluation-harness@00967d12093ef614de7bdad0772aed8e4118f1fd
git+https://github.com/intel/auto-round.git@a868c805de4be271cfe7403309a64d9bf03a0ecf



Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
parser.add_argument(
"--woq_algo",
default="RTN",
choices=["RTN", "AWQ", "TEQ", "GPTQ"],
choices=["RTN", "AWQ", "TEQ", "GPTQ", "AUTOROUND"],
help="Weight-only parameter.",
)
parser.add_argument(
Expand Down Expand Up @@ -133,6 +133,18 @@
help="Calibration dataset sequence max length, this should align with your model config",
)
parser.add_argument('--gptq_static_groups', action='store_true', help='Use determined group to do quantization')
# ============AUTOROUND configs==============
parser.add_argument(
"--autoround_nsamples",
type=int, default=128,
help="Number of calibration data samples.",
)
parser.add_argument(
"--autoround_seq_len",
type=int,
default=2048,
help="Calibration dataset sequence max length, this should align with your model config",
)
# ============Harness configs============
parser.add_argument("--tasks", default=None, help="Evaluation tasks")
parser.add_argument(
Expand Down Expand Up @@ -281,6 +293,26 @@
algorithm_args=algorithm_args,
calib_dataset=args.dataset
)
elif args.woq_algo == "AUTOROUND":
algorithm_args = {
"n_samples": args.autoround_nsamples,
"amp": False,
"seq_len": args.autoround_seq_len,
"iters": args.calib_iters,
"scale_dtype": "fp32",
"device": "cpu",
}
quantization_config = WeightOnlyQuantConfig(
compute_dtype=args.woq_compute_dtype,
scale_dtype=args.woq_scale_dtype,
weight_dtype=args.woq_weight_dtype,
scheme=args.woq_scheme,
group_size=args.woq_group_size,
algorithm=args.woq_algo,
tokenizer=tokenizer,
algorithm_args=algorithm_args,
calib_dataset=args.dataset
)
else:
quantization_config = WeightOnlyQuantConfig(
weight_dtype=args.woq_weight_dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ tiktoken #qwen
einops #qwen
git+https://github.com/intel/neural-compressor.git
git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2
git+https://github.com/intel/auto-round.git@a868c805de4be271cfe7403309a64d9bf03a0ecf
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
parser.add_argument(
"--woq_algo",
default="RTN",
choices=["RTN", "AWQ", "TEQ", "GPTQ"],
choices=["RTN", "AWQ", "TEQ", "GPTQ", "AUTOROUND"],
help="Weight-only parameter.",
)
parser.add_argument(
Expand Down Expand Up @@ -159,6 +159,18 @@
help="Calibration dataset sequence max length, this should align with your model config",
)
parser.add_argument('--gptq_static_groups', action='store_true', help='Use determined group to do quantization')
# ============AUTOROUND configs==============
parser.add_argument(
"--autoround_nsamples",
type=int, default=128,
help="Number of calibration data samples.",
)
parser.add_argument(
"--autoround_seq_len",
type=int,
default=2048,
help="Calibration dataset sequence max length, this should align with your model config",
)
# ============BitsAndBytes configs==============
parser.add_argument("--bitsandbytes", action="store_true")
# ============AutoModel parameters==============
Expand Down Expand Up @@ -292,6 +304,26 @@
tokenizer=tokenizer,
algorithm_args=algorithm_args,
)
elif args.woq_algo == "AUTOROUND":
algorithm_args = {
"n_samples": args.autoround_nsamples,
"amp": False,
"seq_len": args.autoround_seq_len,
"iters": args.calib_iters,
"scale_dtype": "fp32",
"device": "cpu",
}
quantization_config = WeightOnlyQuantConfig(
compute_dtype=args.woq_compute_dtype,
scale_dtype=args.woq_scale_dtype,
weight_dtype=args.woq_weight_dtype,
scheme=args.woq_scheme,
group_size=args.woq_group_size,
algorithm=args.woq_algo,
tokenizer=tokenizer,
algorithm_args=algorithm_args,
calib_dataset=args.dataset
)
else:
quantization_config = WeightOnlyQuantConfig(
compute_dtype=args.woq_compute_dtype,
Expand Down
26 changes: 21 additions & 5 deletions intel_extension_for_transformers/llm/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _replace_linear(
model._modules[name].requires_grad_(False)
if device == "cpu" or device == torch.device("cpu") or device == "auto":
if not empty_weights:
if quantization_config.algorithm == "GPTQ":
if quantization_config.algorithm == "GPTQ" or quantization_config.algorithm == "AUTOROUND":
from .gptq_utils import unpack_weight
int_weight, gptq_scales, gptq_zeros = unpack_weight(
module.qweight,
Expand Down Expand Up @@ -237,7 +237,7 @@ def convert_to_quantized_model(model, config, device="cpu"):
calib_func = config.calib_func
calib_iters = config.calib_iters
model_device = next(model.parameters()).device
if calib_dataloader is None and config.algorithm in ["TEQ", "AWQ", "GPTQ"]:
if calib_dataloader is None and config.algorithm in ["TEQ", "AWQ", "GPTQ", "AUTOROUND"]:
from datasets import load_dataset
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -320,7 +320,8 @@ def default_calib_func(model):
},
"awq_args": config.algorithm_args.update({"enable_mse_search": config.mse_range})
if config.algorithm == "AWQ" and config.algorithm_args is not None else {},
"gptq_args": config.algorithm_args if config.algorithm == "GPTQ" else None
"gptq_args": config.algorithm_args if config.algorithm == "GPTQ" else None,
"autoround_args": config.algorithm_args if config.algorithm == "AUTOROUND" else None
}
conf = PostTrainingQuantConfig(
approach="weight_only",
Expand All @@ -346,7 +347,7 @@ def default_calib_func(model):
)
# TEQ: set calib_func=None, use default training func as calib_func
# RTN: doesn't need calib_func
if config.algorithm in ["TEQ", "RTN", "GPTQ"]:
if config.algorithm in ["TEQ", "RTN", "GPTQ", "AUTOROUND"]:
calib_func = None

orig_dtype = torch.float32
Expand All @@ -360,6 +361,7 @@ def default_calib_func(model):
conf,
calib_func=calib_func,
calib_dataloader=calib_dataloader)

if device == "xpu" or device == torch.device("xpu"):
model = inc_model.export_compressed_model(compression_dtype=torch.int8,
compression_dim=0,
Expand All @@ -374,7 +376,6 @@ def default_calib_func(model):
if config.algorithm == "GPTQ":
inc_model = inc_model.export_compressed_model(use_optimum_format=True)
inc_model.eval()

quantize_config = {
"bits": bits,
"group_size": config.group_size,
Expand All @@ -386,6 +387,21 @@ def default_calib_func(model):
"model_file_base_name": "model",
}

setattr(config, "gptq_quantize_config", quantize_config)
q_model = replace_linear(inc_model, None, None, config, device=device)
elif config.algorithm == "AUTOROUND":
inc_model = inc_model.export_compressed_model(use_optimum_format=True)
inc_model.eval()
quantize_config = {
"bits": bits,
"group_size": config.group_size,
"desc_act": False,
"sym": True if config.scheme == "sym" else False,
"true_sequential": True,
"model_name_or_path": "null",
"model_file_base_name": "model",
}

setattr(config, "gptq_quantize_config", quantize_config)
q_model = replace_linear(inc_model, None, None, config, device=device)
else:
Expand Down
44 changes: 35 additions & 9 deletions tests/CI/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def test_quantization_for_llm(self):
from intel_extension_for_transformers.transformers import AutoModelForCausalLM
fp32_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, use_neural_speed=False)
dummy_input = fp32_model.dummy_inputs["input_ids"]
#smooth-quant
# SQ
sq_config = SmoothQuantConfig(
tokenizer=tokenizer, # either two of one, tokenizer or calib_func
calib_iters=2,
Expand All @@ -332,7 +332,8 @@ def test_quantization_for_llm(self):
use_neural_speed=False
)
self.assertTrue(isinstance(q_model.model, torch.jit.ScriptModule))
#SQ auto

# SQ auto
recipes = {
"smooth_quant": True,
"smooth_quant_args": { "alpha": "auto", "auto_alpha_args":{"alpha_max": 0.6,
Expand All @@ -349,18 +350,19 @@ def test_quantization_for_llm(self):
use_neural_speed=False
)
self.assertTrue(isinstance(q_model.model, torch.jit.ScriptModule))

# weight-only
#RTN
# RTN
woq_config = WeightOnlyQuantConfig(weight_dtype="int4_fullrange")
woq_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
quantization_config=woq_config,
use_neural_speed=False
)
woq_model.eval()
output = woq_model(dummy_input)
print("output:", float(output[0][0][0][0]))
self.assertTrue(isclose(float(output[0][0][0][0]), 0.16387596726417542, rel_tol=1e-04))
#AWQ

# AWQ
woq_config = WeightOnlyQuantConfig(weight_dtype="int4_fullrange",
calib_iters=5,
tokenizer=tokenizer,
Expand All @@ -373,7 +375,8 @@ def test_quantization_for_llm(self):
output = woq_model(dummy_input)
print("output:", float(output[0][0][0][0]))
self.assertTrue(isclose(float(output[0][0][0][0]), 0.17239853739738464, rel_tol=1e-04))
#TEQ

# TEQ
woq_config = WeightOnlyQuantConfig(weight_dtype="int4_fullrange",
calib_iters=5,
tokenizer=tokenizer,
Expand All @@ -384,7 +387,8 @@ def test_quantization_for_llm(self):
)
woq_model.eval()
output = woq_model(dummy_input)
#fp8

# fp8
woq_config = WeightOnlyQuantConfig(weight_dtype="fp8_e5m2", scale_dtype="fp8_e8m0")
woq_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, quantization_config=woq_config, use_neural_speed=False
Expand All @@ -394,6 +398,7 @@ def test_quantization_for_llm(self):
self.assertTrue(
isclose(float(output[0][0][0][0]), 0.16162332892417908, rel_tol=1e-04)
)

# amp
amp_config = MixedPrecisionConfig()
amp_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
Expand All @@ -403,6 +408,7 @@ def test_quantization_for_llm(self):
amp_model.eval()
output = amp_model(dummy_input)
self.assertTrue(isclose(float(output[0][0][0][0]), 0.1689453125, rel_tol=1e-04))

# bitsandbytes, for cpu is fp32 model
bab_config = BitsAndBytesConfig()
bab_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
Expand Down Expand Up @@ -430,7 +436,7 @@ def test_quantization_for_llm(self):
print("output:", float(output[0][0][0][0]))
self.assertTrue(isclose(float(output[0][0][0][0]), 0.1675747185945511, rel_tol=1e-04))

#GPTQ
# GPTQ
algorithm_args = {
"act_order": False,
"percdamp": 0.01,
Expand All @@ -449,9 +455,29 @@ def test_quantization_for_llm(self):
)
woq_model.eval()
output = woq_model(dummy_input)
print("output:", float(output[0][0][0][0]))
self.assertTrue(isclose(float(output[0][0][0][0]), 0.17126554250717163, rel_tol=1e-04))

# AUTOROUND
algorithm_args = {
"n_samples": 128,
"amp": False,
"seq_len": 32,
"iters": 5,
"scale_dtype": "fp32",
"device": "cpu",
}
woq_config = WeightOnlyQuantConfig(weight_dtype="int4_clip",
algorithm_args=algorithm_args,
tokenizer=tokenizer,
algorithm="AUTOROUND")
woq_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
quantization_config=woq_config,
use_neural_speed=False
)
woq_model.eval()
output = woq_model(dummy_input)
self.assertTrue(isclose(float(output[0][0][0][0]), 0.18015708029270172, rel_tol=1e-04))

def test_export(self):
# test model with model_id
self.trainer.export_to_onnx("export.onnx")
Expand Down
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ datasets==2.14.7
einops
evaluate
gguf
git+https://github.com/intel/auto-round.git@a868c805de4be271cfe7403309a64d9bf03a0ecf
git+https://github.com/intel/neural-compressor.git
intel-extension-for-pytorch==2.2.0
intel-tensorflow==2.14.0
Expand Down

0 comments on commit 6c42b53

Please sign in to comment.