-
Notifications
You must be signed in to change notification settings - Fork 62
Refactor input normalization by replaying inputs for consistent preprocessing #1094
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
a8494ac
add test for llmc
yiliu30 f9e07ca
add llmc
yiliu30 cd16b7d
rm ct
yiliu30 2f8fce1
fix
yiliu30 ce60aee
Merge branch 'main' into add-llmc-test
yiliu30 081926f
merge main
yiliu30 62d9160
fix
yiliu30 1252c26
fix device
yiliu30 0c2ec8d
Merge branch 'main' into add-llmc-test
XuehaoSun 19f5da9
fix
yiliu30 80d20b6
tmp fix attn mask
yiliu30 ca018f9
fix attn mask
yiliu30 6b21310
nornalize input by replaying
yiliu30 ab5221b
merge main
yiliu30 0eefd47
update
yiliu30 e36dba9
update
yiliu30 8ea0e5b
update
yiliu30 b8af29e
fix
yiliu30 fd858f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 86402e9
fix
yiliu30 01422dc
Merge branch 'fix-attn-mask' of https://github.com/intel/auto-round i…
yiliu30 677b414
update hints
yiliu30 1d209d1
fix
yiliu30 aed8626
fix
yiliu30 9aa573d
fix preprocess
yiliu30 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| import traceback | ||
| from collections import defaultdict | ||
| from dataclasses import asdict, fields | ||
| from functools import partial | ||
| from typing import Any, Callable, Optional, Union | ||
|
|
||
| import accelerate | ||
|
|
@@ -96,7 +97,6 @@ | |
| llm_load_model, | ||
| memory_monitor, | ||
| mv_module_from_gpu, | ||
| normalize_input, | ||
| set_amax_for_all_moe_layers, | ||
| set_module, | ||
| to_device, | ||
|
|
@@ -1918,6 +1918,45 @@ def _get_block_outputs( | |
|
|
||
| return output | ||
|
|
||
| def normalize_decoding_layer_inputs_(self, decoding_layer_inputs: list[tuple[tuple[Any, dict[str, Any]]]]): | ||
| """ | ||
| Processes and stores decoding layer inputs for block quantization. | ||
|
|
||
| This function iterates through a list of captured decoding layer calls, | ||
| replaying them through a fake decoding layer to extract and store the | ||
| inputs required for the decoding block in `self.inputs`. This effectively | ||
| "normalizes" the inputs by making them accessible in a consistent format | ||
| for subsequent quantization steps. | ||
|
|
||
| Args: | ||
| decoding_layer_inputs: | ||
| A list of entries captured by a forward hook on the decoding layer. | ||
| Each element is expected to be a tuple whose first item is | ||
| `(args, kwargs)`, where `args` are the positional arguments and | ||
| `kwargs` are the keyword arguments seen during the original | ||
| forward pass. | ||
|
|
||
| The capture hook look like: | ||
|
|
||
| def input_capture_hook(module, *args, **kwargs): | ||
| _all_module_input[module._tmp_name].append((args, kwargs)) | ||
| """ | ||
| first_block_name = self.quant_block_list[0][0] | ||
|
|
||
| class _FakeDecodingLayer(torch.nn.Module): | ||
| def forward(self, *args, **kwargs): | ||
| return args, kwargs | ||
|
|
||
| fake_layer = _FakeDecodingLayer() | ||
| fake_layer.orig_forward = fake_layer.forward | ||
| fake_layer.forward = partial(self._get_block_forward_func(first_block_name), fake_layer) | ||
|
|
||
| self.inputs = {} | ||
| self.last_cache_name = None | ||
| for step_input in decoding_layer_inputs: | ||
| args, kwargs = step_input[0] | ||
| fake_layer(*args, **kwargs) | ||
|
|
||
| @torch.no_grad() | ||
| def calib(self, nsamples, bs): | ||
| """Perform calibration for quantization. | ||
|
|
@@ -2346,7 +2385,6 @@ def _recover_forward(self): | |
|
|
||
| def _replace_forward(self): | ||
| """Replaces the forward function.""" | ||
| from functools import partial | ||
|
|
||
| for n, m in self.model.named_modules(): | ||
| if n in self.to_cached_layers and type(m) not in self.supported_types: ##block | ||
|
|
@@ -2652,7 +2690,10 @@ def quantize_block( | |
| "DiffusionCompressor", | ||
| "MLLMCompressor", | ||
| ], f"Currently, {self.__class__.__name__} does not support support quantize block with this function." | ||
| input_ids, input_others = normalize_input(inputs) | ||
| self.normalize_decoding_layer_inputs_(inputs) | ||
| block_inputs = self.inputs[self.quant_block_list[0][0]] | ||
| decoding_layer_first_input_name = "hidden_states" | ||
| input_ids, input_others = self._preprocess_block_inputs(block_inputs, decoding_layer_first_input_name) | ||
| return self._quantize_block(block, input_ids, input_others, q_input, device, auto_offload) | ||
|
|
||
| def _get_loss( | ||
|
|
@@ -2959,12 +3000,32 @@ def _quantize_block( | |
|
|
||
| return None, output | ||
|
|
||
| def _split_inputs(self, inputs: dict) -> tuple[torch.Tensor, dict]: | ||
| input_ids = inputs["input_ids"] | ||
| inputs.pop("input_ids", None) | ||
| def _split_inputs(self, inputs: dict, first_input_name: str) -> tuple[torch.Tensor, dict]: | ||
| input_ids = inputs[first_input_name] | ||
| inputs.pop(first_input_name, None) | ||
| input_others = inputs | ||
| return input_ids, input_others | ||
|
|
||
| def _preprocess_block_inputs(self, inputs, first_input_name="input_ids"): | ||
| input_ids, input_others = self._split_inputs(inputs, first_input_name) | ||
| clear_memory(device_list=self.device_list) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don''t know why clear_memory is needed here. @n1ck-guo do you know? |
||
| input_ids = to_device(input_ids, self.cache_device) | ||
| input_others = to_device(input_others, self.cache_device) | ||
| # As in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage | ||
|
|
||
| tmp_dtype = self.amp_dtype if self.amp else torch.float32 | ||
| input_ids = to_dtype(input_ids, tmp_dtype) | ||
|
|
||
| for key in input_others.keys(): | ||
| if isinstance(input_others[key], torch.Tensor) and ( | ||
| input_others[key].dtype == torch.float16 or input_others[key].dtype == torch.bfloat16 | ||
| ): | ||
| input_others[key] = input_others[key].to(tmp_dtype) | ||
| elif isinstance(input_others[key], list): | ||
| for i in range(len(input_others[key])): | ||
| to_dtype(input_others[key][i], tmp_dtype) | ||
| return input_ids, input_others | ||
|
|
||
| def _quantize_blocks( | ||
| self, | ||
| model: torch.nn.Module, | ||
|
|
@@ -2991,23 +3052,7 @@ def _quantize_blocks( | |
| for n, m in model.named_parameters(): | ||
| m.requires_grad_(False) | ||
|
|
||
| input_ids, input_others = self._split_inputs(inputs) | ||
| clear_memory(device_list=self.device_list) | ||
| input_ids = to_device(input_ids, self.cache_device) | ||
| input_others = to_device(input_others, self.cache_device) | ||
| # As in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage | ||
|
|
||
| tmp_dtype = self.amp_dtype if self.amp else torch.float32 | ||
| input_ids = to_dtype(input_ids, tmp_dtype) | ||
|
|
||
| for key in input_others.keys(): | ||
| if isinstance(input_others[key], torch.Tensor) and ( | ||
| input_others[key].dtype == torch.float16 or input_others[key].dtype == torch.bfloat16 | ||
| ): | ||
| input_others[key] = input_others[key].to(tmp_dtype) | ||
| elif isinstance(input_others[key], list): | ||
| for i in range(len(input_others[key])): | ||
| to_dtype(input_others[key][i], tmp_dtype) | ||
| input_ids, input_others = self._preprocess_block_inputs(inputs) | ||
|
|
||
| if pbar is None: | ||
| pbar = tqdm(range(0, len(block_names), nblocks)) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does the origin code have such assumption? Is there any case that the first_input_name is not
hidden_statesThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The origin code replaces the
hidden_stateswithinput_ids.