Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ AutoRound

AutoRound is an advanced weight-only quantization algorithm for low-bits LLM inference. It's tailored for a wide range of models and consistently delivers noticeable improvements, often significantly outperforming SignRound with the cost of more tuning time for quantization.

our method adopts sign gradient descent to fine-tune rounding values and minmax values of weights in just 200 steps, which competes impressively against recent methods without introducing any additional inference overhead. The below image presents an overview of AutoRound.
Our method adopts sign gradient descent to fine-tune rounding values and minmax values of weights in just 200 steps, which competes impressively against recent methods without introducing any additional inference overhead. The below image presents an overview of AutoRound.

<div align="center">

Expand Down Expand Up @@ -114,7 +114,7 @@ Please run the quantization code first.
##Install the latest https://github.com/intel/intel-extension-for-transformers from source first.
from intel_extension_for_transformers.transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

quantized_model_path = "./tmp_autoround"
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path, use_fast=True)
Expand Down
52 changes: 35 additions & 17 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch

from .special_model_handler import check_hidden_state_dim, check_share_attention_mask
from .utils import (
CpuInfo,
block_forward,
Expand All @@ -26,7 +27,6 @@
collect_minmax_scale,
collect_round_v,
detect_device,
get_batch_dim,
get_block_names,
get_module,
get_scale_shape,
Expand Down Expand Up @@ -591,7 +591,7 @@ def set_layerwise_config(self, weight_config):
m.scale_dtype = weight_config[n]["scale_dtype"]

@torch.no_grad()
def get_block_outputs(self, block, input_ids, input_others, bs, device, cache_device, batch_dim):
def get_block_outputs(self, block, input_ids, input_others, bs, device, cache_device):
"""Compute the output of a given block of the model for a given input.

Args:
Expand All @@ -611,12 +611,14 @@ def get_block_outputs(self, block, input_ids, input_others, bs, device, cache_de
for i in range(0, self.n_samples, bs):
end_index = min(self.n_samples, i + bs)
indices = torch.arange(i, end_index).to(torch.long)
tmp_input_ids, tmp_input_others = sampling_inputs(input_ids, input_others, indices, self.seqlen)
tmp_input_ids, tmp_input_others = sampling_inputs(
input_ids, input_others, indices, self.seqlen, self.share_attention_mask_flag, self.input_dim
)
tmp_output = block_forward(block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device).to(
cache_device
)
output.append(tmp_output)
output = torch.cat(output, dim=batch_dim)
output = torch.cat(output, dim=self.input_dim)
torch.cuda.empty_cache()
return output

Expand Down Expand Up @@ -711,6 +713,8 @@ def cache_block_input(self, block_name, n_samples):
"""
self.inputs = {}
self.tmp_block_name = block_name
self.share_attention_mask_flag = None
self.hidden_dim_flag = None
self._replace_forward()
self.calib(n_samples)
self._recover_forward()
Expand All @@ -729,9 +733,21 @@ def get_forward_func(self, name):
"""

def forward(_, hidden_states, *positional_args, **kwargs):
dim = int((hasattr(self.model, "config") and "chatglm" in self.model.config.model_type))
"""Rewrite forward function, process and collect input data.

Args:
hidden_states (torch.Tensor): The hidden states tensor.
*positional_args: Variable number of positional arguments.
**kwargs: Variable number of keyword arguments.

Returns:
NotImplementedError: Getting the first layer inputs and then raise the error to save runtime.
"""
if self.share_attention_mask_flag is None:
self.input_dim = check_hidden_state_dim(self.model, positional_args)
self.share_attention_mask_flag = check_share_attention_mask(self.model, hidden_states, **kwargs)
if name in self.inputs:
data = torch.cat([self.inputs[name]["input_ids"], hidden_states.to("cpu")], dim=dim)
data = torch.cat([self.inputs[name]["input_ids"], hidden_states.to("cpu")], dim=self.input_dim)
self.inputs[name]["input_ids"] = data
else:
self.inputs[name] = {}
Expand All @@ -748,7 +764,7 @@ def forward(_, hidden_states, *positional_args, **kwargs):
if key not in self.inputs[name].keys():
self.inputs[name][key] = None
if kwargs[key] is not None:
if self.inputs[name][key] is not None:
if (not self.share_attention_mask_flag) and self.inputs[name][key] is not None:
self.inputs[name][key] = torch.cat(
[self.inputs[name][key], kwargs[key].to("cpu")], dim=0
)
Expand All @@ -761,7 +777,7 @@ def forward(_, hidden_states, *positional_args, **kwargs):
alibi = kwargs[key]
batch = kwargs["attention_mask"].shape[0]
alibi = alibi.reshape(batch, -1, alibi.shape[1], alibi.shape[2])
if self.inputs[name][key] is not None:
if (not self.share_attention_mask_flag) and self.inputs[name][key] is not None:
self.inputs[name][key] = torch.cat([self.inputs[name][key], alibi.to("cpu")], dim=0)
else:
self.inputs[name][key] = alibi.to("cpu")
Expand Down Expand Up @@ -804,14 +820,13 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
"""
from torch.amp import autocast

batch_dim = get_batch_dim(input_others)
if not self.low_gpu_mem_usage and input_ids.device != device:
input_ids = move_input_to_device(input_ids, device)
input_others = move_input_to_device(input_others, device)
cache_device = device
if self.low_gpu_mem_usage:
cache_device = "cpu"
output = self.get_block_outputs(block, input_ids, input_others, self.train_bs, device, cache_device, batch_dim)
output = self.get_block_outputs(block, input_ids, input_others, self.train_bs, device, cache_device)

if q_input is not None:
input_ids = q_input.to(cache_device)
Expand Down Expand Up @@ -842,7 +857,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch

pick_samples = self.train_bs
if len(input_ids.shape) == 3:
n_samples = input_ids.shape[batch_dim]
n_samples = input_ids.shape[self.input_dim]
else:
n_samples = input_ids.shape[0] // self.seqlen
if self.sampler != "rand":
Expand All @@ -860,12 +875,17 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
total_loss = 0
for _ in range(self.gradient_accumulate_steps):
current_input_ids, current_input_others = sampling_inputs(
input_ids, input_others, indices, seqlen=self.seqlen
input_ids,
input_others,
indices,
seqlen=self.seqlen,
share_attention_mask_flag=self.share_attention_mask_flag,
input_dim=self.input_dim,
)
if len(input_ids.shape) == 3:
if batch_dim == 0:
if self.input_dim == 0:
current_output = output[indices, :, :]
elif batch_dim == 1:
elif self.input_dim == 1:
current_output = output[:, indices, :]
else:
current_output = output[:, :, indices]
Expand Down Expand Up @@ -923,9 +943,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch

unwrapper_block(block, best_v, best_min_scale, best_max_scale)
if self.use_quant_input:
q_outputs = self.get_block_outputs(
block, input_ids, input_others, self.train_bs, device, cache_device, batch_dim
)
q_outputs = self.get_block_outputs(block, input_ids, input_others, self.train_bs, device, cache_device)

return q_outputs, output

Expand Down
3 changes: 2 additions & 1 deletion auto_round/export/export_to_itrex/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def pack_model(
m = get_module(compressed_model, k)
fp_weight = m.weight.data
scale, zp = v["scale"], v["zp"]
convert_dtype = torch.float32 if fp_weight.device.type == "cpu" else scale_dtype
convert_dtype = scale_dtype
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale, dtype=convert_dtype)
zp = torch.tensor(zp, dtype=torch.int32)
Expand All @@ -133,6 +133,7 @@ def pack_model(
zp = zp.clone()
scale = scale.to(dtype=convert_dtype)
zp = zp.to(dtype=torch.int32)

int_weight = quant_weight_w_scale(fp_weight, scale, zp, group_size, fp_weight.device)
int_weight = int_weight.type(torch.int32)
new_module = WeightOnlyLinear(
Expand Down
56 changes: 56 additions & 0 deletions auto_round/special_model_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .utils import logger, torch

share_attention_mask_tuple = ("baichuan",)
special_states_dim_tuple = ("chatglm",)


def check_share_attention_mask(model, hidden_states, attention_mask=None, **kwargs):
"""Checks if the attention mask states of the hidden states are shared in the model.

Args:
hidden_states (torch.Tensor): The hidden states of the model.
attention_mask (torch.Tensor, optional): The attention mask tensor. Defaults to None.
**kwargs: Additional keyword arguments.

Returns:
bool: True if attention mask is shared in the model, False otherwise.
"""
if attention_mask is None or not isinstance(hidden_states, torch.Tensor):
return False
is_special = False
for key in share_attention_mask_tuple:
if hasattr(model, "config") and key in model.config.model_type:
is_special = True
break
return bool(is_special and attention_mask.shape[0] != hidden_states.shape[0])


def check_hidden_state_dim(model, positional_args):
"""Checks the dimensionality of the hidden states.

Args:
positional_args: The positional arguments.

Returns:
int: 1 if the model type is 'chatglm' and positional arguments are not None, 0 otherwise.
"""
is_special = False
for key in special_states_dim_tuple:
if hasattr(model, "config") and key in model.config.model_type:
is_special = True
break
return int(is_special and positional_args is not None)
22 changes: 4 additions & 18 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

logger = logging.getLogger("autoround")
logger.setLevel(logging.INFO)
logger.propagate = False
fh = logging.StreamHandler()
fh_formatter = logging.Formatter("%(asctime)s %(levelname)s %(filename)s L%(lineno)d: %(message)s", "%Y-%m-%d %H:%M:%S")
fh.setFormatter(fh_formatter)
Expand Down Expand Up @@ -400,21 +401,7 @@ def collect_minmax_scale(block):
return min_scales, max_scales


@torch.no_grad()
def get_batch_dim(input_others):
"""Gets the batch dimension based on the input positional inputs.

Args:
input_others: A dictionary containing input data.

Returns:
dim: The batch dimension.
"""
dim = int(len(input_others["positional_inputs"]) > 0)
return dim


def sampling_inputs(input_ids, input_others, indices, seqlen):
def sampling_inputs(input_ids, input_others, indices, seqlen, share_attention_mask_flag=False, input_dim=0):
"""Samples inputs based on the given indices and sequence length.

Args:
Expand All @@ -428,7 +415,7 @@ def sampling_inputs(input_ids, input_others, indices, seqlen):
current_input_others: The sampled other input data.
"""
if len(input_ids.shape) == 3:
if int(len(input_others["positional_inputs"]) > 0):
if bool(input_dim):
current_input_ids = input_ids[:, indices, :]
else:
current_input_ids = input_ids[indices, :, :]
Expand All @@ -437,10 +424,9 @@ def sampling_inputs(input_ids, input_others, indices, seqlen):
current_input_ids = input_ids.view(n_samples, seqlen, -1)
current_input_ids = current_input_ids[indices, :, :]
current_input_ids = current_input_ids.reshape(-1, input.shape[-1])

current_input_others = {"positional_inputs": input_others["positional_inputs"]}
for key in input_others.keys():
if "attention_mask" in key or "alibi" in key:
if not share_attention_mask_flag and ("attention_mask" in key or "alibi" in key):
current_input_others[key] = None
if input_others[key] is not None:
current_input_others[key] = input_others[key][indices, ...]
Expand Down