Skip to content

Commit

Permalink
support double quant for weight-only (#1420)
Browse files Browse the repository at this point in the history
Description
popular repos, like bitsandbytes, llama.cpp, provides double quant on scale to improve the compression ratio of weight-only model. Theoretically, double quant will quantize the scales of several blocks and use a hyper scale and a hyper zeropoint to recover it.

Using below args in RTNWeightQuantConfig to set double quant for scales

double_quant_dtype
double_quant_bits
double_quant_sym
double_quant_group_size
Demo code can be found in UT.
  • Loading branch information
xin3he committed Nov 30, 2023
1 parent de385a4 commit 05c15a4
Show file tree
Hide file tree
Showing 9 changed files with 1,008 additions and 298 deletions.
13 changes: 13 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/__init__.py
@@ -0,0 +1,13 @@
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Expand Up @@ -233,6 +233,7 @@ def __init__(
# weight config
self.weight_config = weight_config
# default settings, check configs
self.wdtype_default = "int"
self.wbits_default = 4
self.group_size_default = 128
self.block_size_default = 128
Expand All @@ -241,6 +242,10 @@ def __init__(
self.act_order_default = False
self.perchannel_default = True
self.mse_default = False
self.double_quant_dtype_default = "fp32"
self.double_quant_bits_default = 4
self.double_quant_group_size_default = 128
self.double_quant_sym_default = False
self.check_layer_config()

# device
Expand Down Expand Up @@ -285,6 +290,7 @@ def check_layer_config(self):
tmp_weight_config = {}
for name, module in self.model.named_modules():
tmp_weight_config[name] = {}
tmp_weight_config[name]["wdtype"] = self.weight_config.get("wdtype", self.wdtype_default)
tmp_weight_config[name]["wbits"] = self.weight_config.get("wbits", self.wbits_default)
tmp_weight_config[name]["group_size"] = self.weight_config.get("group_size", self.group_size_default)
tmp_weight_config[name]["block_size"] = self.weight_config.get("block_size", self.group_size_default)
Expand All @@ -293,9 +299,22 @@ def check_layer_config(self):
tmp_weight_config[name]["act_order"] = self.weight_config.get("act_order", self.act_order_default)
tmp_weight_config[name]["perchannel"] = self.weight_config.get("perchannel", self.perchannel_default)
tmp_weight_config[name]["mse"] = self.weight_config.get("mse", self.mse_default)
tmp_weight_config[name]["double_quant_dtype"] = self.weight_config.get(
"double_quant_dtype", self.double_quant_dtype_default
)
tmp_weight_config[name]["double_quant_bits"] = self.weight_config.get(
"double_quant_bits", self.double_quant_bits_default
)
tmp_weight_config[name]["double_quant_group_size"] = self.weight_config.get(
"double_quant_group_size", self.double_quant_group_size_default
)
tmp_weight_config[name]["double_quant_sym"] = self.weight_config.get(
"double_quant_sym", self.double_quant_sym_default
)
self.weight_config = tmp_weight_config
else:
for layer_name, config in self.weight_config.items():
self.weight_config[layer_name]["wdtype"] = config.get("wdtype", self.wdtype_default)
self.weight_config[layer_name]["wbits"] = config.get("wbits", self.wbits_default)
self.weight_config[layer_name]["group_size"] = config.get("group_size", self.group_size_default)
self.weight_config[layer_name]["block_size"] = config.get("block_size", self.group_size_default)
Expand All @@ -304,6 +323,18 @@ def check_layer_config(self):
self.weight_config[layer_name]["act_order"] = config.get("act_order", self.act_order_default)
self.weight_config[layer_name]["perchannel"] = config.get("perchannel", self.perchannel_default)
self.weight_config[layer_name]["mse"] = config.get("mse", self.mse_default)
self.weight_config[layer_name]["double_quant_dtype"] = config.get(
"double_quant_dtype", self.double_quant_dtype_default
)
self.weight_config[layer_name]["double_quant_bits"] = config.get(
"double_quant_bits", self.double_quant_bits_default
)
self.weight_config[layer_name]["double_quant_group_size"] = config.get(
"double_quant_group_size", self.double_quant_group_size_default
)
self.weight_config[layer_name]["double_quant_sym"] = config.get(
"double_quant_sym", self.double_quant_sym_default
)

def get_layer_config(self, layer_name):
"""Obtain config for one layer, since GPTQ supports layer-wise config."""
Expand Down Expand Up @@ -467,12 +498,7 @@ def execute_quantization(self, means=None, stds=None, model_path=None):
W = sub_layers[layer_name].weight.data.clone()
gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name], W, self.device)
# gptq_for_this_block[layer_name].quantizer = Quantizer()
gptq_for_this_block[layer_name].quantizer.configure(
weight_config_this_layer["wbits"],
weight_config_this_layer["perchannel"],
weight_config_this_layer["sym"],
weight_config_this_layer["mse"],
)
gptq_for_this_block[layer_name].quantizer.configure(weight_config_this_layer)

# Step 2.3: modify forward functions to hook inputs data (used in gptq execution)
def add_batch(_name):
Expand Down Expand Up @@ -656,7 +682,9 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)

