Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,23 @@ python setup.py install
```python
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "facebook/opt-125m"
model_name = "bigscience/bloom-560m"
model = AutoModelForCausalLM.from_pretrained(
model_name, low_cpu_mem_usage=True, torch_dtype="auto", trust_remote_code=True
)
model_name, low_cpu_mem_usage=True, torch_dtype="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
bits, group_size, scheme = 4, 128, "asym"

# need load model first, them import
# need to load model first, then import
from auto_round import AutoRound

autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, scheme=scheme,
device="hpu", scale_dtype="bf16", amp=False)
autoround.quantize()

# Intel CPU Inference, Currently, llama, bloom, and mistral are supported.
output_dir = "/path/to/quantized_model"
autoround.export(output_dir)
# then follow ITREX to load the model and do inference

```

### On GPU
Expand All @@ -53,8 +56,7 @@ from auto_round import AutoRound

model_name = "bigscience/bloom-560m"
model = AutoModelForCausalLM.from_pretrained(
model_name, low_cpu_mem_usage=True, torch_dtype="auto", trust_remote_code=True
)
model_name, low_cpu_mem_usage=True, torch_dtype="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
bits, group_size, scheme = 4, 128, "asym"
autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, scheme=scheme)
Expand Down Expand Up @@ -185,7 +187,7 @@ For wikitext2/ptb-new/c4-new ppl, we follow the code of gptq and set the sequenc
</tr>

</tr>
<th>Ours iters1K, disable use_quant_input, minmax_lr 0.002</th>
<th>Ours iters=1K,use_quant_input=False, minmax_lr=0.002</th>
<td>67.70</td> <! acc avg -->
<td>60.57</td> <! MMLU -->
<td>73.74</td> <! Lambada_openai -->
Expand Down Expand Up @@ -246,7 +248,7 @@ For wikitext2/ptb-new/c4-new ppl, we follow the code of gptq and set the sequenc
<td>-</td>
</tr>
<tr>
<th>Ours iters1K, disable use_quant_input
<th>Ours iters=1K,use_quant_input=False
<td>66.78</td>
<td>68.68</td>
<td>78.61</td>
Expand All @@ -266,7 +268,7 @@ For wikitext2/ptb-new/c4-new ppl, we follow the code of gptq and set the sequenc

</tr>
<tr>
<td rowspan="2">microsoft/phi-2 </td>
<td rowspan="3">microsoft/phi-2 </td>
<th>FP16</th>
<td>61.80</td>
<td>56.40</td>
Expand Down Expand Up @@ -306,6 +308,26 @@ For wikitext2/ptb-new/c4-new ppl, we follow the code of gptq and set the sequenc
<td>11.37</td>

</tr>

</tr>
<th>Ours iters=1K,use_quant_input=False </th>
<td>61.47</td> <! acc avg -->
<td>55.41</td> <! MMLU -->
<td>61.77</td> <! Lambada_openai -->
<td>54.92</td> <! Hellsaswag -->
<td>76.40</td> <! Winogrande -->
<td>78.29</td> <! Piqa -->
<td>31.09</td> <! Truthfulqa -->
<td>40.0</td> <! Openbookqa -->
<td>83.24</td> <! Boolq -->
<td>63.54</td> <! RTE -->
<td>79.29</td> <! Arc easy -->
<td>52.22</td> <! Arc Challenge -->
<td>9.97</td> <! wikitext2 ppl -->
<td>18.63</td> <! ptb_new ppl -->
<td>14.37</td> <! c4_new ppl -->
<td>11.35</td> <! lm-eval wikitext ppl -->
</tr>
</table>

We provide a [comprehensive analysis](docs/README.md) with other methods in our accuracy data section. Notably, our approach has outperformed GPTQ with a score of 30/32 and AWQ with a score of 27/32 across llamv1/llamav2/mistral-7b on W4G-1, W4G128, W3G128, W2G128. And the tuning costs are comparable.
Expand Down
4 changes: 2 additions & 2 deletions auto_round/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
# limitations under the License.
from .autoround import AutoRound, AutoAdamRound, AutoOPTRound
from .version import __version__
from .export_to_itrex import compress_model, QuantConfig
from .export_to_autogptq import save_quantized_to_autogptq
from auto_round.export.export_to_itrex import compress_model, QuantConfig
from auto_round.export.export_to_autogptq import save_quantized_to_autogptq

