Skip to content

Commit

Permalink
Fix WOQ int8 unpack weight (#1393)
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss committed Mar 20, 2024
1 parent ae54f69 commit edede40
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ def recover_qparms(self):
qzeros = torch.ops.bestlaop.acquire_woq_packw_info(self.weight, 10)
if bits == 4:
qzeros = qzeros // 16 + 8
else:
qzeros = (qzeros.to(torch.int32) + 128).to(torch.uint8)
else:
qzeros = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,45 @@


def unpack_weight(qweight, scales, qzeros, q_config):
sym = q_config.sym
bits = q_config.bits
wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)

zeros = torch.bitwise_right_shift(
torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)
).to(torch.int16 if bits == 8 else torch.int8)
torch.bitwise_and(zeros, (2**bits) - 1, out=zeros)
if bits == 8:
zeros = zeros.to(torch.int8)
zeros = zeros.to(torch.int8 if sym else torch.uint8)
# due to INC minus one
zeros = zeros + 1
zeros = zeros.reshape(scales.shape)
try:
zeros = zeros.reshape(scales.shape)
except:
# zeros and scales have different iteam numbers.
# remove 1 (due to 0 + 1 in line 68)
zeros = zeros[zeros !=1]
zeros = zeros.reshape(scales.shape)

# due to INC asym return torch.uint8 but backend request int8,
# change it to int8 with offset 128
if not sym and bits == 8:
zeros = (zeros.to(torch.int32) - 128).to(torch.int8)

weight = torch.bitwise_right_shift(
torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)
).to(torch.int16 if bits == 8 else torch.int8)
torch.bitwise_and(weight, (2**bits) - 1, out=weight)

if bits == 8:
weight = weight.to(torch.int8)
# due to INC add shift bias for sym
if sym:
shift_bias = 2 ** (bits - 1)
weight -= shift_bias
weight = weight.to(torch.int8 if sym else torch.uint8)
# due to INC asym return torch.uint8 but backend request int8,
# change it to int8 with offset 128
if not sym:
weight = (weight.to(torch.int32) - 128). to(torch.int8)
return weight, scales, zeros


Expand Down Expand Up @@ -238,7 +259,7 @@ def _replace_linear(
model._modules[name].requires_grad_(False)
if device == "cpu" or device == torch.device("cpu") or device == "auto":
if quantization_config.weight_dtype in \
["fp8_e5m2", "fp8_e4m3", "nf4", "fp4", "int8", "int4_fullrange"]:
["fp8_e5m2", "fp8_e4m3", "nf4", "fp4", "int4_fullrange"]:
model._modules[name].set_fp_weights_bias(
module.weight.data,
None if module.bias is None else module.bias.data,
Expand Down Expand Up @@ -506,7 +527,7 @@ def default_calib_func(model):

q_model = replace_linear(model, None, None, config, device=device)
else:
if config.weight_dtype not in ["nf4", "fp4", "int8", "int4_fullrange"]:
if config.weight_dtype not in ["nf4", "fp4", "int4_fullrange"]:
inc_model = inc_model.export_compressed_model(use_optimum_format=True)
inc_model.eval()
q_model = replace_linear(inc_model, None, None, config, device=device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def save_low_bit(
return

if self.quantization_config.weight_dtype not in \
["fp8_e5m2", "fp8_e4m3", "nf4", "fp4", "int8", "int4_fullrange"]:
["fp8_e5m2", "fp8_e4m3", "nf4", "fp4", "int4_fullrange"]:
convert_model_to_public(self)
os.makedirs(save_directory, exist_ok=True)
# use transformers original `save_pretrained` function
Expand Down Expand Up @@ -1171,7 +1171,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
else:
model = model_class(config, *model_args, **kwargs)
if config.quantization_config["weight_dtype"] not in \
["fp8_e5m2", "fp8_e4m3", "fp4", "nf4", "int8", "int4_fullrange"]:
["fp8_e5m2", "fp8_e4m3", "fp4", "nf4", "int4_fullrange"]:
model = build_woq_model(model, quantization_config)
else:
model = replace_linear(
Expand Down Expand Up @@ -1222,7 +1222,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
if config.quantization_config["weight_dtype"] not in \
["fp8_e5m2", "fp8_e4m3", "int8", "nf4", "fp4" "int4_fullrange"]:
["fp8_e5m2", "fp8_e4m3", "nf4", "fp4" "int4_fullrange"]:
model = replace_linear(
model,
quantization_config=quantization_config,
Expand Down
2 changes: 1 addition & 1 deletion tests/CI/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def test_quantization_for_llm(self):
)
bit8_model.eval()
output = bit8_model(dummy_input)
self.assertTrue(isclose(float(output[0][0][0][0]), 0.1675747185945511, rel_tol=1e-04))
self.assertTrue(isclose(float(output[0][0][0][0]), 0.16759155690670013, rel_tol=1e-04))

# GPTQ
woq_config = GPTQConfig(bits=4,
Expand Down
10 changes: 8 additions & 2 deletions tests/CI/test_weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,14 @@
Trainer
)
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
from intel_extension_for_transformers.transformers.llm.quantization.nn.modules import QuantizedLinearQBits, QuantizedLoraLinearQBits
from intel_extension_for_transformers.transformers.llm.quantization.utils import convert_to_quantized_model, replace_linear
from intel_extension_for_transformers.transformers.llm.quantization.nn.modules import (
QuantizedLinearQBits,
QuantizedLoraLinearQBits
)
from intel_extension_for_transformers.transformers.llm.quantization.utils import (
convert_to_quantized_model,
replace_linear
)
from intel_extension_for_transformers.transformers.llm.utils.generation import _beam_search, _greedy_search
from intel_extension_for_transformers.transformers import RtnConfig

Expand Down

0 comments on commit edede40

Please sign in to comment.