diff --git a/README.md b/README.md index 5a98d1b46..9874745f4 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ AutoRound AutoRound is an advanced weight-only quantization algorithm for low-bits LLM inference. It's tailored for a wide range of models and consistently delivers noticeable improvements, often significantly outperforming SignRound with the cost of more tuning time for quantization. -our method adopts sign gradient descent to fine-tune rounding values and minmax values of weights in just 200 steps, which competes impressively against recent methods without introducing any additional inference overhead. The below image presents an overview of AutoRound. +Our method adopts sign gradient descent to fine-tune rounding values and minmax values of weights in just 200 steps, which competes impressively against recent methods without introducing any additional inference overhead. The below image presents an overview of AutoRound.
@@ -114,7 +114,7 @@ Please run the quantization code first. ##Install the latest https://github.com/intel/intel-extension-for-transformers from source first. from intel_extension_for_transformers.transformers import AutoModelForCausalLM from transformers import AutoTokenizer - + quantized_model_path = "./tmp_autoround" model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="auto", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(quantized_model_path, use_fast=True) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 6f96506d7..2f6fa8c16 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -18,6 +18,7 @@ import torch +from .special_model_handler import check_hidden_state_dim, check_share_attention_mask from .utils import ( CpuInfo, block_forward, @@ -26,7 +27,6 @@ collect_minmax_scale, collect_round_v, detect_device, - get_batch_dim, get_block_names, get_module, get_scale_shape, @@ -591,7 +591,7 @@ def set_layerwise_config(self, weight_config): m.scale_dtype = weight_config[n]["scale_dtype"] @torch.no_grad() - def get_block_outputs(self, block, input_ids, input_others, bs, device, cache_device, batch_dim): + def get_block_outputs(self, block, input_ids, input_others, bs, device, cache_device): """Compute the output of a given block of the model for a given input. Args: @@ -611,12 +611,14 @@ def get_block_outputs(self, block, input_ids, input_others, bs, device, cache_de for i in range(0, self.n_samples, bs): end_index = min(self.n_samples, i + bs) indices = torch.arange(i, end_index).to(torch.long) - tmp_input_ids, tmp_input_others = sampling_inputs(input_ids, input_others, indices, self.seqlen) + tmp_input_ids, tmp_input_others = sampling_inputs( + input_ids, input_others, indices, self.seqlen, self.share_attention_mask_flag, self.input_dim + ) tmp_output = block_forward(block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device).to( cache_device ) output.append(tmp_output) - output = torch.cat(output, dim=batch_dim) + output = torch.cat(output, dim=self.input_dim) torch.cuda.empty_cache() return output @@ -711,6 +713,8 @@ def cache_block_input(self, block_name, n_samples): """ self.inputs = {} self.tmp_block_name = block_name + self.share_attention_mask_flag = None + self.hidden_dim_flag = None self._replace_forward() self.calib(n_samples) self._recover_forward() @@ -729,9 +733,21 @@ def get_forward_func(self, name): """ def forward(_, hidden_states, *positional_args, **kwargs): - dim = int((hasattr(self.model, "config") and "chatglm" in self.model.config.model_type)) + """Rewrite forward function, process and collect input data. + + Args: + hidden_states (torch.Tensor): The hidden states tensor. + *positional_args: Variable number of positional arguments. + **kwargs: Variable number of keyword arguments. + + Returns: + NotImplementedError: Getting the first layer inputs and then raise the error to save runtime. + """ + if self.share_attention_mask_flag is None: + self.input_dim = check_hidden_state_dim(self.model, positional_args) + self.share_attention_mask_flag = check_share_attention_mask(self.model, hidden_states, **kwargs) if name in self.inputs: - data = torch.cat([self.inputs[name]["input_ids"], hidden_states.to("cpu")], dim=dim) + data = torch.cat([self.inputs[name]["input_ids"], hidden_states.to("cpu")], dim=self.input_dim) self.inputs[name]["input_ids"] = data else: self.inputs[name] = {} @@ -748,7 +764,7 @@ def forward(_, hidden_states, *positional_args, **kwargs): if key not in self.inputs[name].keys(): self.inputs[name][key] = None if kwargs[key] is not None: - if self.inputs[name][key] is not None: + if (not self.share_attention_mask_flag) and self.inputs[name][key] is not None: self.inputs[name][key] = torch.cat( [self.inputs[name][key], kwargs[key].to("cpu")], dim=0 ) @@ -761,7 +777,7 @@ def forward(_, hidden_states, *positional_args, **kwargs): alibi = kwargs[key] batch = kwargs["attention_mask"].shape[0] alibi = alibi.reshape(batch, -1, alibi.shape[1], alibi.shape[2]) - if self.inputs[name][key] is not None: + if (not self.share_attention_mask_flag) and self.inputs[name][key] is not None: self.inputs[name][key] = torch.cat([self.inputs[name][key], alibi.to("cpu")], dim=0) else: self.inputs[name][key] = alibi.to("cpu") @@ -804,14 +820,13 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch """ from torch.amp import autocast - batch_dim = get_batch_dim(input_others) if not self.low_gpu_mem_usage and input_ids.device != device: input_ids = move_input_to_device(input_ids, device) input_others = move_input_to_device(input_others, device) cache_device = device if self.low_gpu_mem_usage: cache_device = "cpu" - output = self.get_block_outputs(block, input_ids, input_others, self.train_bs, device, cache_device, batch_dim) + output = self.get_block_outputs(block, input_ids, input_others, self.train_bs, device, cache_device) if q_input is not None: input_ids = q_input.to(cache_device) @@ -842,7 +857,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch pick_samples = self.train_bs if len(input_ids.shape) == 3: - n_samples = input_ids.shape[batch_dim] + n_samples = input_ids.shape[self.input_dim] else: n_samples = input_ids.shape[0] // self.seqlen if self.sampler != "rand": @@ -860,12 +875,17 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch total_loss = 0 for _ in range(self.gradient_accumulate_steps): current_input_ids, current_input_others = sampling_inputs( - input_ids, input_others, indices, seqlen=self.seqlen + input_ids, + input_others, + indices, + seqlen=self.seqlen, + share_attention_mask_flag=self.share_attention_mask_flag, + input_dim=self.input_dim, ) if len(input_ids.shape) == 3: - if batch_dim == 0: + if self.input_dim == 0: current_output = output[indices, :, :] - elif batch_dim == 1: + elif self.input_dim == 1: current_output = output[:, indices, :] else: current_output = output[:, :, indices] @@ -923,9 +943,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch unwrapper_block(block, best_v, best_min_scale, best_max_scale) if self.use_quant_input: - q_outputs = self.get_block_outputs( - block, input_ids, input_others, self.train_bs, device, cache_device, batch_dim - ) + q_outputs = self.get_block_outputs(block, input_ids, input_others, self.train_bs, device, cache_device) return q_outputs, output diff --git a/auto_round/export/export_to_itrex/export.py b/auto_round/export/export_to_itrex/export.py index 6f58b6517..578f50ec8 100644 --- a/auto_round/export/export_to_itrex/export.py +++ b/auto_round/export/export_to_itrex/export.py @@ -123,7 +123,7 @@ def pack_model( m = get_module(compressed_model, k) fp_weight = m.weight.data scale, zp = v["scale"], v["zp"] - convert_dtype = torch.float32 if fp_weight.device.type == "cpu" else scale_dtype + convert_dtype = scale_dtype if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale, dtype=convert_dtype) zp = torch.tensor(zp, dtype=torch.int32) @@ -133,6 +133,7 @@ def pack_model( zp = zp.clone() scale = scale.to(dtype=convert_dtype) zp = zp.to(dtype=torch.int32) + int_weight = quant_weight_w_scale(fp_weight, scale, zp, group_size, fp_weight.device) int_weight = int_weight.type(torch.int32) new_module = WeightOnlyLinear( diff --git a/auto_round/special_model_handler.py b/auto_round/special_model_handler.py new file mode 100644 index 000000000..82c1aaa18 --- /dev/null +++ b/auto_round/special_model_handler.py @@ -0,0 +1,56 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .utils import logger, torch + +share_attention_mask_tuple = ("baichuan",) +special_states_dim_tuple = ("chatglm",) + + +def check_share_attention_mask(model, hidden_states, attention_mask=None, **kwargs): + """Checks if the attention mask states of the hidden states are shared in the model. + + Args: + hidden_states (torch.Tensor): The hidden states of the model. + attention_mask (torch.Tensor, optional): The attention mask tensor. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + bool: True if attention mask is shared in the model, False otherwise. + """ + if attention_mask is None or not isinstance(hidden_states, torch.Tensor): + return False + is_special = False + for key in share_attention_mask_tuple: + if hasattr(model, "config") and key in model.config.model_type: + is_special = True + break + return bool(is_special and attention_mask.shape[0] != hidden_states.shape[0]) + + +def check_hidden_state_dim(model, positional_args): + """Checks the dimensionality of the hidden states. + + Args: + positional_args: The positional arguments. + + Returns: + int: 1 if the model type is 'chatglm' and positional arguments are not None, 0 otherwise. + """ + is_special = False + for key in special_states_dim_tuple: + if hasattr(model, "config") and key in model.config.model_type: + is_special = True + break + return int(is_special and positional_args is not None) diff --git a/auto_round/utils.py b/auto_round/utils.py index e31bf95bb..6fb78ebbe 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -25,6 +25,7 @@ logger = logging.getLogger("autoround") logger.setLevel(logging.INFO) +logger.propagate = False fh = logging.StreamHandler() fh_formatter = logging.Formatter("%(asctime)s %(levelname)s %(filename)s L%(lineno)d: %(message)s", "%Y-%m-%d %H:%M:%S") fh.setFormatter(fh_formatter) @@ -400,21 +401,7 @@ def collect_minmax_scale(block): return min_scales, max_scales -@torch.no_grad() -def get_batch_dim(input_others): - """Gets the batch dimension based on the input positional inputs. - - Args: - input_others: A dictionary containing input data. - - Returns: - dim: The batch dimension. - """ - dim = int(len(input_others["positional_inputs"]) > 0) - return dim - - -def sampling_inputs(input_ids, input_others, indices, seqlen): +def sampling_inputs(input_ids, input_others, indices, seqlen, share_attention_mask_flag=False, input_dim=0): """Samples inputs based on the given indices and sequence length. Args: @@ -428,7 +415,7 @@ def sampling_inputs(input_ids, input_others, indices, seqlen): current_input_others: The sampled other input data. """ if len(input_ids.shape) == 3: - if int(len(input_others["positional_inputs"]) > 0): + if bool(input_dim): current_input_ids = input_ids[:, indices, :] else: current_input_ids = input_ids[indices, :, :] @@ -437,10 +424,9 @@ def sampling_inputs(input_ids, input_others, indices, seqlen): current_input_ids = input_ids.view(n_samples, seqlen, -1) current_input_ids = current_input_ids[indices, :, :] current_input_ids = current_input_ids.reshape(-1, input.shape[-1]) - current_input_others = {"positional_inputs": input_others["positional_inputs"]} for key in input_others.keys(): - if "attention_mask" in key or "alibi" in key: + if not share_attention_mask_flag and ("attention_mask" in key or "alibi" in key): current_input_others[key] = None if input_others[key] is not None: current_input_others[key] = input_others[key][indices, ...]