Skip to content

Commit

Permalink
support lwq for gptq (#1324)
Browse files Browse the repository at this point in the history
* [LLM]support lwq for gptq

Signed-off-by: Guo, Heng <heng.guo@intel.com>
  • Loading branch information
n1ck-guo committed Oct 30, 2023
1 parent 30142f1 commit ee54507
Show file tree
Hide file tree
Showing 11 changed files with 200 additions and 71 deletions.
10 changes: 4 additions & 6 deletions docs/source/quantization_weight_only.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,22 +173,19 @@ Large language models (LLMs) have shown exceptional performance across various t
|:--------------:|:----------:|
| RTN | &#10004; |
| AWQ | &#10005; |
| GPTQ | &#10005; |
| GPTQ | &#10004; |
| TEQ | &#10005; |

### Example
```python
from neural_compressor import PostTrainingQuantConfig, quantization
from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_shell
from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model

fp32_model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True)
fp32_model = load_empty_model(model_name_or_path, torchscript=True)
conf = PostTrainingQuantConfig(
approach="weight_only",
recipes={
"layer_wise_quant": True,
"layer_wise_quant_args": {
"model_path": "facebook/opt-125m",
},
"rtn_args": {"enable_full_range": True},
},
)
Expand All @@ -201,6 +198,7 @@ q_model = quantization.fit(
)
ouput_dir = "./saved_model"
q_model.save(ouput_dir)
q_model = load(ouput_dir, fp32_model, weight_only=True, layer_wise=True)
```

## Reference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def skip(*args, **kwargs):
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
model = AutoModel.from_pretrained(args.model_name_or_path, trust_remote_code=True)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, low_cpu_mem_usage=True, trust_remote_code=True)
model = model.eval()

Expand Down Expand Up @@ -294,7 +294,8 @@ def skip(*args, **kwargs):
dataloader=calib_dataloader,
nsamples = args.nsamples,
use_max_length = args.use_max_length,
pad_max_length = args.pad_max_length
pad_max_length = args.pad_max_length,
device = DEV,
)

results = lm_evaluate(
Expand Down
51 changes: 37 additions & 14 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3502,13 +3502,13 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
):
from .torch_utils.layer_wise_quant import LayerWiseQuant

model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None)
# model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None)
model_path = model._model.path
smooth_quant = recipe_cfgs["layer_wise_quant_args"].get("smooth_quant", False)
alpha = recipe_cfgs["layer_wise_quant_args"].get("smooth_quant_alpha", 0.5)
assert (
model_path is not None
), "the layer_wise_quant_args should have args model_path to load the weight of model."
device = recipe_cfgs["layer_wise_quant_args"].get("decvice", "cpu")
# device = recipe_cfgs["layer_wise_quant_args"].get("decvice", "cpu")
assert model_path is not None, "The model_path should not be None."
device = self.device
lw_quant = LayerWiseQuant(
q_model._model,
model_path,
Expand Down Expand Up @@ -4541,14 +4541,12 @@ def rtn_quantize(self, model, tune_cfg):
# for layer_wise quant mode
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
if recipe_cfgs.get("layer_wise_quant", False):
from neural_compressor.config import options

from .torch_utils.layer_wise_quant.utils import _get_path, load_module
from .torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, _get_path, load_module

lwq_workspace = os.path.join(options.workspace, "lwq_tmpdir")
os.makedirs(lwq_workspace, exist_ok=True)
model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None)
assert model_path, "model_path should specify in layer_wise_quant_args."
os.makedirs(LWQ_WORKSPACE, exist_ok=True)
# model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None)
model_path = model.path
assert model_path, "model_path should not be None."
model_path = _get_path(model_path)

for key, config in tune_cfg["op"].items():
Expand Down Expand Up @@ -4584,7 +4582,7 @@ def rtn_quantize(self, model, tune_cfg):
# save and clean weight
from .torch_utils.layer_wise_quant.utils import clean_module_weight

