diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 48cf0cd1a..b186b92e8 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -111,7 +111,7 @@ class BaseCompressor(object): sym (bool): Whether to use symmetric weight quantization. layer_config (dict): Per-layer quantization configuration. nsamples (int): Number of calibration samples. - enable_torch_compile (bool): Whether to enable torch.compile for quant blocks/layers. + enable_torch_compile (bool): Whether to enable compile_func for quant blocks/layers. """ bits: int | None @@ -361,6 +361,7 @@ def __init__( self.infer_bs_coeff = 1 self.enable_torch_compile = enable_torch_compile self._adjust_torch_compile(enable_torch_compile) + self.block_forward = compile_func(block_forward, self.device) if self.enable_torch_compile else block_forward self._check_configs() torch.set_printoptions(precision=3, sci_mode=True) @@ -1428,6 +1429,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None: enable_minmax_tuning=False, enable_norm_bias_tuning=False, enable_round_tuning=False, + enable_torch_compile=self.enable_torch_compile, ) m = m.unwrapper({}) m.to("cpu") @@ -1443,6 +1445,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None: enable_minmax_tuning=False, enable_norm_bias_tuning=False, enable_round_tuning=False, + enable_torch_compile=self.enable_torch_compile, ) m = m.unwrapper({}) except Exception as e: @@ -1882,6 +1885,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: enable_round_tuning=False, enable_minmax_tuning=False, enable_norm_bias_tuning=False, + enable_torch_compile=self.enable_torch_compile, device=self.device, ) new_layer = wrapper_layer.unwrapper({}) @@ -1911,10 +1915,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) clear_memory() - if self.enable_torch_compile: - quant_layer = compile_func(self._quantize_layer, self.device) - else: - quant_layer = self._quantize_layer + quant_layer = self._quantize_layer for layer_name in layer_names: layer_input = layer_inputs[layer_name] layer_input = to_device(layer_input, self.cache_device) @@ -2093,9 +2094,9 @@ def _get_block_outputs( tmp_input_ids, tmp_input_others = self._sampling_inputs( input_ids, input_others, indices, self.seqlen, self.batch_dim, share_cache_keys=self.shared_cache_keys ) - tmp_output = block_forward(block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device).to( - cache_device - ) + tmp_output = self.block_forward( + block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device + ).to(cache_device) if save_output: if self.batch_size == 1: output.append(tmp_output) @@ -2518,7 +2519,12 @@ def _quantize_layer( if q_inputs is not None: q_inputs[i] = q_inputs[i].to(layer.weight.dtype) - wrapper_linear = WrapperLinear(layer, enable_minmax_tuning=self.enable_minmax_tuning, device=device).to(device) + wrapper_linear = WrapperLinear( + layer, + enable_minmax_tuning=self.enable_minmax_tuning, + enable_torch_compile=self.enable_torch_compile, + device=device, + ).to(device) round_params = [] minmax_params = [] for key in wrapper_linear.params.keys(): @@ -2696,7 +2702,7 @@ def _get_current_q_output( batch_dim=self.batch_dim, share_cache_keys=self.shared_cache_keys, ) - output_q = block_forward(block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device) + output_q = self.block_forward(block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device) return output_q def _get_current_num_elm( @@ -2781,7 +2787,11 @@ def _quantize_block( input_ids = q_input quantized_layer_names, unquantized_layer_names = wrapper_block( - block, self.enable_minmax_tuning, self.enable_norm_bias_tuning, device=self.device + block, + self.enable_minmax_tuning, + self.enable_norm_bias_tuning, + enable_torch_compile=self.enable_torch_compile, + device=self.device, ) if is_nv_fp(self.data_type): # enable qkv and moe structure global_scale fuse from auto_round.data_type.utils import update_fused_layer_global_scales @@ -3008,14 +3018,8 @@ def _quantize_blocks( logger.info("using algorithm extension for quantization.") except (ImportError, ModuleNotFoundError): quantize_block = self._quantize_block - if self.enable_torch_compile: - quantize_block = compile_func(quantize_block, device) - else: - quantize_block = quantize_block else: quantize_block = self._quantize_block - if self.enable_torch_compile: - quantize_block = compile_func(quantize_block, device) if pbar is None: pbar = tqdm(range(0, len(block_names), nblocks)) diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index f6bffa94d..98fa80289 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -22,6 +22,7 @@ from .utils import ( SUPPORTED_LAYER_TYPES, check_to_quantized, + compile_func, deepspeed_exists, get_scale_shape, is_mx_fp, @@ -67,6 +68,7 @@ class WrapperLinear(torch.nn.Module): orig_layer (torch.nn.Module): The original layer to be wrapped (linear or conv1d). enable_minmax_tuning (bool): Whether to enable min-max scale tuning. enable_norm_bias_tuning (bool): Whether to enable normalization and tuning of the bias term. + enable_torch_compile (bool): Whether to enable torch compilation. device (str): Device on which to run computations (e.g., 'cpu' or 'cuda'). """ @@ -77,6 +79,7 @@ def __init__( enable_norm_bias_tuning=False, device="cpu", enable_round_tuning=True, + enable_torch_compile=False, **kwargs, ): """Initializes the WrapperLinear module. @@ -93,6 +96,7 @@ def __init__( self.device = self.orig_layer.tuning_device if hasattr(self.orig_layer, "tuning_device") else device self.enable_minmax_tuning = enable_minmax_tuning self.enable_round_tuning = enable_round_tuning + self.enable_torch_compile = enable_torch_compile self.enable_norm_bias_tuning = enable_norm_bias_tuning and (orig_layer.bias is not None) self.enable_act_quant = self.orig_layer.act_bits <= 8 self.weight_global_scale = getattr(self.orig_layer, "weight_global_scale", None) @@ -143,11 +147,15 @@ def _init_tuning_params_and_quant_func(self): self._init_params("max_scale", p_dtype, shape, 1.0, (self.enable_minmax_tuning and self.orig_layer.bits < 16)) self.weight_quant_func, self.data_type = get_quant_func(orig_layer.data_type, orig_layer.bits, orig_layer.sym) + if self.enable_torch_compile: + self.weight_quant_func = compile_func(self.weight_quant_func, self.device) if self.enable_act_quant: self.act_quant_func, self.act_data_type = get_quant_func( orig_layer.act_data_type, orig_layer.act_bits, orig_layer.act_sym ) + if self.enable_torch_compile: + self.act_quant_func = compile_func(self.act_quant_func, self.device) self._init_params("act_max_scale", p_dtype, (1), 1.0, not orig_layer.act_dynamic) ## bias tuning @@ -372,7 +380,11 @@ def _set_dict_attr(attr_dict, attr_name): self.orig_layer.act_data_type = self.act_data_type self.orig_layer.act_quant_func = self.act_quant_func - wrapper_layer = WrapperWALayer(self.orig_layer) + wrapper_layer = WrapperWALayer( + self.orig_layer, + enable_torch_compile=self.enable_torch_compile, + device=self.device, + ) return wrapper_layer return self.orig_layer @@ -452,12 +464,16 @@ def forward(self, x): class WrapperWALayer(torch.nn.Module): - def __init__(self, orig_layer): + def __init__(self, orig_layer, enable_torch_compile=False, device="cpu"): super(WrapperWALayer, self).__init__() self.orig_layer = orig_layer + self.enable_torch_compile = enable_torch_compile + self.device = device self.data_type = orig_layer.data_type if hasattr(orig_layer, "data_type") else None self.act_data_type = orig_layer.act_data_type if hasattr(orig_layer, "act_data_type") else None self.act_quant_func = self.orig_layer.act_quant_func + if self.enable_torch_compile: + self.act_quant_func = compile_func(self.act_quant_func, self.device) self.extra_repr_org = orig_layer.extra_repr def forward(self, x): @@ -609,12 +625,17 @@ def forward(self, x, **kwargs): return hidden_states -def wrapper_block(block, enable_minmax_tuning, enable_norm_bias_tuning, device="cpu", **kwargs): +def wrapper_block( + block, enable_minmax_tuning, enable_norm_bias_tuning, enable_torch_compile=False, device="cpu", **kwargs +): """Wraps the layers in the given block with a custom Wrapper module. Args: block: The input block containing linear and conv1d layers to be wrapped. enable_minmax_tuning: A boolean indicating whether min-max tuning is enabled. + enable_norm_bias_tuning: A boolean indicating whether normalization and bias tuning is enabled. + enable_torch_compile: A boolean indicating whether to enable torch compilation. + device: The device to which the wrapped layers should be moved. Returns: list: A list of names of the wrapped layers and unwrapped layers. @@ -630,6 +651,7 @@ def wrapper_block(block, enable_minmax_tuning, enable_norm_bias_tuning, device=" m, enable_minmax_tuning=enable_minmax_tuning, enable_norm_bias_tuning=enable_norm_bias_tuning, + enable_torch_compile=enable_torch_compile, device=device, **kwargs, )