38 changes: 22 additions & 16 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@
from functools import partial
from torch.functional import F
from .utils import (quant_weight, set_module, get_module, get_block_names, block_forward, sampling_inputs,
get_dataloader, get_scale_shape, move_input_to_device, check_is_cpu, collect_round_v,
collect_minmax_scale, get_batch_dim)
get_scale_shape, move_input_to_device, check_is_cpu, collect_round_v,
collect_minmax_scale, get_batch_dim, is_hpu_available)
from .calib_dataset import CALIB_DATASETS


from .device_utils import *


class SaveInputs:
Expand Down Expand Up @@ -124,15 +125,22 @@ def get_inputs(self, n_samples=512):
if data is None:
continue
if isinstance(data, torch.Tensor):
input_ids = data.to(self.model.device)
data_new = data.to(self.model.device)
input_ids = data_new
else:
input_ids = data["input_ids"].to(self.model.device)
if input_ids.shape[-1] < self.seqlen:
continue
data_new = {}
for key in data.keys():
data_new[key] = data[key].to(self.model.device)
input_ids = data_new["input_ids"]
# if input_ids.shape[-1] < self.seqlen:
# continue
if total_cnt + input_ids.shape[0] > n_samples:
input_ids = input_ids[: n_samples - total_cnt, ...]
try:
self.model(input_ids)
if isinstance(data_new, torch.Tensor):
self.model(data_new)
elif isinstance(data_new, dict):
self.model(**data_new)
except NotImplementedError:
pass
except Exception as error:
Expand All @@ -153,11 +161,6 @@ def get_inputs(self, n_samples=512):
f"Effective samples size:{total_cnt}, Target sample size:{n_samples}"
)
res = self.inputs[self.block_name]
if "input_ids" in res.keys():
total_samples = res["input_ids"].shape[0]
if total_samples < n_samples:
logger.warning("only cache {total_samples}")

return res

def _recover_forward(self):
Expand Down Expand Up @@ -582,13 +585,16 @@ def __init__(
self.dataset_name = dataset_name

if dataloader is None:
get_dataloader = CALIB_DATASETS.get(self.dataset_name,
CALIB_DATASETS["NeelNanda/pile-10k"])
self.dataloader = get_dataloader(
self.tokenizer,
self.seqlen,
seed=self.seed,
bs=self.train_bs,
split=self.dataset_split,
data_name=self.dataset_name,
dataset_name=self.dataset_name

)
else:
self.dataloader = dataloader
Expand Down Expand Up @@ -989,7 +995,7 @@ def export_to_autogptq(self, output_dir, use_triton=False):

def export_to_itrex(self, output_dir):
"""Save configure file and weights for CPU backend inference."""
from .export_to_itrex import compress_model
from auto_round.export.export_to_itrex import compress_model
compressed_model, quantize_config = compress_model(self.model, self.weight_config)
if quantize_config is not None:
config = compressed_model.config
Expand Down Expand Up @@ -1028,6 +1034,7 @@ def quantize(self):
del save_input_actor
if "input_ids" in inputs.keys():
total_samples = inputs["input_ids"].shape[0]
self.n_samples = total_samples
if total_samples < self.train_bs:
self.train_bs = total_samples
logger.warning(f"force the train batch size to {total_samples} ")
Expand Down Expand Up @@ -1322,4 +1329,3 @@ def __init__(
optimizer,
**kwargs,
)

161 changes: 161 additions & 0 deletions auto_round/calib_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import torch
import random

CALIB_DATASETS = {}


def register_dataset(name):
"""Class decorator to register a DATASET subclass to the registry.

Decorator function used before a Pattern subclass.

Args:
cls (class): The subclass of register.
name: A string. Define the pruner type.

Returns:
cls: The class of register.
"""

def register(dataset):
CALIB_DATASETS[name] = dataset
return dataset

return register


def get_tokenizer_function(tokenizer, seqlen):
"""Returns a default tokenizer function.

Args:
tokenizer: The tokenizer to be used for tokenization.
seqlen: The maximum sequence length.

Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length of
seqlen to the "text" field of examples.
"""

def default_tokenizer_function(examples):
example = tokenizer(examples["text"], truncation=True, max_length=seqlen)
return example

return default_tokenizer_function


@register_dataset("NeelNanda/pile-10k")
def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split="train", seed=42, bs=4):
"""Returns a dataloader for the specified dataset and split.

Args:
tokenizer: The tokenizer to be used for tokenization.
seqlen: The maximum sequence length.
data_name: The name of the dataset.
split: The data split to be used (e.g., "train", "test").
seed: The random seed for shuffling the dataset.
bs: The batch size for the dataloader.

Returns:
A dataloader for the specified dataset and split, using the provided tokenizer and sequence length.
"""
from datasets import load_dataset
from torch.utils.data import DataLoader