torch.save(m.state_dict(), os.path.join(lwq_workspace, f"{op_name}.pt"))
torch.save(m.state_dict(), os.path.join(LWQ_WORKSPACE, f"{op_name}.pt"))
clean_module_weight(m)
set_module(model, op_name, m)
if recipe_cfgs.get("layer_wise_quant", False):
Expand Down Expand Up @@ -4619,6 +4617,23 @@ def gptq_quantize(self, model, tune_cfg, dataloader):
...
}
"""
# for layer_wise quant mode
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
model_path = None
layer_wise = False
if recipe_cfgs.get("layer_wise_quant", False):
layer_wise = True
from .torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, _get_path, register_weight_hooks

os.makedirs(LWQ_WORKSPACE, exist_ok=True)
# model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None)
model_path = model.path
assert model_path, "model_path should not be None."
model_path = _get_path(model_path)
lwq_handles = register_weight_hooks(
model, model_path, device=self.device, clean_weight=True, saved_path=LWQ_WORKSPACE
)

weight_config = {}
for key, config in tune_cfg["op"].items():
op_name, op_type = key
Expand All @@ -4643,7 +4658,15 @@ def gptq_quantize(self, model, tune_cfg, dataloader):
)
# tune_cfg => weight_config
model, quantization_perm = gptq_quantize(
model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, self.device
model,
weight_config,
dataloader,
nsamples,
use_max_length,
pad_max_length,
self.device,
layer_wise,
model_path,
)
return model, quantization_perm

Expand Down
104 changes: 81 additions & 23 deletions neural_compressor/adaptor/torch_utils/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import math
import random
import re
Expand Down Expand Up @@ -175,6 +176,7 @@ def __init__(
use_max_length=True,
pad_max_length=2048,
device=None,
layer_wise=False,
):
"""
Args:
Expand Down Expand Up @@ -215,9 +217,13 @@ def __init__(
self.check_layer_config()

# device
self.device = model.device
self.device = device
if str(self.model.device).startswith("cuda"):
self.device = self.model.device
self.is_ready = False

self.layer_wise = layer_wise

# dataloader
self.use_max_length = use_max_length
self.pad_max_length = pad_max_length
Expand Down Expand Up @@ -438,11 +444,13 @@ def forward(layer, *args, **kwargs):
raise ValueError

# Step1: fetch the embeddings and other layers before the transformer stack.
for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
embedding_layer = embedding_layer.to(self.device)
if not self.layer_wise:
for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
embedding_layer = embedding_layer.to(self.device)

# Step2: modify the first transformer block's forward function to obtain inputs for calibration
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device)
if not self.layer_wise:
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device)
forward_cache = self.gptq_related_blocks["transformers"][0].forward
self.gptq_related_blocks["transformers"][0].forward = partial(
forward, self.gptq_related_blocks["transformers"][0]
Expand All @@ -451,7 +459,8 @@ def forward(layer, *args, **kwargs):
# Step3: run forward to obtain calibration datasets
logger.info("Collecting calibration inputs...")
for batch in tqdm(self.dataloader):
batch = move_input_to_device(batch, self.device)
if not self.layer_wise:
batch = move_input_to_device(batch, self.device)
try:
if isinstance(batch, tuple) or isinstance(batch, list):
self.model(batch[0])
Expand All @@ -473,9 +482,10 @@ def forward(layer, *args, **kwargs):

# Step 4: restore original forward function, relocate layers back to cpu.
self.gptq_related_blocks["transformers"][0].forward = forward_cache
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu()
for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
embedding_layer.to(self.device)
if not self.layer_wise:
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu()
for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
embedding_layer.to(self.device)
torch.cuda.empty_cache()
# end
logger.info("GPTQ quantization prepared.")
Expand All @@ -501,7 +511,7 @@ def update_blockwise_hidden_states(self, outs):
self.cache_positional_arguments[0] = outs[:]

@torch.no_grad()
def execute_quantization(self, means=None, stds=None):
def execute_quantization(self, means=None, stds=None, model_path=None):
"""Run quantization."""
# Step1: prepare quantization (calibration datasets)

Expand All @@ -513,7 +523,11 @@ def execute_quantization(self, means=None, stds=None):
tblock_length = len(self.gptq_related_blocks["transformers"])
for block_idx in range(tblock_length):
logger.info(f"Quantizing layer {block_idx + 1} / {tblock_length}..")
transformer_block = self.gptq_related_blocks["transformers"][block_idx].to(self.device)
if not self.layer_wise:
# if we do not apply layer-wise feature, we still place the entire block on the GPU
transformer_block = self.gptq_related_blocks["transformers"][block_idx].to(self.device)
else:
transformer_block = self.gptq_related_blocks["transformers"][block_idx] # .to(self.device)
# Step2.1: obtain all layers (Linear, Conv2d, etc) in the block which can be quantized.
sub_layers = find_layers(transformer_block)
sub_layers_to_quant = {}
Expand All @@ -534,8 +548,16 @@ def execute_quantization(self, means=None, stds=None):
# weight_config_this_layer = self.weight_config.get(
# self.get_full_layer_name(layer_name, block_idx), None
# )
weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx))
gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name])
full_layer_name = self.get_full_layer_name(layer_name, block_idx)
weight_config_this_layer = self.get_layer_config(full_layer_name)
if self.layer_wise:
from ..torch_utils.layer_wise_quant.utils import load_value

W = load_value(self.model, full_layer_name + ".weight", model_path)
else:
W = sub_layers[layer_name].weight.data.clone()

gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name], W, self.device)
# gptq_for_this_block[layer_name].quantizer = Quantizer()
gptq_for_this_block[layer_name].quantizer.configure(
weight_config_this_layer["wbits"],
Expand All @@ -555,7 +577,6 @@ def tmp(_, inp, out):
for layer_name in sub_layers:
handles.append(sub_layers[layer_name].register_forward_hook(add_batch(layer_name)))
idx = self.cache_key_arguments.pop("i")
# import pdb;pdb.set_trace()
for j in range(len(self.dataloader)):
cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j)
cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j)
Expand All @@ -570,12 +591,44 @@ def tmp(_, inp, out):
# )
weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx))
logger.info(f"Quantizing layer {layer_name}")
scale, zp = gptq_for_this_block[layer_name].fasterquant(
if self.layer_wise:
from ..torch_utils.layer_wise_quant.utils import load_value

full_layer_name = self.get_full_layer_name(layer_name, block_idx)
W = load_value(self.model, full_layer_name + ".weight", model_path)
else:
W = sub_layers[layer_name].weight.data.clone()
scale, zp, Q = gptq_for_this_block[layer_name].fasterquant(
W,
blocksize=weight_config_this_layer["block_size"],
percdamp=weight_config_this_layer["percdamp"],
groupsize=weight_config_this_layer["group_size"],
act_order=weight_config_this_layer["act_order"],
)
if self.layer_wise:
from ..torch_utils.layer_wise_quant.utils import (
LWQ_WORKSPACE,
clean_module_weight,
load_value,
set_module_tensor_to_device,
)

sub_layer = sub_layers[layer_name]
full_layer_name = self.get_full_layer_name(layer_name, block_idx)
for n, p in sub_layer.named_parameters():
param_name = full_layer_name + "." + n
if n == "weight":
set_module_tensor_to_device(self.model, param_name, self.device, Q)
else:
value = load_value(self.model, param_name, model_path)
set_module_tensor_to_device(self.model, param_name, self.device, value)
# sub_layer.weight.data = Q
torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt")
clean_module_weight(sub_layer)
del Q
gc.collect()
else:
sub_layers[layer_name].weight.data = Q
gptq_config[self.get_full_layer_name(layer_name, block_idx)] = {"scale": scale}
if not weight_config_this_layer["sym"]:
gptq_config[self.get_full_layer_name(layer_name, block_idx)]["zero"] = zp
Expand All @@ -594,7 +647,10 @@ def tmp(_, inp, out):
out = transformer_block(*cache_positional_batch, **cache_keyword_batch)[0]
outs.append(out)
self.cache_key_arguments["i"] = idx
self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu()
if self.layer_wise:
self.gptq_related_blocks["transformers"][block_idx] = transformer_block
else:
self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu()
del gptq_for_this_block
torch.cuda.empty_cache()
# iteratively replace the input with output, thus layerwise quantization can continue.
Expand All @@ -617,10 +673,10 @@ class GPTQ:
GPTQ: Accurate Post-training Compression for Generative Pretrained Transformers (https://arxiv.org/abs/2210.17323)
"""

def __init__(self, layer):
def __init__(self, layer, W, device="cpu"):
self.layer = layer
self.device = self.layer.weight.device
W = layer.weight.data.clone()
self.device = device
# W = layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d) or isinstance(self.layer, nn.Conv1d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
Expand Down Expand Up @@ -661,8 +717,9 @@ def add_batch(self, inp, out):
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self.H += inp.matmul(inp.t()) # H = X*X, which should be a sysm matrix

def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False):
W = self.layer.weight.data.clone()
def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False):
# W = self.layer.weight.data.clone()
weight_shape, weight_dtype = W.shape, W.data.dtype
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
Expand Down Expand Up @@ -740,7 +797,7 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals
# logger.info(f"{torch.sum((self.layer(self.inp1) - self.out1) ** 2)}")
# logger.info(f"{torch.sum(Losses)}")

if self.device != torch.device("cpu"):
if str(self.device).startswith("cuda"):
torch.cuda.synchronize()
logger.info(f"time {(time.time() - tick)}")
logger.info(f"error {torch.sum(Losses).item()}")
Expand All @@ -751,7 +808,8 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals

if isinstance(self.layer, transformers.Conv1D):
Q = Q.t()
self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
# self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
Q = Q.reshape(weight_shape).to(weight_dtype)
if DEBUG:
logger.info(f"{torch.sum((self.layer(self.inp1) - self.out1) ** 2)}")

Expand All @@ -760,7 +818,7 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals
zero.append(self.quantizer.zero)
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)
return scale, zero
return scale, zero, Q

def free(self):
if DEBUG:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Torch layer-wise quantization module."""
from .utils import load_shell
from .utils import load_empty_model
from .quantize import LayerWiseQuant
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
update_module,
)

TMP_DIR = os.path.join(default_workspace, "layer_wise_quant_tmp_dir")
TMP_DIR = os.path.join(default_workspace, "lwq_tmpdir")


def mk_tmp_dir():
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(
alpha=0.5,
):
"""Init LayerWiseQuant."""
# self.q_model = load_shell(pretrained_model_name_or_path, cls)
# self.q_model = load_empty_model(pretrained_model_name_or_path, cls)
self.q_model = q_model
self.fp32_model = deepcopy(self.q_model)
self.path = _get_path(pretrained_model_name_or_path)
Expand Down

0 comments on commit ee54507

Please sign in to comment.