Skip to content

Commit

Permalink
Add DynamicQuantConfig and QuantAwareTrainingConfig (#1505)
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss committed Apr 25, 2024
1 parent 8116fbb commit 6a15b48
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 4 deletions.
2 changes: 2 additions & 0 deletions intel_extension_for_transformers/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
BitsAndBytesConfig,
SmoothQuantConfig,
StaticQuantConfig,
DynamicQuantConfig,
QuantAwareTrainingConfig,
RtnConfig,
AwqConfig,
TeqConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
MixedPrecisionConfig,
SmoothQuantConfig,
StaticQuantConfig,
DynamicQuantConfig,
QuantAwareTrainingConfig,
RtnConfig,
AwqConfig,
TeqConfig,
Expand Down Expand Up @@ -412,7 +414,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"Quantization_config loading failed. If you want to load saved "
"low bit model, please check your quantizate_config.json."
)
elif use_neural_speed and not config.quantization_config["quant_method"] == "static":
elif use_neural_speed and not config.quantization_config["quant_method"] in ["dynamic", "static", "qat"]:
if not os.path.exists(pretrained_model_name_or_path):
from huggingface_hub import snapshot_download
pretrained_model_name_or_path = snapshot_download(repo_id=pretrained_model_name_or_path,
Expand Down Expand Up @@ -972,6 +974,42 @@ def calib_func(model):
),
)
logger.info("SmoothQuant done.")
elif isinstance(quantization_config, DynamicQuantConfig):
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
low_cpu_mem_usage=True,
torch_dtype=torch.float,
**kwargs,
)

if (
not torch.cuda.is_available()
or device_map == "cpu"
or device_map == torch.device("cpu")
) and model.config.model_type == "chatglm":
model = model.float()
model.eval()
logger.info("Applying DynamicQuant.")
# call inc dynamic quant
from neural_compressor import PostTrainingQuantConfig, quantization

conf = PostTrainingQuantConfig(
approach="dynamic",
excluded_precisions=quantization_config.excluded_precisions,
op_type_dict=quantization_config.op_type_dict,
op_name_dict=quantization_config.op_name_dict,
)
model = quantization.fit(
model,
conf,
)
model.save_pretrained = types.MethodType(save_low_bit, model)
quantization_config.remove_redundant_parameters()
model.quantization_config = quantization_config
logger.info("DynamicQuant done.")
return model
elif isinstance(quantization_config, StaticQuantConfig):
if quantization_config.backend == "ipex":
try:
Expand Down Expand Up @@ -1107,7 +1145,7 @@ def calib_func(model):
from neural_compressor import PostTrainingQuantConfig, quantization

conf = PostTrainingQuantConfig(
backend=quantization_config.backend, # default is ipex
backend=quantization_config.backend,
excluded_precisions=quantization_config.excluded_precisions,
op_type_dict=quantization_config.op_type_dict,
op_name_dict=quantization_config.op_name_dict,
Expand All @@ -1123,6 +1161,157 @@ def calib_func(model):
model.quantization_config = quantization_config
logger.info("StaticQuant done.")
return model
elif isinstance(quantization_config, QuantAwareTrainingConfig):
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
low_cpu_mem_usage=True,
torch_dtype=torch.float,
**kwargs,
)

if (
not torch.cuda.is_available()
or device_map == "cpu"
or device_map == torch.device("cpu")
) and model.config.model_type == "chatglm":
model = model.float()
logger.info("Applying QuantAwareTraining.")
# train function
train_func = quantization_config.train_func
tokenizer = quantization_config.tokenizer
if train_func is None:
if quantization_config.tokenizer is None:
logger.error(
"Please provide the tokenizer or provide train_func directly,"
+ " the following is how to get tokenizer. \n"
+ " from transformer import AutoTokenizer \n"
+ " tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) \n"
)
exit(0)

from datasets import load_dataset
from torch.utils.data import DataLoader

train_dataset = quantization_config.train_dataset
train_shuffle = quantization_config.train_shuffle
train_iters = quantization_config.train_iters
train_padding = quantization_config.train_padding
train_len = quantization_config.train_len
train_pad_val = quantization_config.train_pad_val
from torch.nn.functional import pad

train_dataset = load_dataset(
train_dataset,
split=(
"test"
if train_dataset in ["mbpp", "openai_humaneval"]
else "train"
),
)
if train_shuffle:
train_dataset = train_dataset.shuffle(seed=42)

def tokenize_function(examples):
if "code" in examples:
example = tokenizer(examples["code"])
elif "prompt" in examples:
example = tokenizer(examples["prompt"])
elif "text" in examples:
example = tokenizer(examples["text"])
else:
logger.error(
"Please check dataset prompt identifier,"
+ " NeelNanda/pile-10k is default used calibration dataset."
)
exit(0)
return example

