Skip to content

Commit

Permalink
Smoothquant support ipex.optimize_transformers feature (#695)
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss committed Nov 21, 2023
1 parent a6f84df commit d2bd4d9
Show file tree
Hide file tree
Showing 6 changed files with 364 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
help="by default it is int8-fp32 mixed, to enable int8 mixed amp bf16 (work on platforms like SPR)",
)
parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model")
parser.add_argument("--quantized_model_path", type=str, default="saved_results/best_model.pt", help="the int8 model path")
# ============Benchmark configs==============
parser.add_argument("--benchmark", action="store_true")
parser.add_argument("--iters", default=100, type=int, help="num iter")
Expand Down Expand Up @@ -75,7 +76,8 @@

# transformers version >= 4.32.0 contained the mpt modeling definition.
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py
check_min_version("4.32.0")
# 4.31.0 for ipex.optimize_transformers
check_min_version("4.31.0")

# get model config
if args.peft_model_id:
Expand Down Expand Up @@ -108,6 +110,9 @@
# use peft
args.model = args.peft_model_id if args.peft_model_id is not None else args.model

# Generation
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4)

# mp/sq/woq/bitsandbytes config setting
quantization_config = None
if args.mixed_precision:
Expand All @@ -129,27 +134,13 @@
else:
op_type_dict = {}
excluded_precisions = [] if args.int8_bf16_mixed else ["bf16"]
inputs = None
if config.model_type == "chatglm":
query = "我该怎么办?"
if hasattr(tokenizer, "build_chat_inputs"):
inputs = tokenizer.build_chat_inputs(query)
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|observation|>")]
inputs["eos_token_id"] = eos_token_id
elif hasattr(tokenizer, "build_prompt"):
prompt = tokenizer.build_prompt(query)
inputs = tokenizer([prompt], return_tensors="pt")
else:
inputs = tokenizer([query], return_tensors="pt")

quantization_config = SmoothQuantConfig(
tokenizer=tokenizer, # either two of one, tokenizer or calib_func
alpha="auto" if args.alpha == "auto" else float(args.alpha), # default is 0.5
op_type_dict=op_type_dict, # default is {}
excluded_precisions=excluded_precisions, # default is []
example_inputs=inputs,
)
num_beams=generate_kwargs["num_beams"],
)
elif args.woq:
quantization_config = WeightOnlyQuantConfig(compute_dtype="fp32", weight_dtype="int4_fullrange", group_size=32) #default is A32W4G32
# bitsandbytes
Expand Down Expand Up @@ -189,16 +180,37 @@
else:
user_model = AutoModelForCausalLM.from_pretrained(args.model, config=config, trust_remote_code=args.trust_remote_code, use_llm_runtime=False)

# Generation
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4)


if args.int8 or args.int8_bf16_mixed:
# TorchScript model don't attribute generate method, the wrapper is provided.
import intel_extension_for_pytorch as ipex
user_model = TSModelForCausalLM.from_pretrained(
args.output_dir, file_name="best_model.pt", trust_remote_code=args.trust_remote_code
)
if config.model_type in ["gptj", "opt", "llama"]:
if args.accuracy:
from intel_extension_for_transformers.transformers.utils.utility import TSModelCausalLMForOPTLLM
user_model = TSModelCausalLMForOPTLLM.from_pretrained(
args.output_dir, file_name="best_model.pt", trust_remote_code=args.trust_remote_code
)
else:
torch._C._jit_set_texpr_fuser_enabled(False)
qconfig = ipex.quantization.default_static_qconfig_mapping
with ipex.OnDevice(dtype=torch.float, device="meta"):
user_model = AutoModelForCausalLM.from_pretrained(args.model, use_llm_runtime=False)
user_model = ipex.optimize_transformers(
user_model.eval(),
dtype=torch.float,
inplace=True,
quantization_config=qconfig,
deployment_mode=False,
)
if not hasattr(user_model, "trace_graph"):
print("load_quantized_model")
self_jit = torch.jit.load(args.quantized_model_path)
self_jit = torch.jit.freeze(self_jit.eval())
ipex._set_optimized_model_for_generation(user_model, optimized_model=self_jit)
else:
user_model = TSModelForCausalLM.from_pretrained(
args.output_dir, file_name="best_model.pt", trust_remote_code=args.trust_remote_code
)


