Skip to content

Commit

Permalink
GPTQ: use optimum format by default (#2568)
Browse files Browse the repository at this point in the history
* Use HuggingFace Optimum format for GPTQ checkpoint

* Fix issue in LLM examples

---------

Co-authored-by: WeizhuoZhang-intel <weizhuo.zhang@intel.com>
  • Loading branch information
Xia-Weiwen and WeizhuoZhang-intel committed Feb 5, 2024
1 parent cfaa7a2 commit 9fcc489
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 27 deletions.
11 changes: 10 additions & 1 deletion examples/cpu/inference/python/llm/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,14 @@ def main(args_in: Optional[List[str]] = None) -> None:
"--gptq",
action="store_true",
help="Run GPTQ calibration to generate optimized INT4 weight for weight-only quantization."
"This is recommended for INT4 to minimize accuracy drop after quantization."
" This is recommended for INT4 to minimize accuracy drop after quantization."
)
parser.add_argument(
"--gptq-legacy-format",
action="store_true",
help="Indicate that the low-precision checkpoint is in the legacy format rather than the"
" HuggingFace Optimum format for backward compatibility. It must be used with"
" --low-precision-checkpoint. Otherwise, it has no effect."
)
parser.add_argument(
"--group-size",
Expand Down Expand Up @@ -357,6 +364,8 @@ def main(args_in: Optional[List[str]] = None) -> None:
str(args.low_precision_checkpoint),
]
)
if args.gptq_legacy_format:
quant_cmd.extend(["--gptq-legacy-format"])
else:
# No need to set group size if args.gptq is true
# Group size is read from the checkpoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@
"PER_BATCH_IC_BLOCK(3): quantize per block of size 1 x IC_BLOCK. "
"IC_BLOCK is determined by IC automatically.",
)
parser.add_argument(
"--gptq-legacy-format",
action="store_true",
help="Indicate that the low-precision checkpoint is in the legacy format rather than the"
" HuggingFace Optimum format for backward compatibility. It must be used with"
" --low-precision-checkpoint. Otherwise, it has no effect."
)
args = parser.parse_args()


