Skip to content

Commit

Permalink
[LLM] text-generation example support chatglm2&3 (#638)
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss committed Nov 13, 2023
1 parent 554fb99 commit 4525b71
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ torch==2.1.0+cpu
transformers==4.34.1
intel_extension_for_pytorch
git+https://github.com/intel/neural-compressor.git
git+https://github.com/huggingface/optimum.git@deda7e3be202e2a21a7d53dd5c284e47f9c646b7
git+https://github.com/huggingface/optimum-intel.git@74912c0caa7dfb8979f59cbc94b57f5d6a448c30
git+https://github.com/huggingface/optimum-intel.git@34e85a267dcf92dc25348ef53d8b79ae928fc9b8
git+https://github.com/huggingface/optimum.git@359b38ef147b0112081d27c7cafd8b402da6ca27
git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import torch
import logging
from transformers import AutoConfig, AutoTokenizer
from intel_extension_for_transformers.transformers import AutoModelForCausalLM
from intel_extension_for_transformers.transformers import (
AutoModelForCausalLM,
AutoModel
)
from transformers.utils import check_min_version
from optimum.intel.generation.modeling import TSModelForCausalLM
from intel_extension_for_transformers.transformers import (
Expand Down Expand Up @@ -76,15 +79,18 @@

# get model config
config = AutoConfig.from_pretrained(
args.model,
args.output_dir if (args.int8 or args.int8_bf16_mixed) else args.model,
torchscript=True
if (args.sq or args.woq_algo in ['AWQ', 'TEQ'])
if (args.sq or args.woq_algo in ['AWQ', 'TEQ'] or (args.int8 or args.int8_bf16_mixed))
else False, # torchscript will force `return_dict=False` to avoid jit errors
use_cache=True, # to use kv cache.
trust_remote_code=args.trust_remote_code,
revision=args.revision,
)

# chatglm
if config.model_type == "chatglm":
AutoModelForCausalLM = AutoModel

# tokenizer
if config.model_type == "llama":
Expand All @@ -107,19 +113,33 @@
elif re.search("mpt", config.model_type):
op_type_dict = {
"add": {"weight": {"dtype": ["fp32"]}, "activation": {"dtype": ["fp32"]}},
"<built-in function linear>":{"weight": {"dtype": ["fp32"]}, "activation": {"dtype": ["fp32"]}},
"<built-in function linear>": {"weight": {"dtype": ["fp32"]}, "activation": {"dtype": ["fp32"]}},
}
elif re.search("mistral", config.model_type) or re.search("baichuan", config.model_type):
op_type_dict = {".*": {"activation": {"algorithm": "minmax"}}}
else:
op_type_dict = {}
excluded_precisions = [] if args.int8_bf16_mixed else ["bf16"]
if config.model_type == "chatglm":
query = "我该怎么办?"
if hasattr(tokenizer, "build_chat_inputs"):
inputs = tokenizer.build_chat_inputs(query)
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|observation|>")]
inputs["eos_token_id"] = eos_token_id
elif hasattr(tokenizer, "build_prompt"):
prompt = tokenizer.build_prompt(query)
inputs = tokenizer([prompt], return_tensors="pt")
else:
inputs = tokenizer([query], return_tensors="pt")

quantization_config = SmoothQuantConfig(
tokenizer=tokenizer, # either two of one, tokenizer or calib_func
alpha="auto" if args.alpha == "auto" else float(args.alpha), # default is 0.5
op_type_dict=op_type_dict, # default is {}
excluded_precisions=excluded_precisions, # default is []
)
tokenizer=tokenizer, # either two of one, tokenizer or calib_func
alpha="auto" if args.alpha == "auto" else float(args.alpha), # default is 0.5
op_type_dict=op_type_dict, # default is {}
excluded_precisions=excluded_precisions, # default is []
example_inputs=inputs,
)
elif args.woq:
quantization_config = WeightOnlyQuantConfig(compute_dtype="fp32", weight_dtype="int4_fullrange", group_size=32) #default is A32W4G32
# bitsandbytes
Expand Down Expand Up @@ -163,7 +183,7 @@
# TorchScript model don't attribute generate method, the wrapper is provided.
import intel_extension_for_pytorch as ipex
user_model = TSModelForCausalLM.from_pretrained(
args.output_dir, file_name="best_model.pt", trust_remote_code=args.trust_remote_code
args.output_dir, config=config, file_name="best_model.pt", trust_remote_code=args.trust_remote_code
)


Expand Down Expand Up @@ -207,7 +227,8 @@
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
results = evaluate(
model="hf-causal",
model_args='pretrained='+args.model+',tokenizer='+args.model+',dtype=float32',
model_args='pretrained=' + args.model + ',tokenizer=' + args.model + \
',dtype=float32' + ",trust_remote_code=" + args.trust_remote_code,
user_model=user_model,
batch_size=args.batch_size,
tasks=args.tasks,
Expand All @@ -221,4 +242,3 @@
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]))
else:
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]))

Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,11 @@ def evaluate(model,
+ ".db",
)