def collate_batch(batch):
input_ids_padded = []
last_ind = []
for text in batch:
input_ids = text["input_ids"]
if not train_padding:
input_ids = (
input_ids[: int(train_len)]
if len(input_ids) > int(train_len)
else input_ids
) # no_padding
else:
pad_len = train_len - input_ids.shape[0]
input_ids = pad(
input_ids, (0, pad_len), value=train_pad_val
)

last_ind.append(input_ids.shape[0] - 1)
input_ids_padded.append(input_ids)

return (
{
"input_ids": torch.vstack(input_ids_padded),
},
torch.tensor(last_ind),
)


tokenized_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_dataset.set_format(type="torch", columns=["input_ids"])
train_dataloader = DataLoader(
tokenized_dataset,
batch_size=quantization_config.train_batch_size,
shuffle=False,
collate_fn=collate_batch,
)

def train_func(model):
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
# switch to evaluate mode
model.train()
for i, (inputs, last_ind) in enumerate(train_dataloader):
if i >= train_iters:
break
output = model(**inputs)
if isinstance(output, tuple):
loss = output[0].mean()
elif isinstance(output, dict):
loss = output["logits"].mean()
else:
loss = output.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Iteration [{}], Loss: {:.4f}'.format(i+1, loss))
return model

logger.info(
"The default calibration function is used, "
+ "the calibration dataset is NeelNanda/pile-10k, "
+ "batchsize is 1 and calibration iteration is 100."
)
train_func = train_func


# call inc static quant
from neural_compressor import QuantizationAwareTrainingConfig, quantization
from neural_compressor.training import prepare_compression
conf = QuantizationAwareTrainingConfig(
backend=quantization_config.backend,
excluded_precisions=quantization_config.excluded_precisions,
op_type_dict=quantization_config.op_type_dict,
op_name_dict=quantization_config.op_name_dict,
)
compression_manager = prepare_compression(model, conf)
compression_manager.callbacks.on_train_begin()
model = compression_manager.model
train_func(model)
compression_manager.callbacks.on_train_end()
compression_manager.model.save_pretrained = types.MethodType(save_low_bit, model)
quantization_config.remove_redundant_parameters()
compression_manager.model.quantization_config = quantization_config
logger.info("Quant Aware Training done.")
return compression_manager.model
else:
if use_neural_speed:
logger.info("Using Neural Speed with FP32 model dtype.")
Expand Down Expand Up @@ -1255,6 +1444,10 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
quantization_config = AutoRoundConfig.from_dict(quantization_config)
elif quantization_config["quant_method"] == "static":
quantization_config = StaticQuantConfig.from_dict(quantization_config)
elif quantization_config["quant_method"] == "dynamic":
quantization_config = DynamicQuantConfig.from_dict(quantization_config)
elif quantization_config["quant_method"] == "qat":
quantization_config = QuantAwareTrainingConfig.from_dict(quantization_config)
assert (
quantization_config is not None
), "Detect this model is not a low-bit model."
Expand Down Expand Up @@ -1499,7 +1692,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# - we assume all floating dtype weights are of the same dtype
# we also may have config.torch_dtype available, but we won't rely on it till v5
# Pretrained Model
if quantization_config.quant_method == "static":
if quantization_config.quant_method in ["static", "dynamic", "qat"]:
model = model_class(config, *model_args, **kwargs)
from neural_compressor.utils.pytorch import load
weights_file = os.path.join(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
BitsAndBytesConfig,
SmoothQuantConfig,
StaticQuantConfig,
DynamicQuantConfig,
QuantAwareTrainingConfig,
SparsityConfig,
RtnConfig,
AwqConfig,
Expand Down
56 changes: 55 additions & 1 deletion intel_extension_for_transformers/transformers/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@ class QuantizationMethod(str, Enum):
RTN = "rtn"
AUTOROUND = "autoround"
TEQ = "teq"
DYNAMIC = "dynamic"
STATIC = "static"
SmoothQuant = "sq"
QuantAwareTraining = "qat"



class SparsityConfig(PretrainedConfig):
Expand Down Expand Up @@ -537,7 +540,9 @@ def remove_redundant_parameters(self):
"double_quant_scale_dtype", "use_double_quant", "mse_range", "scheme", "tokenizer", "use_ggml",
"use_neural_speed", "use_quant", "layer_wise", "blocksize", "nsamples", "max_input_length", "static_groups",
"lr", "minmax_lr", "iters", "use_quant_input", "device", "calib_dataset", "calib_pad_val", "calib_shuffle",
"calib_padding", "example_inputs", "excluded_precisions", "op_name_dict", "op_type_dict"]
"calib_padding", "example_inputs", "excluded_precisions", "op_name_dict", "op_type_dict", "train_dataloader",
"train_func", "train_iters", "train_len", "train_padding", "train_dataset", "train_pad_val", "train_shuffle",
"train_batch_size"]
for parameter in remove_parameters:
if hasattr(self, parameter):
delattr(self, parameter)
Expand Down Expand Up @@ -600,6 +605,55 @@ def get_config_dict(
pretrained_model_name_or_path, _configuration_file=cf, **kwargs
)

