diff --git a/README.md b/README.md index 18f777471..112d61942 100644 --- a/README.md +++ b/README.md @@ -29,20 +29,23 @@ python setup.py install ```python from transformers import AutoModelForCausalLM, AutoTokenizer -model_name = "facebook/opt-125m" +model_name = "bigscience/bloom-560m" model = AutoModelForCausalLM.from_pretrained( - model_name, low_cpu_mem_usage=True, torch_dtype="auto", trust_remote_code=True - ) + model_name, low_cpu_mem_usage=True, torch_dtype="auto", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) bits, group_size, scheme = 4, 128, "asym" -# need load model first, them import +# need to load model first, then import from auto_round import AutoRound - autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, scheme=scheme, device="hpu", scale_dtype="bf16", amp=False) autoround.quantize() +# Intel CPU Inference, Currently, llama, bloom, and mistral are supported. +output_dir = "/path/to/quantized_model" +autoround.export(output_dir) +# then follow ITREX to load the model and do inference + ``` ### On GPU @@ -53,8 +56,7 @@ from auto_round import AutoRound model_name = "bigscience/bloom-560m" model = AutoModelForCausalLM.from_pretrained( - model_name, low_cpu_mem_usage=True, torch_dtype="auto", trust_remote_code=True - ) + model_name, low_cpu_mem_usage=True, torch_dtype="auto", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) bits, group_size, scheme = 4, 128, "asym" autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, scheme=scheme) @@ -185,7 +187,7 @@ For wikitext2/ptb-new/c4-new ppl, we follow the code of gptq and set the sequenc - Ours iters1K, disable use_quant_input, minmax_lr 0.002 + Ours iters=1K,use_quant_input=False, minmax_lr=0.002 67.70 60.57 73.74 @@ -246,7 +248,7 @@ For wikitext2/ptb-new/c4-new ppl, we follow the code of gptq and set the sequenc - - Ours iters1K, disable use_quant_input + Ours iters=1K,use_quant_input=False 66.78 68.68 78.61 @@ -266,7 +268,7 @@ For wikitext2/ptb-new/c4-new ppl, we follow the code of gptq and set the sequenc - microsoft/phi-2 + microsoft/phi-2 FP16 61.80 56.40 @@ -306,6 +308,26 @@ For wikitext2/ptb-new/c4-new ppl, we follow the code of gptq and set the sequenc 11.37 + + + Ours iters=1K,use_quant_input=False + 61.47 + 55.41 + 61.77 + 54.92 + 76.40 + 78.29 + 31.09 + 40.0 + 83.24 + 63.54 + 79.29 + 52.22 + 9.97 + 18.63 + 14.37 + 11.35 + We provide a [comprehensive analysis](docs/README.md) with other methods in our accuracy data section. Notably, our approach has outperformed GPTQ with a score of 30/32 and AWQ with a score of 27/32 across llamv1/llamav2/mistral-7b on W4G-1, W4G128, W3G128, W2G128. And the tuning costs are comparable. diff --git a/auto_round/__init__.py b/auto_round/__init__.py index 1b474ec47..01ad1c826 100644 --- a/auto_round/__init__.py +++ b/auto_round/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. from .autoround import AutoRound, AutoAdamRound, AutoOPTRound from .version import __version__ -from .export_to_itrex import compress_model, QuantConfig -from .export_to_autogptq import save_quantized_to_autogptq +from auto_round.export.export_to_itrex import compress_model, QuantConfig +from auto_round.export.export_to_autogptq import save_quantized_to_autogptq diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 7d7ab1945..357614ec9 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -30,10 +30,11 @@ from functools import partial from torch.functional import F from .utils import (quant_weight, set_module, get_module, get_block_names, block_forward, sampling_inputs, - get_dataloader, get_scale_shape, move_input_to_device, check_is_cpu, collect_round_v, - collect_minmax_scale, get_batch_dim) + get_scale_shape, move_input_to_device, check_is_cpu, collect_round_v, + collect_minmax_scale, get_batch_dim, is_hpu_available) +from .calib_dataset import CALIB_DATASETS + -from .device_utils import * class SaveInputs: @@ -124,15 +125,22 @@ def get_inputs(self, n_samples=512): if data is None: continue if isinstance(data, torch.Tensor): - input_ids = data.to(self.model.device) + data_new = data.to(self.model.device) + input_ids = data_new else: - input_ids = data["input_ids"].to(self.model.device) - if input_ids.shape[-1] < self.seqlen: - continue + data_new = {} + for key in data.keys(): + data_new[key] = data[key].to(self.model.device) + input_ids = data_new["input_ids"] + # if input_ids.shape[-1] < self.seqlen: + # continue if total_cnt + input_ids.shape[0] > n_samples: input_ids = input_ids[: n_samples - total_cnt, ...] try: - self.model(input_ids) + if isinstance(data_new, torch.Tensor): + self.model(data_new) + elif isinstance(data_new, dict): + self.model(**data_new) except NotImplementedError: pass except Exception as error: @@ -153,11 +161,6 @@ def get_inputs(self, n_samples=512): f"Effective samples size:{total_cnt}, Target sample size:{n_samples}" ) res = self.inputs[self.block_name] - if "input_ids" in res.keys(): - total_samples = res["input_ids"].shape[0] - if total_samples < n_samples: - logger.warning("only cache {total_samples}") - return res def _recover_forward(self): @@ -582,13 +585,16 @@ def __init__( self.dataset_name = dataset_name if dataloader is None: + get_dataloader = CALIB_DATASETS.get(self.dataset_name, + CALIB_DATASETS["NeelNanda/pile-10k"]) self.dataloader = get_dataloader( self.tokenizer, self.seqlen, seed=self.seed, bs=self.train_bs, split=self.dataset_split, - data_name=self.dataset_name, + dataset_name=self.dataset_name + ) else: self.dataloader = dataloader @@ -989,7 +995,7 @@ def export_to_autogptq(self, output_dir, use_triton=False): def export_to_itrex(self, output_dir): """Save configure file and weights for CPU backend inference.""" - from .export_to_itrex import compress_model + from auto_round.export.export_to_itrex import compress_model compressed_model, quantize_config = compress_model(self.model, self.weight_config) if quantize_config is not None: config = compressed_model.config @@ -1028,6 +1034,7 @@ def quantize(self): del save_input_actor if "input_ids" in inputs.keys(): total_samples = inputs["input_ids"].shape[0] + self.n_samples = total_samples if total_samples < self.train_bs: self.train_bs = total_samples logger.warning(f"force the train batch size to {total_samples} ") @@ -1322,4 +1329,3 @@ def __init__( optimizer, **kwargs, ) - diff --git a/auto_round/calib_dataset.py b/auto_round/calib_dataset.py new file mode 100644 index 000000000..617494b69 --- /dev/null +++ b/auto_round/calib_dataset.py @@ -0,0 +1,161 @@ +import torch +import random + +CALIB_DATASETS = {} + + +def register_dataset(name): + """Class decorator to register a DATASET subclass to the registry. + + Decorator function used before a Pattern subclass. + + Args: + cls (class): The subclass of register. + name: A string. Define the pruner type. + + Returns: + cls: The class of register. + """ + + def register(dataset): + CALIB_DATASETS[name] = dataset + return dataset + + return register + + +def get_tokenizer_function(tokenizer, seqlen): + """Returns a default tokenizer function. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + + Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length of + seqlen to the "text" field of examples. + """ + + def default_tokenizer_function(examples): + example = tokenizer(examples["text"], truncation=True, max_length=seqlen) + return example + + return default_tokenizer_function + + +@register_dataset("NeelNanda/pile-10k") +def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split="train", seed=42, bs=4): + """Returns a dataloader for the specified dataset and split. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + data_name: The name of the dataset. + split: The data split to be used (e.g., "train", "test"). + seed: The random seed for shuffling the dataset. + bs: The batch size for the dataloader. + + Returns: + A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. + """ + from datasets import load_dataset + from torch.utils.data import DataLoader + + tokenizer_function = get_tokenizer_function(tokenizer, seqlen) + + @torch.no_grad() + def collate_batch(batch): + input_ids_new = [] + for text in batch: + input_ids = text["input_ids"] + if input_ids.shape[0] < seqlen: + continue + input_ids = input_ids[:seqlen] + input_ids_list = input_ids.tolist() + if input_ids_list.count(input_ids_list[-1]) > seqlen // 2: + continue + input_ids_new.append(input_ids) + if len(input_ids_new) == 0: + return None + tmp = torch.vstack(input_ids_new) + res = {"input_ids": tmp} + return res + + calib_dataset = load_dataset(dataset_name, split=split) + calib_dataset = calib_dataset.shuffle(seed=seed) + calib_dataset = calib_dataset.map(tokenizer_function, batched=True) + calib_dataset.set_format(type="torch", columns=["input_ids"]) + calib_dataloader = DataLoader(calib_dataset, batch_size=bs, shuffle=False, collate_fn=collate_batch) + return calib_dataloader + + +@register_dataset("mbpp") +def get_mbpp_dataloader(tokenizer, seqlen, dataset_name="mbpp", split=['train', 'validation', 'test'], seed=42, bs=4): + """Returns a dataloader for the specified dataset and split. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + data_name: The name of the dataset. + split: The data split to be used (e.g., "train", "test"). + seed: The random seed for shuffling the dataset. + bs: The batch size for the dataloader. + + Returns: + A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. + """ + from datasets import load_dataset + from torch.utils.data import DataLoader + + def get_mbpp_tokenizer_function(tokenizer, seqlen): + """Returns a default tokenizer function. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + + Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length of + seqlen to the "text" field of examples. + """ + + def default_tokenizer_function(examples): + example = tokenizer(examples, truncation=True, max_length=seqlen, return_tensors="pt") + # example = tokenizer(examples, return_tensors="pt") + return example + + return default_tokenizer_function + + tokenizer_function = get_mbpp_tokenizer_function(tokenizer, seqlen) + + @torch.no_grad() + def collate_batch(batch): + input_ids_new = [] + attention_mask_new = [] + for text in batch: + token_text = tokenizer_function(text) + input_ids, attention_mask = token_text["input_ids"], token_text["attention_mask"] + if input_ids.shape[1] < seqlen: + continue + input_ids = input_ids[:seqlen] + input_ids_list = input_ids.tolist() + if input_ids_list.count(input_ids_list[-1]) > seqlen // 2: + continue + attention_mask = attention_mask[:seqlen] + attention_mask_new.append(attention_mask) + input_ids_new.append(input_ids) + if len(input_ids_new) == 0: + return None + input_ids_new = torch.vstack(input_ids_new) + attention_mask_new = torch.vstack(attention_mask_new) + res = {"input_ids": input_ids_new, "attention_mask": attention_mask_new} + return res + + samples = [] + splits = split + for split in splits: + dataset = load_dataset(dataset_name, split=split) + for data in dataset: + samples.append(data["text"] + data["code"]) + random.Random(seed).shuffle(samples) + + calib_dataloader = DataLoader(samples, batch_size=bs, shuffle=False, collate_fn=collate_batch) + return calib_dataloader diff --git a/auto_round/device_utils.py b/auto_round/device_utils.py deleted file mode 100644 index 2436aea0e..000000000 --- a/auto_round/device_utils.py +++ /dev/null @@ -1,33 +0,0 @@ - -# Copyright (c) 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -def is_optimum_habana_available(): - import importlib - from transformers.utils.import_utils import is_optimum_available - - return is_optimum_available() and importlib.util.find_spec("optimum.habana") != None - -try: - import habana_frameworks.torch.hpu as hthpu - import habana_frameworks.torch.core as htcore - if is_optimum_habana_available(): - is_hpu_available = True - else: - print("Should install optimum-habana when the environment has habana frameworks") - is_hpu_available = False -except ImportError: - is_hpu_available = False - diff --git a/examples/eval/__init__.py b/auto_round/export/__init__.py similarity index 100% rename from examples/eval/__init__.py rename to auto_round/export/__init__.py diff --git a/auto_round/export_to_autogptq.py b/auto_round/export/export_to_autogptq.py similarity index 100% rename from auto_round/export_to_autogptq.py rename to auto_round/export/export_to_autogptq.py diff --git a/auto_round/export_to_itrex/__init__.py b/auto_round/export/export_to_itrex/__init__.py similarity index 100% rename from auto_round/export_to_itrex/__init__.py rename to auto_round/export/export_to_itrex/__init__.py diff --git a/auto_round/export_to_itrex/config.py b/auto_round/export/export_to_itrex/config.py similarity index 100% rename from auto_round/export_to_itrex/config.py rename to auto_round/export/export_to_itrex/config.py diff --git a/auto_round/export_to_itrex/export.py b/auto_round/export/export_to_itrex/export.py similarity index 98% rename from auto_round/export_to_itrex/export.py rename to auto_round/export/export_to_itrex/export.py index 1177cd17f..5af0528c6 100644 --- a/auto_round/export_to_itrex/export.py +++ b/auto_round/export/export_to_itrex/export.py @@ -24,7 +24,7 @@ import json from .config import QuantConfig from .model_wrapper import WeightOnlyLinear -from ..utils import quant_weight_w_scale, get_module, set_module +from auto_round.utils import quant_weight_w_scale, get_module, set_module def compress_model( diff --git a/auto_round/export_to_itrex/model_wrapper.py b/auto_round/export/export_to_itrex/model_wrapper.py similarity index 100% rename from auto_round/export_to_itrex/model_wrapper.py rename to auto_round/export/export_to_itrex/model_wrapper.py diff --git a/auto_round/utils.py b/auto_round/utils.py index 2e84b82fc..add95604a 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -19,6 +19,23 @@ from collections import UserDict +def is_optimum_habana_available(): + import importlib + from transformers.utils.import_utils import is_optimum_available + + return is_optimum_available() and importlib.util.find_spec("optimum.habana") != None + +try: + import habana_frameworks.torch.hpu as hthpu + import habana_frameworks.torch.core as htcore + if is_optimum_habana_available(): + is_hpu_available = True + else: + print("Should install optimum-habana when the environment has habana frameworks") + is_hpu_available = False +except ImportError: + is_hpu_available = False + def round_ste(x: torch.Tensor): """Straight-Through Estimator for rounding. This function is adapted from omniquant. @@ -313,67 +330,6 @@ def get_block_names(model): return block_names -def get_tokenizer_function(tokenizer, seqlen): - """Returns a default tokenizer function. - - Args: - tokenizer: The tokenizer to be used for tokenization. - seqlen: The maximum sequence length. - - Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length of - seqlen to the "text" field of examples. - """ - - def default_tokenizer_function(examples): - example = tokenizer(examples["text"], truncation=True, max_length=seqlen) - return example - - return default_tokenizer_function - - -def get_dataloader(tokenizer, seqlen, data_name="NeelNanda/pile-10k", split="train", seed=42, bs=4): - """Returns a dataloader for the specified dataset and split. - - Args: - tokenizer: The tokenizer to be used for tokenization. - seqlen: The maximum sequence length. - data_name: The name of the dataset. - split: The data split to be used (e.g., "train", "test"). - seed: The random seed for shuffling the dataset. - bs: The batch size for the dataloader. - - Returns: - A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. - """ - from datasets import load_dataset - from torch.utils.data import DataLoader - - tokenizer_function = get_tokenizer_function(tokenizer, seqlen) - - @torch.no_grad() - def collate_batch(batch): - input_ids_new = [] - for text in batch: - input_ids = text["input_ids"] - if input_ids.shape[0] < seqlen: - continue - input_ids = input_ids[:seqlen] - input_ids_list = input_ids.tolist() - if input_ids_list.count(input_ids_list[-1]) > seqlen // 2: - continue - input_ids_new.append(input_ids) - if len(input_ids_new) == 0: - return None - tmp = torch.vstack(input_ids_new) - res = {"input_ids": tmp} - return res - - calib_dataset = load_dataset(data_name, split=split) - calib_dataset = calib_dataset.shuffle(seed=seed) - calib_dataset = calib_dataset.map(tokenizer_function, batched=True) - calib_dataset.set_format(type="torch", columns=["input_ids"]) - calib_dataloader = DataLoader(calib_dataset, batch_size=bs, shuffle=False, collate_fn=collate_batch) - return calib_dataloader def collect_round_v(block): diff --git a/examples/code-generation/README.md b/examples/code-generation/README.md new file mode 100644 index 000000000..d69ab3a0f --- /dev/null +++ b/examples/code-generation/README.md @@ -0,0 +1,65 @@ +Step-by-Step +============ + +This document presents step-by-step instructions for auto-round. + + +## 2. Prepare Dataset + +The mbpp in huggingface is adopted as the default calibration data and will be downloaded automatically from the datasets Hub. To customize a dataset, please kindly follow our dataset code. +See more about loading [huggingface dataset](https://huggingface.co/docs/datasets/loading_datasets.html) + +
+ +## 3. Run Examples +Enter into the examples folder +- **Default Settings:** +```bash +CUDA_VISIBLE_DEVICES=0 python3 main.py --model_name Salesforce/codegen25-7b-multi --amp --bits 4 --group_size -1 --enable_minmax_tuning --use_quant_input +``` +- **Reduced GPU Memory Usage and Adjusted Training Batch Size:** +```bash +CUDA_VISIBLE_DEVICES=0 python3 main.py --model_name Salesforce/codegen25-7b-multi --amp --bits 4 --group_size -1 --low_gpu_mem_usage --train_bs 1 --gradient_accumulate_steps 8 +``` +- **Utilizing the AdamW Optimizer:** +Include the flag `--adam`. Note that AdamW is less effective than Sign gradient descent in many scenarios we tested. + +- **Running the Original SignRound:** +```bash +CUDA_VISIBLE_DEVICES=0 python3 main.py --model_name Salesforce/codegen25-7b-multi --amp --bits 4 --group_size -1 --iters 400 --lr 0.0025 --minmax_lr 0.0025 +``` + `--enable_minmax_tuning` is strongly recommended + + +## 4. Evaluation +Please follow https://github.com/bigcode-project/bigcode-evaluation-harness to eval the model, currently we only support fake model evaluation. + + + + + + + + + + + + + + + +
ModelMethodHumanEval top1 t=0.2 n_samples=20
Salesforce/codegen25-7b-multiFP16 0.2854
AutoRound use_quant_input=False0.2841
+ + + +## Reference +If you find SignRound useful for your research, please cite our paper: +```bash +@article{cheng2023optimize, + title={Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs}, + author={Cheng, Wenhua and Zhang, Weiwei and Shen, Haihao and Cai, Yiyang and He, Xin and Lv, Kaokao}, + journal={arXiv preprint arXiv:2309.05516}, + year={2023} +} +``` + diff --git a/examples/code-generation/main.py b/examples/code-generation/main.py new file mode 100644 index 000000000..92755f7b3 --- /dev/null +++ b/examples/code-generation/main.py @@ -0,0 +1,198 @@ +import argparse +import random +import sys + +sys.path.insert(0, '../..') +from auto_round import (AutoRound, + AutoAdamRound) + +parser = argparse.ArgumentParser() +import torch +import os + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.use_deterministic_algorithms(True, warn_only=True) +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel + +from transformers import set_seed + +import re + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + + +if __name__ == '__main__': + + parser.add_argument( + "--model_name", default="facebook/opt-125m" + ) + + parser.add_argument("--bits", default=4, type=int, + help="number of bits") + + parser.add_argument("--group_size", default=128, type=int, + help="group size") + + parser.add_argument("--train_bs", default=8, type=int, + help="train batch size") + + parser.add_argument("--eval_bs", default=4, type=int, + help="eval batch size") + + parser.add_argument("--device", default=0, type=str, + help="device gpu int number, or 'cpu' ") + + parser.add_argument("--sym", action='store_true', + help=" sym quantization") + + parser.add_argument("--iters", default=200, type=int, + help=" iters") + + parser.add_argument("--use_quant_input", action='store_true', + help="whether to use the output of quantized block to tune the next block") + + parser.add_argument("--lr", default=None, type=float, + help="learning rate, if None, it will be set to 1.0/iters automatially") + + parser.add_argument("--minmax_lr", default=None, type=float, + help="minmax learning rate, if None,it will beset to be the same with lr") + + parser.add_argument("--seed", default=42, type=int, + help="seed") + + parser.add_argument("--amp", action='store_true', + help="amp") + + parser.add_argument("--adam", action='store_true', + help="adam") + + parser.add_argument("--seqlen", default=128, type=int, + help="sequence length") + + parser.add_argument("--gradient_accumulate_steps", default=1, type=int, help="gradient accumulate steps") + + parser.add_argument("--n_blocks", default=1, type=int, help="num of blocks to tune together") + + parser.add_argument("--n_samples", default=512, type=int, + help="number of samples") + + parser.add_argument("--low_gpu_mem_usage", action='store_true', + help="low_gpu_mem_usage") + + parser.add_argument("--enable_minmax_tuning", action='store_true', + help="whether enable weight minmax tuning") + + parser.add_argument("--deployment_device", default='fake', type=str, + help="targeted inference acceleration platform,The options are 'fake', 'cpu' and 'gpu'," + "default to 'fake', indicating that it only performs fake quantization and won't be exported to any device.") + + parser.add_argument("--scale_dtype", default='fp32', + help="which scale data type to use for quantization, 'fp16', 'fp32' or 'bf16'.") + + parser.add_argument("--tasks", + default=['wikitext2', 'ptb-new', 'c4-new', 'lambada_openai', 'hellaswag', 'winogrande', 'piqa', + "hendrycksTest-*", "wikitext", "truthfulqa_mc", "openbookqa", "boolq", "rte", + "arc_easy", "arc_challenge"], + help="lm-eval tasks") + + parser.add_argument("--output_dir", default="./tmp_signround", type=str, + help="Where to store the final model.") + + args = parser.parse_args() + set_seed(args.seed) + + model_name = args.model_name + if model_name[-1] == "/": + model_name = model_name[:-1] + print(model_name, flush=True) + + tasks = args.tasks + + if args.device == "cpu": + device_str = "cpu" + else: + device_str = f"cuda:{int(args.device)}" + torch_device = torch.device(device_str) + is_glm = bool(re.search("chatglm", model_name.lower())) + is_llava = bool(re.search("llava", model_name.lower())) + if is_llava: + from transformers import LlavaForConditionalGeneration + + model = LlavaForConditionalGeneration.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype="auto") + elif is_glm: + model = AutoModel.from_pretrained(model_name, trust_remote_code=True) + else: + model = AutoModelForCausalLM.from_pretrained( + model_name, low_cpu_mem_usage=True, torch_dtype="auto", trust_remote_code=True + ) + model = model.eval() + # align wigh GPTQ to eval ppl + if "opt" in model_name: + seqlen = model.config.max_position_embeddings + model.seqlen = model.config.max_position_embeddings + else: + seqlen = 2048 + model.seqlen = seqlen + seqlen = args.seqlen + + if "llama" in model_name: + from transformers import LlamaTokenizer + + tokenizer = LlamaTokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + else: + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + if hasattr(tokenizer, "model_max_length"): + if tokenizer.model_max_length < seqlen: + print(f"change sequence length to {tokenizer.model_max_length} due to the limitation of model_max_length", + flush=True) + seqlen = min(seqlen, tokenizer.model_max_length) + args.seqlen = seqlen + + excel_name = f"{model_name}_{args.bits}_{args.group_size}" + + if not args.low_gpu_mem_usage: + model = model.to(torch_device) + + scheme = "asym" + if args.sym: + scheme = "sym" + round = AutoRound + if args.adam: + round = AutoAdamRound + autoround = round(model, tokenizer, args.bits, args.group_size, scheme, bs=args.train_bs, + seqlen=seqlen, n_blocks=args.n_blocks, iters=args.iters, lr=args.lr, + minmax_lr=args.minmax_lr, use_quant_input=args.use_quant_input, device=device_str, + amp=args.amp, n_samples=args.n_samples, low_gpu_mem_usage=args.low_gpu_mem_usage, + seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps, + scale_dtype=args.scale_dtype, dataset_name="mbpp", dataset_split=['train', 'validation', 'test']) ##TODO args pass + model, q_config = autoround.quantize() + model_name = args.model_name.rstrip("/") + export_dir = args.output_dir + "/compressed_" + model_name.split('/')[-1] + "/" + if args.deployment_device == 'cpu': + autoround.export(output_dir=export_dir) + del q_config + elif args.deployment_device == 'gpu': + autoround.export(export_dir, target="auto_gptq", use_triton=True) + if args.device != "cpu": + torch.cuda.empty_cache() + model.eval() + output_dir = args.output_dir + "_" + model_name.split('/')[-1] + f"_w{args.bits}_g{args.group_size}" + + import shutil + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + if (hasattr(model, 'config') and model.config.torch_dtype is torch.bfloat16): + dtype = 'bfloat16' + pt_dtype = torch.bfloat16 + else: + pt_dtype = torch.float16 + + model = model.to(pt_dtype) + model = model.to("cpu") + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) diff --git a/examples/code-generation/run_autoround.sh b/examples/code-generation/run_autoround.sh new file mode 100644 index 000000000..8993a8665 --- /dev/null +++ b/examples/code-generation/run_autoround.sh @@ -0,0 +1,16 @@ +device=0 +#!/bin/bash +set -x +model_name="Salesforce/codegen25-7b-multi" + +CUDA_VISIBLE_DEVICES=$device \ +python3 main.py \ +--model_name $model_name \ +--device $device \ +--group_size 128 \ +--bits 4 \ +--iters 200 \ +--seqlen 128 \ +--enable_minmax_tuning \ +--output_dir "./tmp_signround" \ +--amp \ No newline at end of file diff --git a/examples/README.md b/examples/language-modeling/README.md similarity index 90% rename from examples/README.md rename to examples/language-modeling/README.md index 560a5956b..e5daec3e4 100644 --- a/examples/README.md +++ b/examples/language-modeling/README.md @@ -1,7 +1,7 @@ Step-by-Step ============ -This document presents step-by-step instructions for autoround. +This document presents step-by-step instructions for auto-round. # Prerequisite @@ -32,7 +32,7 @@ The transformers version required varies across different types of models. Here, ## 2. Prepare Dataset -The dataset will be downloaded automatically from the datasets Hub. +The NeelNanda/pile-10k in huggingface is adopted as the default calibration data and will be downloaded automatically from the datasets Hub. To customize a dataset, please kindly follow our dataset code. See more about loading [huggingface dataset](https://huggingface.co/docs/datasets/loading_datasets.html)
diff --git a/examples/language-modeling/eval/__init__.py b/examples/language-modeling/eval/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/eval/evaluation.py b/examples/language-modeling/eval/evaluation.py similarity index 100% rename from examples/eval/evaluation.py rename to examples/language-modeling/eval/evaluation.py diff --git a/examples/eval/parse_results.py b/examples/language-modeling/eval/parse_results.py similarity index 100% rename from examples/eval/parse_results.py rename to examples/language-modeling/eval/parse_results.py diff --git a/examples/eval/utils.py b/examples/language-modeling/eval/utils.py similarity index 100% rename from examples/eval/utils.py rename to examples/language-modeling/eval/utils.py diff --git a/examples/main.py b/examples/language-modeling/main.py similarity index 98% rename from examples/main.py rename to examples/language-modeling/main.py index f466c8e44..917f0b0df 100644 --- a/examples/main.py +++ b/examples/language-modeling/main.py @@ -1,6 +1,6 @@ import argparse import sys -sys.path.insert(0, '../') +sys.path.insert(0, '../..') from auto_round import (AutoRound, AutoAdamRound) @@ -25,7 +25,7 @@ if __name__ == '__main__': parser.add_argument( - "--model_name", default="/models/opt-125m" + "--model_name", default="facebook/opt-125m" ) parser.add_argument("--bits", default=4, type=int, diff --git a/examples/requirements.txt b/examples/language-modeling/requirements.txt similarity index 100% rename from examples/requirements.txt rename to examples/language-modeling/requirements.txt diff --git a/examples/run_autoround.sh b/examples/language-modeling/run_autoround.sh similarity index 82% rename from examples/run_autoround.sh rename to examples/language-modeling/run_autoround.sh index 66b3f0b44..ceabc30ca 100644 --- a/examples/run_autoround.sh +++ b/examples/language-modeling/run_autoround.sh @@ -1,7 +1,7 @@ #!/bin/bash set -x -model="neural-chat-7b-v3-3" +model="Intel/neural-chat-7b-v3" CUDA_VISIBLE_DEVICES=4 python3 main.py \ --model_name /models/${model} \ @@ -13,5 +13,5 @@ CUDA_VISIBLE_DEVICES=4 python3 main.py \ --deployment_device 'cpu' \ --scale_dtype 'fp32' \ --eval_bs 16 \ - --output_dir "SAVE/PATH" + --output_dir "./tmp_signround" diff --git a/test/test_autoround.py b/test/test_autoround.py index 4784433ce..172ecf8f2 100644 --- a/test/test_autoround.py +++ b/test/test_autoround.py @@ -1,15 +1,11 @@ import copy -import os import shutil import unittest import sys sys.path.insert(0,".") import torch import transformers -from transformers import AutoModelForCausalLM, AutoTokenizer -from auto_round import (AutoRound, - AutoAdamRound, - QuantConfig) +from auto_round import (AutoRound) class SimpleDataLoader: @@ -65,7 +61,7 @@ def test_autoround_int_quant(self): optq_1 = round(model, self.tokenizer, n_samples=20, device=device, amp=False, seqlen=10, iters=10, scale_dtype='fp32') q_model, weight_config1 = optq_1.quantize() q_model = q_model - from auto_round.export_to_itrex import compress_model + from auto_round.export.export_to_itrex import compress_model compressed_model,_ = compress_model(q_model, weight_config1) q_model = q_model model = model