task_dict = get_task_dict(tasks)

if check_integrity:
run_task_tests(task_list=tasks)

if user_model:
lm.model = user_model

Expand All @@ -169,5 +168,5 @@ def evaluate(model,
output_base_path=output_base_path
)

print(make_table(results))
print(make_table(results))
return results
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def _create_auto_tokenizer(
try:
tokenizer.pad_token = tokenizer.eos_token
except:
print("token.pad_token doesn't set to equal with tokenizer.eos_token.")
print("token.pad_token setting failed.")
return tokenizer

@property
Expand Down Expand Up @@ -673,10 +673,14 @@ def _create_auto_tokenizer(
def _model_call(
self, inputs: TokenSequence, labels: Optional[TokenSequence] = None
) -> TokenSequence:
if hasattr(self._config, "_name_or_path") and self._config._name_or_path == "chatglm":
if hasattr(self._config, "_name_or_path") and self._config._name_or_path == "THUDM/chatglm-6b":
input_bs, input_len = inputs.shape
bos = torch.tensor([130001, 130004]).repeat(input_bs,1)
inputs = torch.cat((inputs, bos),1)
eos = torch.tensor([130001, 130004]).repeat(input_bs, 1)
inputs = torch.cat((inputs, eos), 1)
if hasattr(self._config, "_name_or_path") and self._config._name_or_path == "THUDM/chatglm2-6b":
input_bs, input_len = inputs.shape
bos = torch.tensor([64790, 64792]).repeat(input_bs, 1)
inputs = torch.cat((bos, inputs), 1)
output = self.model(inputs) if self.model_format != "onnx" else \
self.model(inputs, torch.ones(inputs.shape, dtype=torch.int64))
if isinstance(output, tuple):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
if (not torch.cuda.is_available() or
device_map == "cpu" or
device_map == torch.device("cpu")
) and model.config.model_type == "chatglm":
model = model.float()
model.eval()
quantization_config.post_init()
from intel_extension_for_transformers.llm.quantization.utils import (
Expand All @@ -159,6 +164,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
if (not torch.cuda.is_available() or
device_map == "cpu" or
device_map == torch.device("cpu")
) and model.config.model_type == "chatglm":
model = model.float()
model.eval()
logger.info("Applying SmoothQuant.")
try:
Expand Down Expand Up @@ -244,7 +254,7 @@ def default_calib_func(model):
"smooth_quant": True,
"smooth_quant_args": {"alpha": quantization_config.alpha},
}
example_inputs = get_example_inputs_for_trace(model)
example_inputs = get_example_inputs_for_trace(model, quantization_config=quantization_config)
from neural_compressor import PostTrainingQuantConfig, quantization

conf = PostTrainingQuantConfig(
Expand Down Expand Up @@ -276,7 +286,13 @@ def default_calib_func(model):
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
model.eval()
if (not torch.cuda.is_available() or
device_map == "cpu" or
device_map == torch.device("cpu")
) and model.config.model_type == "chatglm":
model = model.float()

model.eval()
return model


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,4 @@ class SmoothQuantConfig:
op_type_dict: dict = None
op_name_dict: dict = None
excluded_precisions: list = field(default_factory=list)
example_inputs: Any = None
47 changes: 32 additions & 15 deletions intel_extension_for_transformers/transformers/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def generate_dummy_past_key_values(input_bs, model):
num_key_value_heads = num_attention_heads
if hasattr(normalized_config, "num_key_value_heads"):
num_key_value_heads = normalized_config.num_key_value_heads
if hasattr(normalized_config, "multi_query_group_num"):
num_key_value_heads = normalized_config.multi_query_group_num

if model.config.model_type == "bloom":
pkv = ()
Expand All @@ -100,38 +102,53 @@ def generate_dummy_past_key_values(input_bs, model):
else:
new_shape = [input_bs * num_key_value_heads, 1, d_k]
pkv = pkv + (torch.ones(size=new_shape),)
elif model.config.model_type == "mistral":
new_shape = [input_bs, num_key_value_heads, 1, d_k]
dummy_tensor = torch.ones(size=new_shape)
pkv = tuple(dummy_tensor for _ in range(nb_pkv))
elif model.config.model_type == "qwen":
new_shape = [input_bs, 1, num_key_value_heads, d_k]
dummy_tensor = torch.ones(size=new_shape)
pkv = tuple(dummy_tensor for _ in range(nb_pkv))
elif model.config.model_type == "chatglm":
new_shape = [1, input_bs, num_key_value_heads, d_k]
dummy_tensor = torch.ones(size=new_shape)
pkv = tuple(dummy_tensor for _ in range(nb_pkv))
else:
new_shape = [input_bs, num_key_value_heads, 1, d_k]
dummy_tensor = torch.ones(size=new_shape)
pkv = tuple(dummy_tensor for _ in range(nb_pkv))
past_key_values = tuple(tuple(pkv) for _ in range(num_layers))
return past_key_values

def get_example_inputs_for_trace(model, return_type="dict"):
def get_example_inputs_for_trace(model, quantization_config=None, return_type="dict"):
"""
Generate the example_input for tracing, support models load from AutoModelForCausalLM.
"""
input_ids = model.dummy_inputs["input_ids"]
input_bs, input_len = input_ids.shape
past_key_values = generate_dummy_past_key_values(input_bs, model)
attention_mask = torch.ones(input_bs, input_len + 1)
attention_mask[:,0] = 0
example_inputs = (input_ids, tuple(past_key_values), attention_mask)
if return_type != "tuple":
if quantization_config and quantization_config.example_inputs is not None:
example_inputs = quantization_config.example_inputs
input_ids = example_inputs["input_ids"]
input_bs, input_len = input_ids.shape
attention_mask = torch.ones(input_bs, input_len + 1)
attention_mask[:, 0] = 0
past_key_values = generate_dummy_past_key_values(input_bs, model)
if "past_key_values" not in example_inputs:
example_inputs["past_key_values"] = tuple(past_key_values)
example_inputs["attention_mask"] = attention_mask
if "position_ids" in example_inputs.keys():
example_inputs.pop("position_ids")
else:
input_ids = model.dummy_inputs["input_ids"]
input_bs, input_len = input_ids.shape
past_key_values = generate_dummy_past_key_values(input_bs, model)
attention_mask = torch.ones(input_bs, input_len + 1)
attention_mask[:, 0] = 0
example_inputs = {
"input_ids": input_ids,
"past_key_values": tuple(past_key_values),
"attention_mask": attention_mask
"attention_mask": attention_mask,
}
# do inference to check example_inputs correct.
out = model(**example_inputs)
if return_type == "tuple":
example_inputs = (example_inputs["input_ids"], example_inputs["past_key_values"],
example_inputs["attention_mask"])

# do inference to check example_inputs correct.
out = model(**example_inputs) if return_type == "dict" else model(*example_inputs)
return example_inputs
11 changes: 11 additions & 0 deletions tests/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,16 @@ def test_quantization_for_llm(self):
use_llm_runtime=False
)
self.assertTrue(isinstance(q_model.model, torch.jit.ScriptModule))
sq_config = SmoothQuantConfig(
tokenizer=tokenizer, # either two of one, tokenizer or calib_func
calib_iters=5,
example_inputs=fp32_model.dummy_inputs
)
q_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
quantization_config=sq_config,
use_llm_runtime=False
)
self.assertTrue(isinstance(q_model.model, torch.jit.ScriptModule))

# weight-only
#RTN
Expand Down Expand Up @@ -368,5 +378,6 @@ def test_quantization_for_llm(self):
output = bit8_model(dummy_input)
self.assertTrue(isclose(float(output[0][0][0][0]), -7.2695, rel_tol=1e-04))


if __name__ == "__main__":
unittest.main()

0 comments on commit 4525b71

Please sign in to comment.