class QuantAwareTrainingConfig(ITREXQuantizationConfigMixin):
def __init__(
self,
backend="default",
tokenizer=None,
train_dataset="NeelNanda/pile-10k",
train_dataloader=None,
train_func=None,
train_shuffle=True,
train_iters=100,
train_padding=True,
train_batch_size=8,
train_len=512,
train_pad_val=1,
op_name_dict=None,
op_type_dict=None,
excluded_precisions=[],
**kwargs,
):
self.quant_method = QuantizationMethod.QuantAwareTraining
self.backend = backend
self.tokenizer = tokenizer
self.train_dataset = train_dataset
self.train_dataloader = train_dataloader
self.train_func = train_func
self.train_shuffle = train_shuffle
self.train_iters = train_iters
self.train_padding = train_padding
self.train_len = train_len
self.train_pad_val = train_pad_val
self.train_batch_size = train_batch_size
self.op_name_dict = op_name_dict
self.op_type_dict = op_type_dict
self.excluded_precisions = excluded_precisions


class DynamicQuantConfig(ITREXQuantizationConfigMixin):
def __init__(
self,
excluded_precisions=[],
op_name_dict=None,
op_type_dict=None,
**kwargs,
):
self.quant_method = QuantizationMethod.DYNAMIC
self.excluded_precisions = excluded_precisions
self.op_name_dict = op_name_dict
self.op_type_dict = op_type_dict

class StaticQuantConfig(ITREXQuantizationConfigMixin):
def __init__(
self,
Expand Down
33 changes: 33 additions & 0 deletions tests/CI/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ def test_quantization_for_llm(self):
MixedPrecisionConfig,
SmoothQuantConfig,
StaticQuantConfig,
DynamicQuantConfig,
QuantAwareTrainingConfig,
RtnConfig,
AwqConfig,
TeqConfig,
Expand All @@ -327,6 +329,21 @@ def test_quantization_for_llm(self):
from intel_extension_for_transformers.transformers import AutoModelForCausalLM
fp32_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, use_neural_speed=False)
dummy_input = fp32_model.dummy_inputs["input_ids"]

# Dynamic quant
dq_config = DynamicQuantConfig()
q_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
quantization_config=dq_config,
)
q_model.eval()
output = q_model(dummy_input)
q_model.save_pretrained("./saved_results")
output = q_model(dummy_input)
self.assertTrue(isclose(float(output[0][0][0][0]), 0.17140813171863556, rel_tol=1e-04))
q_model = AutoModelForCausalLM.from_pretrained("./saved_results"
)
output = q_model(dummy_input)
self.assertTrue(isclose(float(output[0][0][0][0]), 0.17140813171863556, rel_tol=1e-04))
# Static quant
sq_config = StaticQuantConfig(
tokenizer=tokenizer, # either two of one, tokenizer or calib_func
Expand All @@ -343,6 +360,22 @@ def test_quantization_for_llm(self):
loading_model.eval()
output = loading_model(dummy_input)
self.assertTrue(isclose(float(output[0][0][0][0]), 0.17378684878349304, rel_tol=1e-04))
# Quant aware training
qat_config = QuantAwareTrainingConfig(
tokenizer=tokenizer, # either two of one, tokenizer or train_func
train_iters=2,
)
q_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
quantization_config=qat_config,
)
q_model.eval()
output = q_model(dummy_input)
self.assertTrue(isclose(float(output[0][0][0][0]), 0.17362995445728302, rel_tol=1e-04))
q_model.save_pretrained("./saved_results")
loading_model = AutoModelForCausalLM.from_pretrained("./saved_results")
loading_model.eval()
output = loading_model(dummy_input)
self.assertTrue(isclose(float(output[0][0][0][0]), 0.17362995445728302, rel_tol=1e-04))
# Smoothquant
sq_config = SmoothQuantConfig(
tokenizer=tokenizer, # either two of one, tokenizer or calib_func
Expand Down

0 comments on commit 6a15b48

Please sign in to comment.