-
Notifications
You must be signed in to change notification settings - Fork 1k
[bugfix] fix gptq_v2 #6126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[bugfix] fix gptq_v2 #6126
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,5 @@ | ||
| pip install "transformers<4.52" | ||
|
|
||
| CUDA_VISIBLE_DEVICES=0 \ | ||
| swift export \ | ||
| --model Qwen/Qwen2.5-72B-Instruct \ | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,3 +1,5 @@ | ||||||
| pip install "transformers<4.52" | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the other
Suggested change
|
||||||
|
|
||||||
| CUDA_VISIBLE_DEVICES=0,1 \ | ||||||
| swift export \ | ||||||
| --model Qwen/Qwen3-30B-A3B \ | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,8 @@ | |
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import transformers | ||
| from packaging import version | ||
| from tqdm import tqdm | ||
|
|
||
| from swift.llm import (ExportArguments, HfConfigFactory, MaxLengthError, ProcessorMixin, deep_getattr, load_dataset, | ||
|
|
@@ -41,6 +43,10 @@ def quantize(self): | |
| elif args.quant_method in {'gptq', 'gptq_v2'}: | ||
| self.template.model = self.model | ||
| gptq_quantizer = self.gptq_model_quantize(v2=(args.quant_method == 'gptq_v2')) | ||
| if args.quant_method == 'gptq_v2': | ||
| if not getattr(self.model, '_dynamic_tied_weights_keys', None): | ||
| self.model._dynamic_tied_weights_keys = [] | ||
| self.model._dynamic_tied_weights_keys += ['wf_unsqueeze_zero', 'wf_unsqueeze_neg_one'] | ||
| gptq_quantizer.save( | ||
| self.model, | ||
| args.output_dir, | ||
|
|
@@ -76,7 +82,7 @@ def _prepare_gptq_dataset(self, examples: List[Dict[str, torch.LongTensor]], bat | |
| @torch.inference_mode() | ||
| def _get_quant_dataset(self, *args, **kwargs): | ||
| args = self.args | ||
| assert args.quant_method in {'awq', 'gptq'} | ||
| assert args.quant_method in {'awq', 'gptq', 'gptq_v2'} | ||
| template = self.template | ||
| n_samples = args.quant_n_samples | ||
| block_size = args.max_length | ||
|
|
@@ -96,7 +102,7 @@ def _get_quant_dataset(self, *args, **kwargs): | |
| inputs = template.encode(data) | ||
| except MaxLengthError: | ||
| continue | ||
| if is_multimodal and args.quant_method == 'gptq': | ||
| if is_multimodal and args.quant_method in {'gptq', 'gptq_v2'}: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The condition is_gptq_family = args.quant_method in {'gptq', 'gptq_v2'}You could then use |
||
| inputs.pop('labels', None) | ||
| samples.append(inputs) | ||
| else: | ||
|
|
@@ -107,15 +113,15 @@ def _get_quant_dataset(self, *args, **kwargs): | |
| if i == n_samples: | ||
| break | ||
| prog_bar.close() | ||
| if is_multimodal and args.quant_method == 'gptq': | ||
| if is_multimodal and args.quant_method in {'gptq', 'gptq_v2'}: | ||
| return samples | ||
| # now concatenate all samples and split according to block size | ||
| n_split = max(len(samples) // block_size, 1) | ||
| logger.info(f'Split into {n_split} blocks') | ||
| res = [] | ||
| for i in range(n_split): | ||
| input_ids = samples[i * block_size:(i + 1) * block_size] | ||
| if args.quant_method == 'gptq': | ||
| if args.quant_method in {'gptq', 'gptq_v2'}: | ||
| res.append({'input_ids': input_ids}) | ||
| else: | ||
| res.append(torch.tensor(input_ids)[None]) | ||
|
|
@@ -226,6 +232,29 @@ def get_modules_in_block_to_quantize(model, block_name: str): | |
| res[experts_idx:experts_idx] = experts.values() | ||
| return res | ||
|
|
||
| @contextmanager | ||
| def _patch_gptq_block(self, model, block_name_to_quantize): | ||
| if version.parse(transformers.__version__) < version.parse('4.54'): | ||
| yield | ||
| return | ||
| # compat transformers>=4.54 | ||
| blocks = deep_getattr(model, block_name_to_quantize) | ||
| hooks = [] | ||
|
|
||
| def _to_tuple(module, input, output): | ||
| if not isinstance(output, (list, tuple)): | ||
| output = (output, ) | ||
| return output | ||
|
|
||
| for block in blocks: | ||
| hooks.append(block.register_forward_hook(_to_tuple)) | ||
|
|
||
| try: | ||
| yield | ||
| finally: | ||
| for hook in hooks: | ||
| hook.remove() | ||
|
|
||
| def gptq_model_quantize(self, v2: bool = False): | ||
| from optimum.gptq import GPTQQuantizer | ||
| args = self.args | ||
|
|
@@ -247,7 +276,8 @@ def gptq_model_quantize(self, v2: bool = False): | |
| logger.info('Start quantizing the model...') | ||
| logger.warning('The process of packing the model takes a long time and there is no progress bar. ' | ||
| 'Please be patient and wait...') | ||
| gptq_quantizer.quantize_model(self.model, self.tokenizer) | ||
| with self._patch_gptq_block(self.model, block_name_to_quantize): | ||
| gptq_quantizer.quantize_model(self.model, self.tokenizer) | ||
| self.model.config.quantization_config.pop('dataset', None) | ||
| return gptq_quantizer | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve maintainability and help other developers, it would be beneficial to add a comment explaining why
transformers<4.52is required. This provides context for the version pinning.