if args.benchmark:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ onnxruntime==1.15.0
pymysql
deepface
exifread
protobuf==3.20.2
einops
urllib3
langid
162 changes: 131 additions & 31 deletions intel_extension_for_transformers/transformers/modeling/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@
logger,
LazyImport,
generate_dummy_past_key_values,
get_example_inputs_for_trace,
get_example_inputs,
generate_dummy_past_key_values_for_opt_llm,
get_example_inputs_for_opt_llm,
get_example_inputs_for_chatglm
)
from transformers.utils import is_accelerate_available, is_bitsandbytes_available

Expand All @@ -71,7 +74,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
*model_args,
**kwargs,
)
if load_in_8bit or load_in_4bit:
elif load_in_8bit or load_in_4bit:
use_cpu = (
True
if device_map == torch.device("cpu")
Expand Down Expand Up @@ -131,7 +134,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
and quantization_config.compute_dtype == torch_dtype
), f"Quantization_config.weight_dtype should be 'int8' and compute_dtype should be {torch_dtype}."
if isinstance(quantization_config, MixedPrecisionConfig):
kwargs["torch_dtype"] = torch.bfloat16
if quantization_config.dtype == "float16" or quantization_config.dtype == "fp16":
kwargs["torch_dtype"] = torch.float16
else:
kwargs["torch_dtype"] = torch.bfloat16
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
Expand Down Expand Up @@ -182,6 +188,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
) and model.config.model_type == "chatglm":
model = model.float()
model.eval()
model_type = model.config.model_type
logger.info("Applying SmoothQuant.")
try:
import intel_extension_for_pytorch as ipex
Expand All @@ -190,7 +197,30 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"Please install Intel Extension for PyTorch to accelerate the model inference."
)
assert ipex.__version__ >= "2.1.0+cpu", "Please use Intel Extension for PyTorch >=2.1.0+cpu."
ipex_opt_llm_supported = ["gptj", "opt", "llama"]
calib_func = quantization_config.calib_func
example_inputs = quantization_config.example_inputs
num_beams = quantization_config.num_beams
if quantization_config.ipex_opt_llm is None:
if model_type in ipex_opt_llm_supported:
quantization_config.ipex_opt_llm = True
logger.info("quantization_config.ipex_opt_llm set to True and ipex.optimize_transformers is used.")
logger.warning("The suggest transformers version is 4.31.0 if ipex.optimize_transformers is used.")
else:
quantization_config.ipex_opt_llm = False

# ipex optimize transformers
if quantization_config.ipex_opt_llm:
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
model = ipex.optimize_transformers(
model.eval(),
quantization_config=qconfig,
dtype=torch.float32,
inplace=True,
deployment_mode=False
)
model.eval()
#get calibration function
if calib_func is None:
if quantization_config.tokenizer is None:
logger.error(
Expand Down Expand Up @@ -243,50 +273,119 @@ def collate_batch(batch):
collate_fn=collate_batch,
)

def default_calib_func(model):
"""
This is the default calibration function, the dataset is NeelNanda/pile-10k,
the default calib_iters is 100.
"""
def default_calib_func(model):
"""
This is the default calibration function, the dataset is NeelNanda/pile-10k,
the default calib_iters is 100.
"""
with torch.no_grad():
for i, (input_ids) in enumerate(calib_dataloader):
input_bs, input_len = input_ids.shape
past_key_values = generate_dummy_past_key_values(input_bs, model)
attention_mask = torch.ones(input_bs, input_len + 1)
attention_mask[:, 0] = 0
if i >= calib_iters:
break
model(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
)

for i, (input_ids) in enumerate(calib_dataloader):
input_bs, input_len = input_ids.shape
past_key_values = generate_dummy_past_key_values(input_bs, model)
attention_mask = torch.ones(input_bs, input_len + 1)
attention_mask[:, 0] = 0
if i >= calib_iters:
break
model(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
)
def calib_func_for_chatglm(model):
with torch.no_grad():
for i, (input_ids) in enumerate(calib_dataloader):
input_bs, input_len = input_ids.shape
past_key_values = generate_dummy_past_key_values(input_bs, model)
attention_mask = torch.ones(input_bs, input_len + 1)
attention_mask[:, 0] = 0
position_ids = torch.vstack([torch.arange(input_len) for i in range(input_bs)])
if i >= calib_iters:
break
calib_inputs = {
"input_ids":input_ids,
"attention_mask":attention_mask,
"position_ids": position_ids,
"past_key_values": tuple(past_key_values)
}
model(**calib_inputs)

