Skip to content

Commit

Permalink
[LLM] Support WOQ scheme asym (#1266)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Chang <chang1.wang@intel.com>
  • Loading branch information
changwangss committed Feb 8, 2024
1 parent 6c0bd77 commit c7f0b70
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def forward(
compute_dtype=None,
weight_dtype=None,
scale_dtype=None,
scheme=None,
):
# # 1. Dequantize
# B_dequant = torch.zeros(out.shape[-1], A.shape[-1], dtype=torch.float)
Expand Down Expand Up @@ -114,7 +115,7 @@ def forward(
compute_dtype,
weight_dtype,
scale_dtype,
False,
False if scheme == "sym" else True,
)
else:
out = qbits_woq_linear_ref_impl(
Expand Down Expand Up @@ -154,7 +155,7 @@ def backward(ctx, grad_output):
None,
)

req_gradA, _, _, req_gradBias, _, _, _ = ctx.needs_input_grad
req_gradA, _, _, req_gradBias, _, _, _, _ = ctx.needs_input_grad
A, B = ctx.tensors
grad_A, grad_B, grad_bias = None, None, None

Expand All @@ -176,7 +177,7 @@ def backward(ctx, grad_output):
if req_gradA:
grad_A = torch.matmul(grad_output, B.to(grad_output.dtype))

return grad_A, grad_B, None, grad_bias, None, None, None
return grad_A, grad_B, None, grad_bias, None, None, None, None


def matmul_kbit(
Expand All @@ -187,11 +188,13 @@ def matmul_kbit(
compute_dtype,
weight_dtype,
scale_dtype,
scheme,
do_dequant=False,
):

if do_dequant:
return MatMulKBit.apply(
A, B, out, bias, compute_dtype, weight_dtype, scale_dtype
A, B, out, bias, compute_dtype, weight_dtype, scale_dtype, scheme
)
else:
qbits_debug_flag = os.getenv('QBITS_DEBUG', 'NULL')
Expand All @@ -206,7 +209,7 @@ def matmul_kbit(
compute_dtype,
weight_dtype,
scale_dtype,
False,
False if scheme == "sym" else True,
)
else:
out = qbits_woq_linear_ref_impl(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def forward(self, x: torch.Tensor):
self.compute_dtype if self.compute_dtype is not None else "fp32",
self.weight_dtype,
self.scale_dtype if self.scale_dtype is not None else "fp32",
self.scheme,
do_dequant=self.training,
)
shape[-1] = self.out_features
Expand All @@ -146,7 +147,7 @@ def set_weights_bias(self, weight_data, bias=None):
self.compute_dtype if self.compute_dtype is not None else "fp32",
self.weight_dtype,
self.scale_dtype if self.scale_dtype is not None else "fp32",
False,
False if self.scheme == "sym" else True,
)
self.weight = ParamsQBits(
data=weight,
Expand Down Expand Up @@ -315,7 +316,7 @@ def merge(self, safe_merge: bool = False) -> None:
self.compute_dtype,
self.weight_dtype,
self.scale_dtype,
False,
False if self.scheme == "sym" else True,
)

self.weight = ParamsQBits(
Expand Down Expand Up @@ -360,7 +361,7 @@ def unmerge(self) -> None:
self.compute_dtype,
self.weight_dtype,
self.scale_dtype,
False,
False if self.scheme == "sym" else True,
)

self.weight = ParamsQBits(
Expand Down
5 changes: 5 additions & 0 deletions intel_extension_for_transformers/transformers/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def post_init(self):
if not isinstance(self.scheme, str):
raise ValueError("scheme must be a string")

if self.scheme == "asym" and (self.compute_dtype == "int8" or self.weight_dtype.startswith("fp") \
or self.weight_dtype.startswith("nf") or self.scale_dtype != "fp32"):
raise ValueError("WeightOnlyQuantization doesn't support asym with \
compute_dtype int8 or weight_dtype float or scale_dtype non-fp32 now, \
please use sym scheme")
self.use_llm_runtime = False

def post_init_xpu(self):
Expand Down
11 changes: 9 additions & 2 deletions tests/CI/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def test_quantization_for_llm(self):
quantization_config=woq_config,
use_llm_runtime=False
)
woq_model.eval()
output = woq_model(dummy_input)
print("output:", float(output[0][0][0][0]))
self.assertTrue(isclose(float(output[0][0][0][0]), 0.16387596726417542, rel_tol=1e-04))
Expand All @@ -368,6 +369,7 @@ def test_quantization_for_llm(self):
quantization_config=woq_config,
use_llm_runtime=False
)
woq_model.eval()
output = woq_model(dummy_input)
print("output:", float(output[0][0][0][0]))
self.assertTrue(isclose(float(output[0][0][0][0]), 0.17239853739738464, rel_tol=1e-04))
Expand All @@ -380,23 +382,25 @@ def test_quantization_for_llm(self):
quantization_config=woq_config,
use_llm_runtime=False
)
woq_model.eval()
output = woq_model(dummy_input)
# fp8
#fp8
woq_config = WeightOnlyQuantConfig(weight_dtype="fp8_e5m2", scale_dtype="fp8_e8m0")
woq_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, quantization_config=woq_config, use_llm_runtime=False
)
woq_model.eval()
output = woq_model(dummy_input)
self.assertTrue(
isclose(float(output[0][0][0][0]), 0.16162332892417908, rel_tol=1e-04)
)

# amp
amp_config = MixedPrecisionConfig()
amp_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
quantization_config=amp_config,
use_llm_runtime=False
)
amp_model.eval()
output = amp_model(dummy_input)
self.assertTrue(isclose(float(output[0][0][0][0]), 0.1689453125, rel_tol=1e-04))
# bitsandbytes, for cpu is fp32 model
Expand All @@ -410,6 +414,7 @@ def test_quantization_for_llm(self):
load_in_4bit=True,
use_llm_runtime=False
)
bit4_model.eval()
output = bit4_model(dummy_input)
print("output:", float(output[0][0][0][0]))
self.assertTrue(isclose(float(output[0][0][0][0]), 0.18726778030395508, rel_tol=1e-04))
Expand All @@ -420,6 +425,7 @@ def test_quantization_for_llm(self):
use_llm_runtime=False,
device_map="cpu"
)
bit8_model.eval()
output = bit8_model(dummy_input)
print("output:", float(output[0][0][0][0]))
self.assertTrue(isclose(float(output[0][0][0][0]), 0.1675747185945511, rel_tol=1e-04))
Expand All @@ -441,6 +447,7 @@ def test_quantization_for_llm(self):
quantization_config=woq_config,
use_llm_runtime=False
)
woq_model.eval()
output = woq_model(dummy_input)
print("output:", float(output[0][0][0][0]))
self.assertTrue(isclose(float(output[0][0][0][0]), 0.17126554250717163, rel_tol=1e-04))
Expand Down

0 comments on commit c7f0b70

Please sign in to comment.