Skip to content

Commit

Permalink
Add StaticQuantConfig (#1501)
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss committed Apr 23, 2024
1 parent 8f5febc commit e1f4666
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 48 deletions.
1 change: 1 addition & 0 deletions intel_extension_for_transformers/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
MixedPrecisionConfig,
BitsAndBytesConfig,
SmoothQuantConfig,
StaticQuantConfig,
RtnConfig,
AwqConfig,
TeqConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
BitsAndBytesConfig,
MixedPrecisionConfig,
SmoothQuantConfig,
StaticQuantConfig,
RtnConfig,
AwqConfig,
TeqConfig,
Expand Down Expand Up @@ -71,6 +72,7 @@
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear
from neural_compressor.model.torch_model import PyTorchFXModel
from threading import Thread
from transformers.configuration_utils import PretrainedConfig
from transformers import AutoConfig
Expand Down Expand Up @@ -211,7 +213,14 @@ def save_low_bit(
f"Provided path ({save_directory}) should be a directory, not a file"
)
return

if isinstance(self, PyTorchFXModel):
self.quantization_config.save_pretrained(save_directory, **kwargs)
self.model.config.quantization_config = self.quantization_config
self.model.config.save_pretrained(save_directory)
weights_file = os.path.join(
os.path.abspath(os.path.expanduser(save_directory)), WEIGHTS_NAME)
torch.save(self.quantized_state_dict(), weights_file)
return
convert_model_to_public(self)
os.makedirs(save_directory, exist_ok=True)
# use transformers original `save_pretrained` function
Expand Down Expand Up @@ -403,7 +412,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"Quantization_config loading failed. If you want to load saved "
"low bit model, please check your quantizate_config.json."
)
elif use_neural_speed:
elif use_neural_speed and not config.quantization_config["quant_method"] == "static":
if not os.path.exists(pretrained_model_name_or_path):
from huggingface_hub import snapshot_download
pretrained_model_name_or_path = snapshot_download(repo_id=pretrained_model_name_or_path,
Expand Down Expand Up @@ -963,6 +972,157 @@ def calib_func(model):
),
)
logger.info("SmoothQuant done.")
elif isinstance(quantization_config, StaticQuantConfig):
if quantization_config.backend == "ipex":
try:
import intel_extension_for_pytorch as ipex
except ImportError:
logger.warning(
"Please install Intel Extension for PyTorch to accelerate the model inference."
)
config.torchscript = True
assert quantization_config.example_inputs is not None, \
"Please provide example_inputs for IPEX static quantization."

model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
low_cpu_mem_usage=True,
torch_dtype=torch.float,
**kwargs,
)

if (
not torch.cuda.is_available()
or device_map == "cpu"
or device_map == torch.device("cpu")
) and model.config.model_type == "chatglm":
model = model.float()
model.eval()
logger.info("Applying StaticQuant.")
# calibration function
calib_func = quantization_config.calib_func
tokenizer = quantization_config.tokenizer
if calib_func is None:
if quantization_config.tokenizer is None:
logger.error(
"Please provide the tokenizer or provide calib_func directly,"
+ " the following is how to get tokenizer. \n"
+ " from transformer import AutoTokenizer \n"
+ " tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) \n"
)
exit(0)

from datasets import load_dataset
from torch.utils.data import DataLoader

calib_dataset = quantization_config.calib_dataset
calib_shuffle = quantization_config.calib_shuffle
calib_iters = quantization_config.calib_iters
calib_padding = quantization_config.calib_padding
calib_len = quantization_config.calib_len
calib_pad_val = quantization_config.calib_pad_val
from torch.nn.functional import pad

calib_dataset = load_dataset(
calib_dataset,
split=(
"test"
if calib_dataset in ["mbpp", "openai_humaneval"]
else "train"
),
)
if calib_shuffle:
calib_dataset = calib_dataset.shuffle(seed=42)