def calib_func_for_opt_llm(model):
with torch.no_grad():
for i, (input_ids) in enumerate(calib_dataloader):
input_bs, input_len = input_ids.shape
past_key_values = generate_dummy_past_key_values_for_opt_llm(input_bs, model, num_beams)
attention_mask = torch.ones(input_bs, input_len)
position_ids = torch.vstack([torch.arange(input_len) for i in range(input_bs)])
if i >= calib_iters:
break
if model.config.model_type != "opt":
calib_inputs = {
"input_ids":input_ids,
"attention_mask":attention_mask,
"position_ids": position_ids,
"past_key_values": tuple(past_key_values)
}
else:
calib_inputs = {
"input_ids":input_ids,
"attention_mask":attention_mask,
"past_key_values": tuple(past_key_values)
}
model(**calib_inputs)

logger.info(
"The default calibration funcation is used, "
+ "the calibration dataset is NeelNanda/pile-10k,"
+ "batchsize is 1 and calibration iteration is 100."
)
if quantization_config.ipex_opt_llm:
calib_func = calib_func_for_opt_llm
elif model_type == "chatglm":
calib_func = calib_func_for_chatglm
else:
calib_func = default_calib_func
# get example_inputs
if quantization_config.example_inputs is not None:
example_inputs = quantization_config.example_inputs
else:
if quantization_config.ipex_opt_llm:
example_inputs = get_example_inputs_for_opt_llm(
model,
quantization_config=quantization_config
)
else:
if model.config.model_type == "chatglm":
example_inputs = get_example_inputs_for_chatglm(
model,
quantization_config=quantization_config
)
else:
example_inputs = get_example_inputs(
model,
quantization_config=quantization_config
)
# sq recipes
recipes = {
"smooth_quant": True,
"smooth_quant_args": {"alpha": quantization_config.alpha},
}
example_inputs = get_example_inputs_for_trace(model, quantization_config=quantization_config)
from neural_compressor import PostTrainingQuantConfig, quantization

# call inc sq
from neural_compressor import PostTrainingQuantConfig, quantization
conf = PostTrainingQuantConfig(
backend="ipex",
backend=quantization_config.backend, # default is ipex
excluded_precisions=quantization_config.excluded_precisions,
op_type_dict=quantization_config.op_type_dict,
op_name_dict=quantization_config.op_name_dict,
recipes=recipes,
example_inputs=example_inputs,
)
if calib_func is None:
if model.config.torchscript is None or model.config.torchscript is False:
model.config.torchscript = True
logger.info(
"The default calibration funcation is used, "
+ "the calibration dataset is NeelNanda/pile-10k,"
+ "batchsize is 1 and calibration iteration is 100."
"Set model.config.torchscript = True for tracing."
)
calib_func = default_calib_func
else:
calib_func = calib_func
model.config.torchscript = True
model = quantization.fit(
model,
conf,
Expand Down Expand Up @@ -319,5 +418,6 @@ class AutoModel(_BaseQBitsAutoModelClass):
class AutoModelForSeq2SeqLM(_BaseQBitsAutoModelClass):
ORIG_MODEL = transformers.AutoModelForSeq2SeqLM


class GPTBigCodeForCausalLM(_BaseQBitsAutoModelClass):
ORIG_MODEL = transformers.GPTBigCodeForCausalLM
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ class MixedPrecisionConfig:

@dataclass
class SmoothQuantConfig:
backend: str = "ipex"
ipex_opt_llm: bool = None
tokenizer: Any = None
calib_func: Any = None
calib_dataset: str = "NeelNanda/pile-10k"
Expand All @@ -262,3 +264,4 @@ class SmoothQuantConfig:
op_name_dict: dict = None
excluded_precisions: list = field(default_factory=list)
example_inputs: Any = None
num_beams: int = 1

0 comments on commit d2bd4d9

Please sign in to comment.