Expand Down Expand Up @@ -605,15 +612,11 @@ def calib_func(prepared_model):
)
if args.low_precision_checkpoint != "":
low_precision_checkpoint = torch.load(args.low_precision_checkpoint)
config_dict = {
"weight_key": "qweight",
"scale_key": "scales",
"zero_point_key": "qzeros",
"bias_key": "bias",
"g_idx_key": "g_idx"
}
state_dict_and_config = (low_precision_checkpoint, config_dict)
low_precision_checkpoint = state_dict_and_config
if args.gptq_legacy_format:
config_dict = (
ipex.utils.weight_only_quantization._legacy_lowp_checkpoint_config()
)
low_precision_checkpoint = (low_precision_checkpoint, config_dict)
else:
low_precision_checkpoint = None
user_model = ipex.llm.optimize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def gptq_export(
compression_dim=compression_dim,
scale_dtype=scale_dtype,
device=torch.device("cpu"),
use_optimum_format=False,
use_optimum_format=True,
)
new_module.pack(int_weight, gptq_scale, gptq_zp, m.bias, gptq_perm)
set_module(model, k, new_module)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def __init__(self, module, tpp=False, woq=False):
concat_weight = torch.concat(weights_list, 0)
concat_scales = torch.concat(scales_list, 0)
concat_zeros = torch.concat(zeros_list, 0)
use_bias = all(bias_list)
use_bias = all([b is not None for b in bias_list])
concat_bias = torch.concat(bias_list, 0) if use_bias else None
mod = nn.Linear(
concat_weight.shape[1], concat_weight.shape[0], use_bias
Expand Down
87 changes: 83 additions & 4 deletions intel_extension_for_pytorch/utils/weight_only_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@
# Weight shape is N by K if transposed is False otherwise K by N.
# Bias is optional. If bias is not provided in the checkpoint, we read the original model.
DEFAULT_LOWP_CHECKPOINT_CONFIG = {
"name": "default",
"name": "optimum",
"use_optimum_format": True,
"weight_key": "qweight",
"scale_key": "scales",
"zero_point_key": "qzeros",
"bias_key": "bias",
"g_idx_key": "g_idx",
}

LEGACY_LOWP_CHECKPOINT_CONFIG = {
"name": "legacy",
"use_optimum_format": False,
"weight_key": "packed_weight",
"scale_key": "scale",
"zero_point_key": "packed_zp",
Expand All @@ -31,14 +42,75 @@ def _default_lowp_checkpoint_config():
return DEFAULT_LOWP_CHECKPOINT_CONFIG


def _legacy_lowp_checkpoint_config():
return LEGACY_LOWP_CHECKPOINT_CONFIG


def _get_keys_from_config(checkpoint_config):
weight_key = checkpoint_config.get("weight_key", "weight")
scales_key = checkpoint_config.get("scale_key", "scale")
zeros_key = checkpoint_config.get("zero_point_key", "zero")
weight_key = checkpoint_config.get("weight_key", "qweight")
scales_key = checkpoint_config.get("scale_key", "scales")
zeros_key = checkpoint_config.get("zero_point_key", "qzeros")
bias_key = checkpoint_config.get("bias_key", "bias")
return weight_key, scales_key, zeros_key, bias_key


def _convert_optimum_format_to_desired(qweight, scales, qzeros):
"""
Optimum format:
qweight: (math.ceil(IC / comp_ratio), OC)
scales: (n_groups, OC)
qzeros: (n_groups, math.ceil(OC / comp_ratio))
qzeros are substracted by 1 before packing
Desired format:
compression_dim = 1
qweight: (OC, math.ceil(IC / comp_ratio))
scales: (OC, n_groups)
qzeros: (OC, math.ceil(n_groups / comp_ratio))
Note:
IC = input channels or input features
OC = output channels or output features
n_groups = math.ceil(IC / group_size)
comp_ratio = compression data type bits // weight or zeros data type bits
E.g., compression dtype = int32, weight dtype = int4, comp_ratio = 32 / 4 = 8
"""
if qweight is None:
return qweight, scales, qzeros
oc = qweight.shape[1]
assert oc == scales.shape[1]
n_groups = scales.shape[0]
qweight = qweight.t_().contiguous()
scales = scales.t_().contiguous()
if qzeros is None:
return qweight, scales, qzeros
zp_dtype = torch.int32
zp = torch.empty((n_groups, oc), dtype=zp_dtype)
# Steps to convert qzeros:
# (1) unpack qzeros to (n_groups, OC)
# (2) take transpose
# (3) plus one and handle overflow
zp_bits = 4 # int4
comp_dtype_bits = 32 # int32
comp_ratio = comp_dtype_bits // zp_bits
mask = torch.tensor(2**zp_bits - 1, dtype=zp_dtype)
for j in range(qzeros.shape[1]):
packed_data = qzeros[:, j]
for e in range(comp_ratio):
index = j * comp_ratio + e
if index >= zp.shape[1]:
continue
data = (packed_data >> (zp_bits * e)) & mask
zp[:, index] = data.type(zp_dtype)
zp = zp.t_().contiguous()
zp += 1
# it may overflow after adding one
zp = torch.where(zp > (2**zp_bits - 1), 0, zp)

return qweight, scales, zp


def _get_linear_parameters(attr_name, state_dict, checkpoint_config):
weight_key, scales_key, zeros_key, bias_key = _get_keys_from_config(
checkpoint_config
Expand All @@ -52,6 +124,13 @@ def _get_linear_parameters(attr_name, state_dict, checkpoint_config):
scales = state_dict.get(s_key, None)
qzeros = state_dict.get(z_key, None)
bias = state_dict.get(b_key, None)

use_optimum_format = checkpoint_config.get("use_optimum_format", True)
if use_optimum_format:
qweight, scales, qzeros = _convert_optimum_format_to_desired(
qweight, scales, qzeros
)

group_size = -1
if qweight is not None and scales is not None:
assert scales.dim() == 2, "Unexpected scales tensor dimension"
Expand Down
72 changes: 70 additions & 2 deletions tests/cpu/test_ipex_optimize_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,8 @@ def test_static_quant_flow(self):
if not hasattr(ipex_m, "trace_graph"):
AssertionError(False)

def test_weight_only_quant_gptq(self):
# import json
def test_weight_only_quant_gptq_legacy(self):
# Test the legacy format
config = AutoConfig.from_pretrained(
f"{curpath}/hf_configs/gptj", return_dict=False
)
Expand Down Expand Up @@ -404,6 +404,74 @@ def test_weight_only_quant_gptq(self):
torch.save(state_dict, checkpoint_file_name)
state_dict = torch.load(checkpoint_file_name)

# test loading checkpoint and quant info
lowp_mode = ipex.quantization.WoqLowpMode.INT8
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
lowp_mode=lowp_mode
)
config_dict = (
ipex.utils.weight_only_quantization._legacy_lowp_checkpoint_config()
)
ipex_m = ipex.llm.optimize(
ipex_m,
dtype=torch.float,
quantization_config=qconfig,
low_precision_checkpoint=(state_dict, config_dict),
deployment_mode=True,
inplace=True,
)
assert hasattr(ipex_m, "trace_graph")

# Ensure model can run without errors
with torch.no_grad():
example_inputs = _get_gptj_example_inputs()
# the optimized model is ipex_m.trace_graph
ipex_m.trace_graph(*example_inputs)

def test_weight_only_quant_gptq(self):
# Test the HuggingFace Optimum format
config = AutoConfig.from_pretrained(
f"{curpath}/hf_configs/gptj", return_dict=False
)
m = transformers.models.gptj.modeling_gptj.GPTJForCausalLM(config).eval()
ipex_m = copy.deepcopy(m)
with tempfile.TemporaryDirectory() as work_dir:
# Generate dummy checkpoint
checkpoint_file_name = work_dir + "/checkpoint.pt"
state_dict = ipex_m.state_dict()
linear_keys = []
for k, v in state_dict.items():
if any(
k.endswith(suffix)
for suffix in ["proj.weight", "fc_in.weight", "fc_out.weight"]
):
linear_keys.append(k[:-7])
group_size = 128
comp_ratio = 8
for k in linear_keys:
N = state_dict[k + ".weight"].shape[0]
K = state_dict[k + ".weight"].shape[1]
del state_dict[k + ".weight"]
n_groups = K // group_size
stored_weight_shape = (K // comp_ratio, N)
stored_scales_shape = (n_groups, N)
stored_zeros_shape = (n_groups, N // comp_ratio)
state_dict[k + ".qweight"] = torch.randint(
-(2**31), 2**31 - 1, stored_weight_shape, dtype=torch.int32
)
state_dict[k + ".scales"] = torch.randn(
stored_scales_shape, dtype=torch.half
)
state_dict[k + ".qzeros"] = torch.randint(
-(2**31), 2**31 - 1, stored_zeros_shape, dtype=torch.int32
)
g_idx = torch.arange(n_groups).repeat(group_size)
g_idx[:] = g_idx[torch.randperm(K)]
state_dict[k + ".g_idx"] = g_idx

torch.save(state_dict, checkpoint_file_name)
state_dict = torch.load(checkpoint_file_name)

# test loading checkpoint and quant info
lowp_mode = ipex.quantization.WoqLowpMode.INT8
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
Expand Down
10 changes: 1 addition & 9 deletions tests/cpu/test_quantization_default_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,14 +729,6 @@ def _get_gptj_example_inputs():
self.assertTrue(torch.allclose(out0[0], out1[0], atol=1e-05))

low_precision_checkpoint = torch.load(work_dir + "/gptq_checkpoint_g128.pt")
config_dict = {
"weight_key": "qweight",
"scale_key": "scales",
"zero_point_key": "qzeros",
"bias_key": "bias",
"g_idx_key": "g_idx",
}
state_dict_and_config = (low_precision_checkpoint, config_dict)
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
weight_dtype=torch.quint4x2,
lowp_mode=ipex.quantization.WoqLowpMode.INT8,
Expand All @@ -748,7 +740,7 @@ def _get_gptj_example_inputs():
dtype=torch.float,
quantization_config=qconfig,
inplace=True,
low_precision_checkpoint=state_dict_and_config,
low_precision_checkpoint=low_precision_checkpoint,
deployment_mode=False,
)
_IPEXAttentionCPU = (
Expand Down

0 comments on commit 9fcc489

Please sign in to comment.