diff --git a/README.md b/README.md
index 18f777471..112d61942 100644
--- a/README.md
+++ b/README.md
@@ -29,20 +29,23 @@ python setup.py install
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
-model_name = "facebook/opt-125m"
+model_name = "bigscience/bloom-560m"
model = AutoModelForCausalLM.from_pretrained(
- model_name, low_cpu_mem_usage=True, torch_dtype="auto", trust_remote_code=True
- )
+ model_name, low_cpu_mem_usage=True, torch_dtype="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
bits, group_size, scheme = 4, 128, "asym"
-# need load model first, them import
+# need to load model first, then import
from auto_round import AutoRound
-
autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, scheme=scheme,
device="hpu", scale_dtype="bf16", amp=False)
autoround.quantize()
+# Intel CPU Inference, Currently, llama, bloom, and mistral are supported.
+output_dir = "/path/to/quantized_model"
+autoround.export(output_dir)
+# then follow ITREX to load the model and do inference
+
```
### On GPU
@@ -53,8 +56,7 @@ from auto_round import AutoRound
model_name = "bigscience/bloom-560m"
model = AutoModelForCausalLM.from_pretrained(
- model_name, low_cpu_mem_usage=True, torch_dtype="auto", trust_remote_code=True
- )
+ model_name, low_cpu_mem_usage=True, torch_dtype="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
bits, group_size, scheme = 4, 128, "asym"
autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, scheme=scheme)
@@ -185,7 +187,7 @@ For wikitext2/ptb-new/c4-new ppl, we follow the code of gptq and set the sequenc
-
Ours iters1K, disable use_quant_input, minmax_lr 0.002 |
+ Ours iters=1K,use_quant_input=False, minmax_lr=0.002 |
67.70 |
60.57 |
73.74 |
@@ -246,7 +248,7 @@ For wikitext2/ptb-new/c4-new ppl, we follow the code of gptq and set the sequenc
- |
- | Ours iters1K, disable use_quant_input
+ | Ours iters=1K,use_quant_input=False
| 66.78 |
68.68 |
78.61 |
@@ -266,7 +268,7 @@ For wikitext2/ptb-new/c4-new ppl, we follow the code of gptq and set the sequenc
- | microsoft/phi-2 |
+ microsoft/phi-2 |
FP16 |
61.80 |
56.40 |
@@ -306,6 +308,26 @@ For wikitext2/ptb-new/c4-new ppl, we follow the code of gptq and set the sequenc
11.37 |
+
+
+ Ours iters=1K,use_quant_input=False |
+ 61.47 |
+ 55.41 |
+ 61.77 |
+ 54.92 |
+ 76.40 |
+ 78.29 |
+ 31.09 |
+ 40.0 |
+ 83.24 |
+ 63.54 |
+ 79.29 |
+ 52.22 |
+ 9.97 |
+ 18.63 |
+ 14.37 |
+ 11.35 |
+
We provide a [comprehensive analysis](docs/README.md) with other methods in our accuracy data section. Notably, our approach has outperformed GPTQ with a score of 30/32 and AWQ with a score of 27/32 across llamv1/llamav2/mistral-7b on W4G-1, W4G128, W3G128, W2G128. And the tuning costs are comparable.
diff --git a/auto_round/__init__.py b/auto_round/__init__.py
index 1b474ec47..01ad1c826 100644
--- a/auto_round/__init__.py
+++ b/auto_round/__init__.py
@@ -13,6 +13,6 @@
# limitations under the License.
from .autoround import AutoRound, AutoAdamRound, AutoOPTRound
from .version import __version__
-from .export_to_itrex import compress_model, QuantConfig
-from .export_to_autogptq import save_quantized_to_autogptq
+from auto_round.export.export_to_itrex import compress_model, QuantConfig
+from auto_round.export.export_to_autogptq import save_quantized_to_autogptq
diff --git a/auto_round/autoround.py b/auto_round/autoround.py
index 7d7ab1945..357614ec9 100644
--- a/auto_round/autoround.py
+++ b/auto_round/autoround.py
@@ -30,10 +30,11 @@
from functools import partial
from torch.functional import F
from .utils import (quant_weight, set_module, get_module, get_block_names, block_forward, sampling_inputs,
- get_dataloader, get_scale_shape, move_input_to_device, check_is_cpu, collect_round_v,
- collect_minmax_scale, get_batch_dim)
+ get_scale_shape, move_input_to_device, check_is_cpu, collect_round_v,
+ collect_minmax_scale, get_batch_dim, is_hpu_available)
+from .calib_dataset import CALIB_DATASETS
+
-from .device_utils import *
class SaveInputs:
@@ -124,15 +125,22 @@ def get_inputs(self, n_samples=512):
if data is None:
continue
if isinstance(data, torch.Tensor):
- input_ids = data.to(self.model.device)
+ data_new = data.to(self.model.device)
+ input_ids = data_new
else:
- input_ids = data["input_ids"].to(self.model.device)
- if input_ids.shape[-1] < self.seqlen:
- continue
+ data_new = {}
+ for key in data.keys():
+ data_new[key] = data[key].to(self.model.device)
+ input_ids = data_new["input_ids"]
+ # if input_ids.shape[-1] < self.seqlen:
+ # continue
if total_cnt + input_ids.shape[0] > n_samples:
input_ids = input_ids[: n_samples - total_cnt, ...]
try:
- self.model(input_ids)
+ if isinstance(data_new, torch.Tensor):
+ self.model(data_new)
+ elif isinstance(data_new, dict):
+ self.model(**data_new)
except NotImplementedError:
pass
except Exception as error:
@@ -153,11 +161,6 @@ def get_inputs(self, n_samples=512):
f"Effective samples size:{total_cnt}, Target sample size:{n_samples}"
)
res = self.inputs[self.block_name]
- if "input_ids" in res.keys():
- total_samples = res["input_ids"].shape[0]
- if total_samples < n_samples:
- logger.warning("only cache {total_samples}")
-
return res
def _recover_forward(self):
@@ -582,13 +585,16 @@ def __init__(
self.dataset_name = dataset_name
if dataloader is None:
+ get_dataloader = CALIB_DATASETS.get(self.dataset_name,
+ CALIB_DATASETS["NeelNanda/pile-10k"])
self.dataloader = get_dataloader(
self.tokenizer,
self.seqlen,
seed=self.seed,
bs=self.train_bs,
split=self.dataset_split,
- data_name=self.dataset_name,
+ dataset_name=self.dataset_name
+
)
else:
self.dataloader = dataloader
@@ -989,7 +995,7 @@ def export_to_autogptq(self, output_dir, use_triton=False):
def export_to_itrex(self, output_dir):
"""Save configure file and weights for CPU backend inference."""
- from .export_to_itrex import compress_model
+ from auto_round.export.export_to_itrex import compress_model
compressed_model, quantize_config = compress_model(self.model, self.weight_config)
if quantize_config is not None:
config = compressed_model.config
@@ -1028,6 +1034,7 @@ def quantize(self):
del save_input_actor
if "input_ids" in inputs.keys():
total_samples = inputs["input_ids"].shape[0]
+ self.n_samples = total_samples
if total_samples < self.train_bs:
self.train_bs = total_samples
logger.warning(f"force the train batch size to {total_samples} ")
@@ -1322,4 +1329,3 @@ def __init__(
optimizer,
**kwargs,
)
-
diff --git a/auto_round/calib_dataset.py b/auto_round/calib_dataset.py
new file mode 100644
index 000000000..617494b69
--- /dev/null
+++ b/auto_round/calib_dataset.py
@@ -0,0 +1,161 @@
+import torch
+import random
+
+CALIB_DATASETS = {}
+
+
+def register_dataset(name):
+ """Class decorator to register a DATASET subclass to the registry.
+
+ Decorator function used before a Pattern subclass.
+
+ Args:
+ cls (class): The subclass of register.
+ name: A string. Define the pruner type.
+
+ Returns:
+ cls: The class of register.
+ """
+
+ def register(dataset):
+ CALIB_DATASETS[name] = dataset
+ return dataset
+
+ return register
+
+
+def get_tokenizer_function(tokenizer, seqlen):
+ """Returns a default tokenizer function.
+
+ Args:
+ tokenizer: The tokenizer to be used for tokenization.
+ seqlen: The maximum sequence length.
+
+ Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length of
+ seqlen to the "text" field of examples.
+ """
+
+ def default_tokenizer_function(examples):
+ example = tokenizer(examples["text"], truncation=True, max_length=seqlen)
+ return example
+
+ return default_tokenizer_function
+
+
+@register_dataset("NeelNanda/pile-10k")
+def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split="train", seed=42, bs=4):
+ """Returns a dataloader for the specified dataset and split.
+
+ Args:
+ tokenizer: The tokenizer to be used for tokenization.
+ seqlen: The maximum sequence length.
+ data_name: The name of the dataset.
+ split: The data split to be used (e.g., "train", "test").
+ seed: The random seed for shuffling the dataset.
+ bs: The batch size for the dataloader.
+
+ Returns:
+ A dataloader for the specified dataset and split, using the provided tokenizer and sequence length.
+ """
+ from datasets import load_dataset
+ from torch.utils.data import DataLoader
+
+ tokenizer_function = get_tokenizer_function(tokenizer, seqlen)
+
+ @torch.no_grad()
+ def collate_batch(batch):
+ input_ids_new = []
+ for text in batch:
+ input_ids = text["input_ids"]
+ if input_ids.shape[0] < seqlen:
+ continue
+ input_ids = input_ids[:seqlen]
+ input_ids_list = input_ids.tolist()
+ if input_ids_list.count(input_ids_list[-1]) > seqlen // 2:
+ continue
+ input_ids_new.append(input_ids)
+ if len(input_ids_new) == 0:
+ return None
+ tmp = torch.vstack(input_ids_new)
+ res = {"input_ids": tmp}
+ return res
+
+ calib_dataset = load_dataset(dataset_name, split=split)
+ calib_dataset = calib_dataset.shuffle(seed=seed)
+ calib_dataset = calib_dataset.map(tokenizer_function, batched=True)
+ calib_dataset.set_format(type="torch", columns=["input_ids"])
+ calib_dataloader = DataLoader(calib_dataset, batch_size=bs, shuffle=False, collate_fn=collate_batch)
+ return calib_dataloader
+
+
+@register_dataset("mbpp")
+def get_mbpp_dataloader(tokenizer, seqlen, dataset_name="mbpp", split=['train', 'validation', 'test'], seed=42, bs=4):
+ """Returns a dataloader for the specified dataset and split.
+
+ Args:
+ tokenizer: The tokenizer to be used for tokenization.
+ seqlen: The maximum sequence length.
+ data_name: The name of the dataset.
+ split: The data split to be used (e.g., "train", "test").
+ seed: The random seed for shuffling the dataset.
+ bs: The batch size for the dataloader.
+
+ Returns:
+ A dataloader for the specified dataset and split, using the provided tokenizer and sequence length.
+ """
+ from datasets import load_dataset
+ from torch.utils.data import DataLoader
+
+ def get_mbpp_tokenizer_function(tokenizer, seqlen):
+ """Returns a default tokenizer function.
+
+ Args:
+ tokenizer: The tokenizer to be used for tokenization.
+ seqlen: The maximum sequence length.
+
+ Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length of
+ seqlen to the "text" field of examples.
+ """
+
+ def default_tokenizer_function(examples):
+ example = tokenizer(examples, truncation=True, max_length=seqlen, return_tensors="pt")
+ # example = tokenizer(examples, return_tensors="pt")
+ return example
+
+ return default_tokenizer_function
+
+ tokenizer_function = get_mbpp_tokenizer_function(tokenizer, seqlen)
+
+ @torch.no_grad()
+ def collate_batch(batch):
+ input_ids_new = []
+ attention_mask_new = []
+ for text in batch:
+ token_text = tokenizer_function(text)
+ input_ids, attention_mask = token_text["input_ids"], token_text["attention_mask"]
+ if input_ids.shape[1] < seqlen:
+ continue
+ input_ids = input_ids[:seqlen]
+ input_ids_list = input_ids.tolist()
+ if input_ids_list.count(input_ids_list[-1]) > seqlen // 2:
+ continue
+ attention_mask = attention_mask[:seqlen]
+ attention_mask_new.append(attention_mask)
+ input_ids_new.append(input_ids)
+ if len(input_ids_new) == 0:
+ return None
+ input_ids_new = torch.vstack(input_ids_new)
+ attention_mask_new = torch.vstack(attention_mask_new)
+ res = {"input_ids": input_ids_new, "attention_mask": attention_mask_new}
+ return res
+
+ samples = []
+ splits = split
+ for split in splits:
+ dataset = load_dataset(dataset_name, split=split)
+ for data in dataset:
+ samples.append(data["text"] + data["code"])
+ random.Random(seed).shuffle(samples)
+
+ calib_dataloader = DataLoader(samples, batch_size=bs, shuffle=False, collate_fn=collate_batch)
+ return calib_dataloader
diff --git a/auto_round/device_utils.py b/auto_round/device_utils.py
deleted file mode 100644
index 2436aea0e..000000000
--- a/auto_round/device_utils.py
+++ /dev/null
@@ -1,33 +0,0 @@
-
-# Copyright (c) 2023 Intel Corporation
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-def is_optimum_habana_available():
- import importlib
- from transformers.utils.import_utils import is_optimum_available
-
- return is_optimum_available() and importlib.util.find_spec("optimum.habana") != None
-
-try:
- import habana_frameworks.torch.hpu as hthpu
- import habana_frameworks.torch.core as htcore
- if is_optimum_habana_available():
- is_hpu_available = True
- else:
- print("Should install optimum-habana when the environment has habana frameworks")
- is_hpu_available = False
-except ImportError:
- is_hpu_available = False
-
diff --git a/examples/eval/__init__.py b/auto_round/export/__init__.py
similarity index 100%
rename from examples/eval/__init__.py
rename to auto_round/export/__init__.py
diff --git a/auto_round/export_to_autogptq.py b/auto_round/export/export_to_autogptq.py
similarity index 100%
rename from auto_round/export_to_autogptq.py
rename to auto_round/export/export_to_autogptq.py
diff --git a/auto_round/export_to_itrex/__init__.py b/auto_round/export/export_to_itrex/__init__.py
similarity index 100%
rename from auto_round/export_to_itrex/__init__.py
rename to auto_round/export/export_to_itrex/__init__.py
diff --git a/auto_round/export_to_itrex/config.py b/auto_round/export/export_to_itrex/config.py
similarity index 100%
rename from auto_round/export_to_itrex/config.py
rename to auto_round/export/export_to_itrex/config.py
diff --git a/auto_round/export_to_itrex/export.py b/auto_round/export/export_to_itrex/export.py
similarity index 98%
rename from auto_round/export_to_itrex/export.py
rename to auto_round/export/export_to_itrex/export.py
index 1177cd17f..5af0528c6 100644
--- a/auto_round/export_to_itrex/export.py
+++ b/auto_round/export/export_to_itrex/export.py
@@ -24,7 +24,7 @@
import json
from .config import QuantConfig
from .model_wrapper import WeightOnlyLinear
-from ..utils import quant_weight_w_scale, get_module, set_module
+from auto_round.utils import quant_weight_w_scale, get_module, set_module
def compress_model(
diff --git a/auto_round/export_to_itrex/model_wrapper.py b/auto_round/export/export_to_itrex/model_wrapper.py
similarity index 100%
rename from auto_round/export_to_itrex/model_wrapper.py
rename to auto_round/export/export_to_itrex/model_wrapper.py
diff --git a/auto_round/utils.py b/auto_round/utils.py
index 2e84b82fc..add95604a 100644
--- a/auto_round/utils.py
+++ b/auto_round/utils.py
@@ -19,6 +19,23 @@
from collections import UserDict
+def is_optimum_habana_available():
+ import importlib
+ from transformers.utils.import_utils import is_optimum_available
+
+ return is_optimum_available() and importlib.util.find_spec("optimum.habana") != None
+
+try:
+ import habana_frameworks.torch.hpu as hthpu
+ import habana_frameworks.torch.core as htcore
+ if is_optimum_habana_available():
+ is_hpu_available = True
+ else:
+ print("Should install optimum-habana when the environment has habana frameworks")
+ is_hpu_available = False
+except ImportError:
+ is_hpu_available = False
+
def round_ste(x: torch.Tensor):
"""Straight-Through Estimator for rounding.
This function is adapted from omniquant.
@@ -313,67 +330,6 @@ def get_block_names(model):
return block_names
-def get_tokenizer_function(tokenizer, seqlen):
- """Returns a default tokenizer function.
-
- Args:
- tokenizer: The tokenizer to be used for tokenization.
- seqlen: The maximum sequence length.
-
- Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length of
- seqlen to the "text" field of examples.
- """
-
- def default_tokenizer_function(examples):
- example = tokenizer(examples["text"], truncation=True, max_length=seqlen)
- return example
-
- return default_tokenizer_function
-
-
-def get_dataloader(tokenizer, seqlen, data_name="NeelNanda/pile-10k", split="train", seed=42, bs=4):
- """Returns a dataloader for the specified dataset and split.
-
- Args:
- tokenizer: The tokenizer to be used for tokenization.
- seqlen: The maximum sequence length.
- data_name: The name of the dataset.
- split: The data split to be used (e.g., "train", "test").
- seed: The random seed for shuffling the dataset.
- bs: The batch size for the dataloader.
-
- Returns:
- A dataloader for the specified dataset and split, using the provided tokenizer and sequence length.
- """
- from datasets import load_dataset
- from torch.utils.data import DataLoader
-
- tokenizer_function = get_tokenizer_function(tokenizer, seqlen)
-
- @torch.no_grad()
- def collate_batch(batch):
- input_ids_new = []
- for text in batch:
- input_ids = text["input_ids"]
- if input_ids.shape[0] < seqlen:
- continue
- input_ids = input_ids[:seqlen]
- input_ids_list = input_ids.tolist()
- if input_ids_list.count(input_ids_list[-1]) > seqlen // 2:
- continue
- input_ids_new.append(input_ids)
- if len(input_ids_new) == 0:
- return None
- tmp = torch.vstack(input_ids_new)
- res = {"input_ids": tmp}
- return res
-
- calib_dataset = load_dataset(data_name, split=split)
- calib_dataset = calib_dataset.shuffle(seed=seed)
- calib_dataset = calib_dataset.map(tokenizer_function, batched=True)
- calib_dataset.set_format(type="torch", columns=["input_ids"])
- calib_dataloader = DataLoader(calib_dataset, batch_size=bs, shuffle=False, collate_fn=collate_batch)
- return calib_dataloader
def collect_round_v(block):
diff --git a/examples/code-generation/README.md b/examples/code-generation/README.md
new file mode 100644
index 000000000..d69ab3a0f
--- /dev/null
+++ b/examples/code-generation/README.md
@@ -0,0 +1,65 @@
+Step-by-Step
+============
+
+This document presents step-by-step instructions for auto-round.
+
+
+## 2. Prepare Dataset
+
+The mbpp in huggingface is adopted as the default calibration data and will be downloaded automatically from the datasets Hub. To customize a dataset, please kindly follow our dataset code.
+See more about loading [huggingface dataset](https://huggingface.co/docs/datasets/loading_datasets.html)
+
+
+
+## 3. Run Examples
+Enter into the examples folder
+- **Default Settings:**
+```bash
+CUDA_VISIBLE_DEVICES=0 python3 main.py --model_name Salesforce/codegen25-7b-multi --amp --bits 4 --group_size -1 --enable_minmax_tuning --use_quant_input
+```
+- **Reduced GPU Memory Usage and Adjusted Training Batch Size:**
+```bash
+CUDA_VISIBLE_DEVICES=0 python3 main.py --model_name Salesforce/codegen25-7b-multi --amp --bits 4 --group_size -1 --low_gpu_mem_usage --train_bs 1 --gradient_accumulate_steps 8
+```
+- **Utilizing the AdamW Optimizer:**
+Include the flag `--adam`. Note that AdamW is less effective than Sign gradient descent in many scenarios we tested.
+
+- **Running the Original SignRound:**
+```bash
+CUDA_VISIBLE_DEVICES=0 python3 main.py --model_name Salesforce/codegen25-7b-multi --amp --bits 4 --group_size -1 --iters 400 --lr 0.0025 --minmax_lr 0.0025
+```
+ `--enable_minmax_tuning` is strongly recommended
+
+
+## 4. Evaluation
+Please follow https://github.com/bigcode-project/bigcode-evaluation-harness to eval the model, currently we only support fake model evaluation.
+
+
+ | Model |
+ Method |
+ HumanEval top1 t=0.2 n_samples=20 |
+
+
+ | Salesforce/codegen25-7b-multi |
+ FP16 |
+ 0.2854 |
+
+
+ | AutoRound use_quant_input=False |
+ 0.2841 |
+
+
+
+
+
+## Reference
+If you find SignRound useful for your research, please cite our paper:
+```bash
+@article{cheng2023optimize,
+ title={Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs},
+ author={Cheng, Wenhua and Zhang, Weiwei and Shen, Haihao and Cai, Yiyang and He, Xin and Lv, Kaokao},
+ journal={arXiv preprint arXiv:2309.05516},
+ year={2023}
+}
+```
+
diff --git a/examples/code-generation/main.py b/examples/code-generation/main.py
new file mode 100644
index 000000000..92755f7b3
--- /dev/null
+++ b/examples/code-generation/main.py
@@ -0,0 +1,198 @@
+import argparse
+import random
+import sys
+
+sys.path.insert(0, '../..')
+from auto_round import (AutoRound,
+ AutoAdamRound)
+
+parser = argparse.ArgumentParser()
+import torch
+import os
+
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+torch.use_deterministic_algorithms(True, warn_only=True)
+from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
+
+from transformers import set_seed
+
+import re
+
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+
+
+if __name__ == '__main__':
+
+ parser.add_argument(
+ "--model_name", default="facebook/opt-125m"
+ )
+
+ parser.add_argument("--bits", default=4, type=int,
+ help="number of bits")
+
+ parser.add_argument("--group_size", default=128, type=int,
+ help="group size")
+
+ parser.add_argument("--train_bs", default=8, type=int,
+ help="train batch size")
+
+ parser.add_argument("--eval_bs", default=4, type=int,
+ help="eval batch size")
+
+ parser.add_argument("--device", default=0, type=str,
+ help="device gpu int number, or 'cpu' ")
+
+ parser.add_argument("--sym", action='store_true',
+ help=" sym quantization")
+
+ parser.add_argument("--iters", default=200, type=int,
+ help=" iters")
+
+ parser.add_argument("--use_quant_input", action='store_true',
+ help="whether to use the output of quantized block to tune the next block")
+
+ parser.add_argument("--lr", default=None, type=float,
+ help="learning rate, if None, it will be set to 1.0/iters automatially")
+
+ parser.add_argument("--minmax_lr", default=None, type=float,
+ help="minmax learning rate, if None,it will beset to be the same with lr")
+
+ parser.add_argument("--seed", default=42, type=int,
+ help="seed")
+
+ parser.add_argument("--amp", action='store_true',
+ help="amp")
+
+ parser.add_argument("--adam", action='store_true',
+ help="adam")
+
+ parser.add_argument("--seqlen", default=128, type=int,
+ help="sequence length")
+
+ parser.add_argument("--gradient_accumulate_steps", default=1, type=int, help="gradient accumulate steps")
+
+ parser.add_argument("--n_blocks", default=1, type=int, help="num of blocks to tune together")
+
+ parser.add_argument("--n_samples", default=512, type=int,
+ help="number of samples")
+
+ parser.add_argument("--low_gpu_mem_usage", action='store_true',
+ help="low_gpu_mem_usage")
+
+ parser.add_argument("--enable_minmax_tuning", action='store_true',
+ help="whether enable weight minmax tuning")
+
+ parser.add_argument("--deployment_device", default='fake', type=str,
+ help="targeted inference acceleration platform,The options are 'fake', 'cpu' and 'gpu',"
+ "default to 'fake', indicating that it only performs fake quantization and won't be exported to any device.")
+
+ parser.add_argument("--scale_dtype", default='fp32',
+ help="which scale data type to use for quantization, 'fp16', 'fp32' or 'bf16'.")
+
+ parser.add_argument("--tasks",
+ default=['wikitext2', 'ptb-new', 'c4-new', 'lambada_openai', 'hellaswag', 'winogrande', 'piqa',
+ "hendrycksTest-*", "wikitext", "truthfulqa_mc", "openbookqa", "boolq", "rte",
+ "arc_easy", "arc_challenge"],
+ help="lm-eval tasks")
+
+ parser.add_argument("--output_dir", default="./tmp_signround", type=str,
+ help="Where to store the final model.")
+
+ args = parser.parse_args()
+ set_seed(args.seed)
+
+ model_name = args.model_name
+ if model_name[-1] == "/":
+ model_name = model_name[:-1]
+ print(model_name, flush=True)
+
+ tasks = args.tasks
+
+ if args.device == "cpu":
+ device_str = "cpu"
+ else:
+ device_str = f"cuda:{int(args.device)}"
+ torch_device = torch.device(device_str)
+ is_glm = bool(re.search("chatglm", model_name.lower()))
+ is_llava = bool(re.search("llava", model_name.lower()))
+ if is_llava:
+ from transformers import LlavaForConditionalGeneration
+
+ model = LlavaForConditionalGeneration.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype="auto")
+ elif is_glm:
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
+ else:
+ model = AutoModelForCausalLM.from_pretrained(
+ model_name, low_cpu_mem_usage=True, torch_dtype="auto", trust_remote_code=True
+ )
+ model = model.eval()
+ # align wigh GPTQ to eval ppl
+ if "opt" in model_name:
+ seqlen = model.config.max_position_embeddings
+ model.seqlen = model.config.max_position_embeddings
+ else:
+ seqlen = 2048
+ model.seqlen = seqlen
+ seqlen = args.seqlen
+
+ if "llama" in model_name:
+ from transformers import LlamaTokenizer
+
+ tokenizer = LlamaTokenizer.from_pretrained(model_name)
+ if tokenizer.pad_token is None:
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+ if hasattr(tokenizer, "model_max_length"):
+ if tokenizer.model_max_length < seqlen:
+ print(f"change sequence length to {tokenizer.model_max_length} due to the limitation of model_max_length",
+ flush=True)
+ seqlen = min(seqlen, tokenizer.model_max_length)
+ args.seqlen = seqlen
+
+ excel_name = f"{model_name}_{args.bits}_{args.group_size}"
+
+ if not args.low_gpu_mem_usage:
+ model = model.to(torch_device)
+
+ scheme = "asym"
+ if args.sym:
+ scheme = "sym"
+ round = AutoRound
+ if args.adam:
+ round = AutoAdamRound
+ autoround = round(model, tokenizer, args.bits, args.group_size, scheme, bs=args.train_bs,
+ seqlen=seqlen, n_blocks=args.n_blocks, iters=args.iters, lr=args.lr,
+ minmax_lr=args.minmax_lr, use_quant_input=args.use_quant_input, device=device_str,
+ amp=args.amp, n_samples=args.n_samples, low_gpu_mem_usage=args.low_gpu_mem_usage,
+ seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps,
+ scale_dtype=args.scale_dtype, dataset_name="mbpp", dataset_split=['train', 'validation', 'test']) ##TODO args pass
+ model, q_config = autoround.quantize()
+ model_name = args.model_name.rstrip("/")
+ export_dir = args.output_dir + "/compressed_" + model_name.split('/')[-1] + "/"
+ if args.deployment_device == 'cpu':
+ autoround.export(output_dir=export_dir)
+ del q_config
+ elif args.deployment_device == 'gpu':
+ autoround.export(export_dir, target="auto_gptq", use_triton=True)
+ if args.device != "cpu":
+ torch.cuda.empty_cache()
+ model.eval()
+ output_dir = args.output_dir + "_" + model_name.split('/')[-1] + f"_w{args.bits}_g{args.group_size}"
+
+ import shutil
+
+ if os.path.exists(output_dir):
+ shutil.rmtree(output_dir)
+
+ if (hasattr(model, 'config') and model.config.torch_dtype is torch.bfloat16):
+ dtype = 'bfloat16'
+ pt_dtype = torch.bfloat16
+ else:
+ pt_dtype = torch.float16
+
+ model = model.to(pt_dtype)
+ model = model.to("cpu")
+ model.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
diff --git a/examples/code-generation/run_autoround.sh b/examples/code-generation/run_autoround.sh
new file mode 100644
index 000000000..8993a8665
--- /dev/null
+++ b/examples/code-generation/run_autoround.sh
@@ -0,0 +1,16 @@
+device=0
+#!/bin/bash
+set -x
+model_name="Salesforce/codegen25-7b-multi"
+
+CUDA_VISIBLE_DEVICES=$device \
+python3 main.py \
+--model_name $model_name \
+--device $device \
+--group_size 128 \
+--bits 4 \
+--iters 200 \
+--seqlen 128 \
+--enable_minmax_tuning \
+--output_dir "./tmp_signround" \
+--amp
\ No newline at end of file
diff --git a/examples/README.md b/examples/language-modeling/README.md
similarity index 90%
rename from examples/README.md
rename to examples/language-modeling/README.md
index 560a5956b..e5daec3e4 100644
--- a/examples/README.md
+++ b/examples/language-modeling/README.md
@@ -1,7 +1,7 @@
Step-by-Step
============
-This document presents step-by-step instructions for autoround.
+This document presents step-by-step instructions for auto-round.
# Prerequisite
@@ -32,7 +32,7 @@ The transformers version required varies across different types of models. Here,
## 2. Prepare Dataset
-The dataset will be downloaded automatically from the datasets Hub.
+The NeelNanda/pile-10k in huggingface is adopted as the default calibration data and will be downloaded automatically from the datasets Hub. To customize a dataset, please kindly follow our dataset code.
See more about loading [huggingface dataset](https://huggingface.co/docs/datasets/loading_datasets.html)
diff --git a/examples/language-modeling/eval/__init__.py b/examples/language-modeling/eval/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/eval/evaluation.py b/examples/language-modeling/eval/evaluation.py
similarity index 100%
rename from examples/eval/evaluation.py
rename to examples/language-modeling/eval/evaluation.py
diff --git a/examples/eval/parse_results.py b/examples/language-modeling/eval/parse_results.py
similarity index 100%
rename from examples/eval/parse_results.py
rename to examples/language-modeling/eval/parse_results.py
diff --git a/examples/eval/utils.py b/examples/language-modeling/eval/utils.py
similarity index 100%
rename from examples/eval/utils.py
rename to examples/language-modeling/eval/utils.py
diff --git a/examples/main.py b/examples/language-modeling/main.py
similarity index 98%
rename from examples/main.py
rename to examples/language-modeling/main.py
index f466c8e44..917f0b0df 100644
--- a/examples/main.py
+++ b/examples/language-modeling/main.py
@@ -1,6 +1,6 @@
import argparse
import sys
-sys.path.insert(0, '../')
+sys.path.insert(0, '../..')
from auto_round import (AutoRound,
AutoAdamRound)
@@ -25,7 +25,7 @@
if __name__ == '__main__':
parser.add_argument(
- "--model_name", default="/models/opt-125m"
+ "--model_name", default="facebook/opt-125m"
)
parser.add_argument("--bits", default=4, type=int,
diff --git a/examples/requirements.txt b/examples/language-modeling/requirements.txt
similarity index 100%
rename from examples/requirements.txt
rename to examples/language-modeling/requirements.txt
diff --git a/examples/run_autoround.sh b/examples/language-modeling/run_autoround.sh
similarity index 82%
rename from examples/run_autoround.sh
rename to examples/language-modeling/run_autoround.sh
index 66b3f0b44..ceabc30ca 100644
--- a/examples/run_autoround.sh
+++ b/examples/language-modeling/run_autoround.sh
@@ -1,7 +1,7 @@
#!/bin/bash
set -x
-model="neural-chat-7b-v3-3"
+model="Intel/neural-chat-7b-v3"
CUDA_VISIBLE_DEVICES=4 python3 main.py \
--model_name /models/${model} \
@@ -13,5 +13,5 @@ CUDA_VISIBLE_DEVICES=4 python3 main.py \
--deployment_device 'cpu' \
--scale_dtype 'fp32' \
--eval_bs 16 \
- --output_dir "SAVE/PATH"
+ --output_dir "./tmp_signround"
diff --git a/test/test_autoround.py b/test/test_autoround.py
index 4784433ce..172ecf8f2 100644
--- a/test/test_autoround.py
+++ b/test/test_autoround.py
@@ -1,15 +1,11 @@
import copy
-import os
import shutil
import unittest
import sys
sys.path.insert(0,".")
import torch
import transformers
-from transformers import AutoModelForCausalLM, AutoTokenizer
-from auto_round import (AutoRound,
- AutoAdamRound,
- QuantConfig)
+from auto_round import (AutoRound)
class SimpleDataLoader:
@@ -65,7 +61,7 @@ def test_autoround_int_quant(self):
optq_1 = round(model, self.tokenizer, n_samples=20, device=device, amp=False, seqlen=10, iters=10, scale_dtype='fp32')
q_model, weight_config1 = optq_1.quantize()
q_model = q_model
- from auto_round.export_to_itrex import compress_model
+ from auto_round.export.export_to_itrex import compress_model
compressed_model,_ = compress_model(q_model, weight_config1)
q_model = q_model
model = model