def tokenize_function(examples):
if "code" in examples:
example = tokenizer(examples["code"])
elif "prompt" in examples:
example = tokenizer(examples["prompt"])
elif "text" in examples:
example = tokenizer(examples["text"])
else:
logger.error(
"Please check dataset prompt identifier,"
+ " NeelNanda/pile-10k is default used calibration dataset."
)
exit(0)
return example

def collate_batch(batch):
input_ids_padded = []
last_ind = []
for text in batch:
input_ids = text["input_ids"]
if not calib_padding:
input_ids = (
input_ids[: int(calib_len)]
if len(input_ids) > int(calib_len)
else input_ids
) # no_padding
else:
pad_len = calib_len - input_ids.shape[0]
input_ids = pad(
input_ids, (0, pad_len), value=calib_pad_val
)

last_ind.append(input_ids.shape[0] - 1)
input_ids_padded.append(input_ids)

return (
{
"input_ids": torch.vstack(input_ids_padded),
},
torch.tensor(last_ind),
)


tokenized_dataset = calib_dataset.map(tokenize_function, batched=True)
tokenized_dataset.set_format(type="torch", columns=["input_ids"])
calib_dataloader = DataLoader(
tokenized_dataset,
batch_size=1,
shuffle=False,
collate_fn=collate_batch,
)

def calib_func(model):
with torch.no_grad():
for i, (inputs, last_ind) in enumerate(calib_dataloader):
if i >= calib_iters:
break
model(**inputs)

logger.info(
"The default calibration function is used, "
+ "the calibration dataset is NeelNanda/pile-10k, "
+ "batchsize is 1 and calibration iteration is 100."
)
calib_func = calib_func


# call inc static quant
from neural_compressor import PostTrainingQuantConfig, quantization

conf = PostTrainingQuantConfig(
backend=quantization_config.backend, # default is ipex
excluded_precisions=quantization_config.excluded_precisions,
op_type_dict=quantization_config.op_type_dict,
op_name_dict=quantization_config.op_name_dict,
example_inputs=quantization_config.example_inputs,
)
model = quantization.fit(
model,
conf,
calib_func=calib_func,
)
model.save_pretrained = types.MethodType(save_low_bit, model)
quantization_config.remove_redundant_parameters()
model.quantization_config = quantization_config
logger.info("StaticQuant done.")
return model
else:
if use_neural_speed:
logger.info("Using Neural Speed with FP32 model dtype.")
Expand Down Expand Up @@ -1093,6 +1253,8 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
quantization_config = GPTQConfig.from_dict(quantization_config)
elif quantization_config["quant_method"] == "autoround":
quantization_config = AutoRoundConfig.from_dict(quantization_config)
elif quantization_config["quant_method"] == "static":
quantization_config = StaticQuantConfig.from_dict(quantization_config)
assert (
quantization_config is not None
), "Detect this model is not a low-bit model."
Expand Down Expand Up @@ -1336,6 +1498,16 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# by checking its first weights entry that is of a floating type
# - we assume all floating dtype weights are of the same dtype
# we also may have config.torch_dtype available, but we won't rely on it till v5
# Pretrained Model
if quantization_config.quant_method == "static":
model = model_class(config, *model_args, **kwargs)
from neural_compressor.utils.pytorch import load
weights_file = os.path.join(
os.path.abspath(os.path.expanduser(pretrained_model_name_or_path)), WEIGHTS_NAME)
q_model = load(weights_file, model, dataloader=None)
del model
return q_model

dtype_orig = None
if torch_dtype is not None:
if isinstance(torch_dtype, str):
Expand Down Expand Up @@ -1378,7 +1550,6 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if quantization_config.weight_dtype is None:
quantization_config.weight_dtype = "int4_clip"

# Pretrained Model
init_contexts = [no_init_weights(_enable=_fast_init)]
init_contexts.append(init_empty_weights())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
MixedPrecisionConfig,
BitsAndBytesConfig,
SmoothQuantConfig,
StaticQuantConfig,
SparsityConfig,
RtnConfig,
AwqConfig,
Expand Down
Loading

0 comments on commit e1f4666

Please sign in to comment.