Skip to content

Commit

Permalink
weight only quantization (#1349)
Browse files Browse the repository at this point in the history
* Update weight only quantization config

Signed-off-by: Cheng Penghui <penghui.cheng@intel.com>
  • Loading branch information
PenghuiCheng committed Mar 9, 2024
1 parent a2ce911 commit 6a458f0
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,12 @@
args.model, trust_remote_code=args.trust_remote_code, device_map=args.device, torch_dtype=torch_dtype) \
if user_model is None else user_model
user_model = user_model.to(memory_format=torch.channels_last)
if quantization_config is None:
quantization_config = WeightOnlyQuantConfig.from_pretrained(args.model)
if not args.disable_optimize_transformers:
print("Optimize with IPEX...")
user_model = ipex.optimize_transformers(
user_model.eval(), device=args.device, inplace=True, woq=(hasattr(user_model, "quantization_config")), dtype=torch_dtype)
user_model.eval(), device=args.device, inplace=True, quantization_config=quantization_config, dtype=torch_dtype)
else:
print("Disabled optimization with IPEX...")
# start
Expand Down Expand Up @@ -263,10 +265,12 @@
user_model = AutoModelForCausalLM.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code, device_map=args.device, torch_dtype=torch_dtype) \
if user_model is None else user_model
if quantization_config is None:
quantization_config = WeightOnlyQuantConfig.from_pretrained(args.model)
if not args.disable_optimize_transformers:
print("Optimize with IPEX...")
user_model = ipex.optimize_transformers(
user_model.eval(), device=args.device, inplace=True, woq=(hasattr(user_model, "quantization_config")), dtype=torch_dtype)
user_model.eval(), device=args.device, inplace=True, quantization_config=quantization_config, dtype=torch_dtype)
else:
print("Disabled optimization with IPEX...")
results = evaluate(
Expand All @@ -287,4 +291,3 @@
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]))
else:
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]))

2 changes: 1 addition & 1 deletion intel_extension_for_transformers/llm/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _replace_linear(
)
elif device == "xpu" or device == torch.device("xpu"):
from intel_extension_for_pytorch.nn.utils._quantize_convert \
import WeightOnlyLinear as ipex_linear # pylint: disable=E0401
import WeightOnlyQuantizedLinear as ipex_linear # pylint: disable=E0401
model._modules[name] = ipex_linear(
in_features,
out_features,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,12 @@


def convert_model_to_public(model):
from intel_extension_for_pytorch.nn.utils._quantize_convert import WeightOnlyLinear # pylint: disable=E0401
# pylint: disable=E0401
from intel_extension_for_pytorch.nn.utils._quantize_convert import(
WeightOnlyQuantizedLinear
)
for name, module in model.named_modules():
if isinstance(module, WeightOnlyLinear):
if isinstance(module, WeightOnlyQuantizedLinear):
if module.weight_transposed:
module.qweight.data = module.qweight.t_().contiguous()
module.scales.data = module.scales.t_().contiguous()
Expand Down

0 comments on commit 6a458f0

Please sign in to comment.