Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 68 additions & 23 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import traceback
from collections import defaultdict
from dataclasses import asdict, fields
from functools import partial
from typing import Any, Callable, Optional, Union

import accelerate
Expand Down Expand Up @@ -96,7 +97,6 @@
llm_load_model,
memory_monitor,
mv_module_from_gpu,
normalize_input,
set_amax_for_all_moe_layers,
set_module,
to_device,
Expand Down Expand Up @@ -1918,6 +1918,45 @@ def _get_block_outputs(

return output

def normalize_decoding_layer_inputs_(self, decoding_layer_inputs: list[tuple[tuple[Any, dict[str, Any]]]]):
"""
Processes and stores decoding layer inputs for block quantization.

This function iterates through a list of captured decoding layer calls,
replaying them through a fake decoding layer to extract and store the
inputs required for the decoding block in `self.inputs`. This effectively
"normalizes" the inputs by making them accessible in a consistent format
for subsequent quantization steps.

Args:
decoding_layer_inputs:
A list of entries captured by a forward hook on the decoding layer.
Each element is expected to be a tuple whose first item is
`(args, kwargs)`, where `args` are the positional arguments and
`kwargs` are the keyword arguments seen during the original
forward pass.

The capture hook look like:

def input_capture_hook(module, *args, **kwargs):
_all_module_input[module._tmp_name].append((args, kwargs))
"""
first_block_name = self.quant_block_list[0][0]

class _FakeDecodingLayer(torch.nn.Module):
def forward(self, *args, **kwargs):
return args, kwargs

fake_layer = _FakeDecodingLayer()
fake_layer.orig_forward = fake_layer.forward
fake_layer.forward = partial(self._get_block_forward_func(first_block_name), fake_layer)

self.inputs = {}
self.last_cache_name = None
for step_input in decoding_layer_inputs:
args, kwargs = step_input[0]
fake_layer(*args, **kwargs)

@torch.no_grad()
def calib(self, nsamples, bs):
"""Perform calibration for quantization.
Expand Down Expand Up @@ -2346,7 +2385,6 @@ def _recover_forward(self):

def _replace_forward(self):
"""Replaces the forward function."""
from functools import partial

for n, m in self.model.named_modules():
if n in self.to_cached_layers and type(m) not in self.supported_types: ##block
Expand Down Expand Up @@ -2652,7 +2690,10 @@ def quantize_block(
"DiffusionCompressor",
"MLLMCompressor",
], f"Currently, {self.__class__.__name__} does not support support quantize block with this function."
input_ids, input_others = normalize_input(inputs)
self.normalize_decoding_layer_inputs_(inputs)
block_inputs = self.inputs[self.quant_block_list[0][0]]
decoding_layer_first_input_name = "hidden_states"
Copy link
Contributor

@wenhuach21 wenhuach21 Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the origin code have such assumption? Is there any case that the first_input_name is not hidden_states

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The origin code replaces the hidden_states with input_ids.

input_ids, input_others = self._preprocess_block_inputs(block_inputs, decoding_layer_first_input_name)
return self._quantize_block(block, input_ids, input_others, q_input, device, auto_offload)

def _get_loss(
Expand Down Expand Up @@ -2959,12 +3000,32 @@ def _quantize_block(

return None, output

def _split_inputs(self, inputs: dict) -> tuple[torch.Tensor, dict]:
input_ids = inputs["input_ids"]
inputs.pop("input_ids", None)
def _split_inputs(self, inputs: dict, first_input_name: str) -> tuple[torch.Tensor, dict]:
input_ids = inputs[first_input_name]
inputs.pop(first_input_name, None)
input_others = inputs
return input_ids, input_others

def _preprocess_block_inputs(self, inputs, first_input_name="input_ids"):
input_ids, input_others = self._split_inputs(inputs, first_input_name)
clear_memory(device_list=self.device_list)
Copy link
Contributor

@wenhuach21 wenhuach21 Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don''t know why clear_memory is needed here. @n1ck-guo do you know?

input_ids = to_device(input_ids, self.cache_device)
input_others = to_device(input_others, self.cache_device)
# As in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage

tmp_dtype = self.amp_dtype if self.amp else torch.float32
input_ids = to_dtype(input_ids, tmp_dtype)

for key in input_others.keys():
if isinstance(input_others[key], torch.Tensor) and (
input_others[key].dtype == torch.float16 or input_others[key].dtype == torch.bfloat16
):
input_others[key] = input_others[key].to(tmp_dtype)
elif isinstance(input_others[key], list):
for i in range(len(input_others[key])):
to_dtype(input_others[key][i], tmp_dtype)
return input_ids, input_others

def _quantize_blocks(
self,
model: torch.nn.Module,
Expand All @@ -2991,23 +3052,7 @@ def _quantize_blocks(
for n, m in model.named_parameters():
m.requires_grad_(False)

input_ids, input_others = self._split_inputs(inputs)
clear_memory(device_list=self.device_list)
input_ids = to_device(input_ids, self.cache_device)
input_others = to_device(input_others, self.cache_device)
# As in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage

tmp_dtype = self.amp_dtype if self.amp else torch.float32
input_ids = to_dtype(input_ids, tmp_dtype)

for key in input_others.keys():
if isinstance(input_others[key], torch.Tensor) and (
input_others[key].dtype == torch.float16 or input_others[key].dtype == torch.bfloat16
):
input_others[key] = input_others[key].to(tmp_dtype)
elif isinstance(input_others[key], list):
for i in range(len(input_others[key])):
to_dtype(input_others[key][i], tmp_dtype)
input_ids, input_others = self._preprocess_block_inputs(inputs)

if pbar is None:
pbar = tqdm(range(0, len(block_names), nblocks))
Expand Down
17 changes: 0 additions & 17 deletions auto_round/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,20 +329,3 @@ def get_reciprocal(tensor):
recip[mask] = 0.0

return recip


def normalize_input(
decoding_layer_inputs: tuple[Union[list[torch.Tensor], dict, Any], Optional[dict]],
) -> Tuple[List[torch.Tensor], Dict[str, Any]]:
"""Normalize the decoding layer inputs into input_ids and other inputs."""
input_ids = []
input_others = {"positional_inputs": []}
for cur_inp in decoding_layer_inputs:
input_ids.append(cur_inp[0][0][0])
for key, val in cur_inp[0][1].items():
input_others[key] = val
# Force 'use_cache' to be False
if "use_cache" in input_others and input_others["use_cache"] is True:
logger.warning_once("Forcing 'use_cache' to be False during calibration.")
input_others["use_cache"] = False
return input_ids, input_others