tokenizer_function = get_tokenizer_function(tokenizer, seqlen)

@torch.no_grad()
def collate_batch(batch):
input_ids_new = []
for text in batch:
input_ids = text["input_ids"]
if input_ids.shape[0] < seqlen:
continue
input_ids = input_ids[:seqlen]
input_ids_list = input_ids.tolist()
if input_ids_list.count(input_ids_list[-1]) > seqlen // 2:
continue
input_ids_new.append(input_ids)
if len(input_ids_new) == 0:
return None
tmp = torch.vstack(input_ids_new)
res = {"input_ids": tmp}
return res

calib_dataset = load_dataset(dataset_name, split=split)
calib_dataset = calib_dataset.shuffle(seed=seed)
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)
calib_dataset.set_format(type="torch", columns=["input_ids"])
calib_dataloader = DataLoader(calib_dataset, batch_size=bs, shuffle=False, collate_fn=collate_batch)
return calib_dataloader


@register_dataset("mbpp")
def get_mbpp_dataloader(tokenizer, seqlen, dataset_name="mbpp", split=['train', 'validation', 'test'], seed=42, bs=4):
"""Returns a dataloader for the specified dataset and split.

Args:
tokenizer: The tokenizer to be used for tokenization.
seqlen: The maximum sequence length.
data_name: The name of the dataset.
split: The data split to be used (e.g., "train", "test").
seed: The random seed for shuffling the dataset.
bs: The batch size for the dataloader.

Returns:
A dataloader for the specified dataset and split, using the provided tokenizer and sequence length.
"""
from datasets import load_dataset
from torch.utils.data import DataLoader

def get_mbpp_tokenizer_function(tokenizer, seqlen):
"""Returns a default tokenizer function.

Args:
tokenizer: The tokenizer to be used for tokenization.
seqlen: The maximum sequence length.

Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length of
seqlen to the "text" field of examples.
"""

def default_tokenizer_function(examples):
example = tokenizer(examples, truncation=True, max_length=seqlen, return_tensors="pt")
# example = tokenizer(examples, return_tensors="pt")
return example

return default_tokenizer_function

tokenizer_function = get_mbpp_tokenizer_function(tokenizer, seqlen)

@torch.no_grad()
def collate_batch(batch):
input_ids_new = []
attention_mask_new = []
for text in batch:
token_text = tokenizer_function(text)
input_ids, attention_mask = token_text["input_ids"], token_text["attention_mask"]
if input_ids.shape[1] < seqlen:
continue
input_ids = input_ids[:seqlen]
input_ids_list = input_ids.tolist()
if input_ids_list.count(input_ids_list[-1]) > seqlen // 2:
continue
attention_mask = attention_mask[:seqlen]
attention_mask_new.append(attention_mask)
input_ids_new.append(input_ids)
if len(input_ids_new) == 0:
return None
input_ids_new = torch.vstack(input_ids_new)
attention_mask_new = torch.vstack(attention_mask_new)
res = {"input_ids": input_ids_new, "attention_mask": attention_mask_new}
return res

samples = []
splits = split
for split in splits:
dataset = load_dataset(dataset_name, split=split)
for data in dataset:
samples.append(data["text"] + data["code"])
random.Random(seed).shuffle(samples)

calib_dataloader = DataLoader(samples, batch_size=bs, shuffle=False, collate_fn=collate_batch)
return calib_dataloader
33 changes: 0 additions & 33 deletions auto_round/device_utils.py

This file was deleted.

File renamed without changes.
Loading