q = quantize(w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq).flatten()
q = self.quantizer.quantize(
w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq
).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2

Expand Down Expand Up @@ -712,11 +740,13 @@ def __init__(self, shape=1):
self.register_buffer("scale", torch.zeros(shape))
self.register_buffer("zero", torch.zeros(shape))

def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=0.8, trits=False):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
self.mse = mse
def configure(self, weight_config_this_layer, norm=2.4, grid=100, maxshrink=0.8, trits=False):
for k, v in weight_config_this_layer.items():
setattr(self, k, v)
self.maxq = torch.tensor(2**self.wbits - 1)
self.scheme = "sym" if self.sym else "asym"
self.double_quant = self.double_quant_dtype != "fp32"
self.double_quant_scheme = "sym" if self.double_quant_sym else "asym"
self.norm = norm
self.grid = grid
self.maxshrink = maxshrink
Expand All @@ -726,7 +756,30 @@ def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=
def find_params(self, x, weight=False):
dev = x.device
self.maxq = self.maxq.to(dev)

# NF4 FP4
if self.wdtype != "int":
from .rtn import quant_weight

_, scale, zero = quant_weight(
x,
self.wbits,
self.group_size,
scheme=self.scheme,
data_type=self.wdtype,
quantile=1.0,
return_int=True,
full_range=False,
double_quant=self.double_quant,
double_quant_dtype=self.double_quant_dtype,
double_quant_bits=self.double_quant_bits,
double_quant_scheme=self.double_quant_scheme,
double_quant_group_size=self.double_quant_group_size,
double_quant_return_int=False,
)
self.scale = scale
self.zero = torch.zeros_like(scale)
return
# INT
shape = x.shape
if self.perchannel:
if weight:
Expand Down Expand Up @@ -773,7 +826,7 @@ def find_params(self, x, weight=False):
xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / self.maxq
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
q = self.quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
q -= x
q.abs_()
q.pow_(self.norm)
Expand All @@ -795,6 +848,23 @@ def find_params(self, x, weight=False):
shape = [-1] + [1] * (len(shape) - 1)
self.scale = self.scale.reshape(shape)
self.zero = self.zero.reshape(shape)

if self.double_quant:
from .rtn import quant_weight

orig_scale_shape = self.scale.shape
self.scale = self.scale.reshape(1, -1)
self.scale = quant_weight(
self.scale,
self.double_quant_bits,
self.double_quant_group_size,
scheme=self.double_quant_scheme,
data_type=self.double_quant_dtype,
quantile=1.0,
return_int=False,
full_range=False,
)
self.scale = self.scale.reshape(orig_scale_shape)
return
if len(shape) == 4:
self.scale = self.scale.reshape((1, -1, 1, 1))
Expand All @@ -806,13 +876,17 @@ def find_params(self, x, weight=False):
self.scale = self.scale.unsqueeze(0)
self.zero = self.zero.unsqueeze(0)

# def quantize(self, x):
# if self.ready():
# return quantize(x, self.scale, self.zero, self.maxq)
# return x
def quantize(self, x, scale, zero, maxq):
"""Do quantization."""
if self.wdtype != "int":
from .rtn import quantize_4bit

# def enabled(self):
# return self.maxq > 0
return quantize_4bit(x, data_type=self.wdtype, scale=scale)
else:
if maxq < 0:
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
return scale * (q - zero)

def ready(self):
return torch.all(self.scale != 0)
Expand Down Expand Up @@ -848,13 +922,18 @@ def gptq_config_mapping(configs_mapping: Dict[Tuple[str, Callable], GPTQConfig])
continue
else:
weight_config[op_name] = {
"wdtype": op_config.weight_dtype,
"wbits": op_config.weight_bits,
"group_size": op_config.weight_group_size,
"sym": op_config.weight_sym,
"percdamp": op_config.percdamp,
"act_order": op_config.act_order,
"block_size": op_config.block_size,
"mse": op_config.enable_mse_search,
"double_quant_dtype": op_config.double_quant_dtype,
"double_quant_bits": op_config.double_quant_bits,
"double_quant_group_size": op_config.double_quant_group_size,
"double_quant_sym": op_config.double_quant_sym,
}
nsamples = op_config.nsamples
dataloader_len = op_config.dataloader_len
Expand Down

0 comments on commit 05c15a4

Please sign in to comment.