Skip to content

Commit

Permalink
[Transformers] Refine Model from_pretrained When use_neural_speed (
Browse files Browse the repository at this point in the history
  • Loading branch information
zhentaoyu committed Apr 11, 2024
1 parent 818f7fe commit 39ecf38
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 41 deletions.
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,13 @@ You can also load the low-bit model quantized by GPTQ/AWQ/RTN/AutoRound algorith
from transformers import AutoTokenizer
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, GPTQConfig

# Download Hugging Face GPTQ/AWQ model or use local quantize model
model_name = "PATH_TO_MODEL" # local path to model
woq_config = GPTQConfig(bits=4) # use AwqConfig for AWQ models, and AutoRoundConfig for AutoRound models
# Hugging Face GPTQ/AWQ model or use local quantize model
model_name = "MODEL_NAME_OR_PATH"
prompt = "Once upon a time, a little girl"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
inputs = tokenizer(prompt, return_tensors="pt").input_ids
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
outputs = model.generate(inputs)
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,19 +540,12 @@
else user_model)
args.model = (peft_config.base_model_name_or_path if args.peft_model_id else args.model)
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate
pretrained = ',pretrained=' + args.model
pretrained = ',pretrained=' + args.output_dir if args.woq_loading else ',pretrained=' + args.model
args._commit_hash = "main" if args._commit_hash is None else args._commit_hash
eval_args = "tokenizer=" + args.model + ",dtype=float32" + ",_commit_hash=" + \
args._commit_hash + ",trust_remote_code=" + str(args.trust_remote_code)
if args.use_neural_speed:
eval_args += pretrained
q_conf = user_model.config.quantization_config
if isinstance(q_conf, dict):
q_algo = q_conf.get("quant_method", None)
else:
q_algo = q_conf.quant_method.value
if q_algo.upper() in ["AWQ", "GPTQ", "AUTOROUND"]:
eval_args += ",use_gptq=True"
results = evaluate(
model="hf-causal",
model_args=eval_args,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,18 +620,13 @@ def __init__(self, *args, pretrained, model_format, **kwargs):
self.model_format = model_format
if self.model_format == "neural_speed":
from intel_extension_for_transformers.transformers import RtnConfig, AwqConfig, GPTQConfig, AutoRoundConfig
use_gptq = kwargs.pop("use_gptq", False)
if use_gptq:
self.woq_config = GPTQConfig(bits=4, compute_dtype="int8", weight_dtype="int4")
else:
self.woq_config = RtnConfig(bits=4, compute_dtype="int8", weight_dtype="int4")
super().__init__(*args, pretrained=pretrained, model_format=model_format, **kwargs)

if self.model_format == "neural_speed":
from transformers import AutoTokenizer, TextStreamer
from intel_extension_for_transformers.transformers import AutoModelForCausalLM
self.runtime_model = AutoModelForCausalLM.from_pretrained(pretrained, quantization_config=self.woq_config,
use_neural_speed=True, trust_remote_code=kwargs.get("trust_remote_code", False))
self.runtime_model = AutoModelForCausalLM.from_pretrained(pretrained, use_neural_speed=True,
trust_remote_code=kwargs.get("trust_remote_code", False))

if self.model_format == "onnx":
if not os.path.exists(os.path.join(pretrained, "decoder_model.onnx")) and \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None)
model_hub = kwargs.pop("model_hub", "huggingface")

quantization_config = kwargs.pop("quantization_config", None)
if not isinstance(config, PretrainedConfig):
if model_hub == "modelscope":
import modelscope # pylint: disable=E0401
Expand All @@ -371,7 +372,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):

)

quantization_config = kwargs.pop("quantization_config", None)
if kwargs.get("use_llm_runtime", None) is not None:
use_neural_speed = kwargs.pop("use_llm_runtime", True) and not use_xpu
logger.warning(
Expand Down Expand Up @@ -420,6 +420,29 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
assert quantization_config.get("quant_method", None) in ConfigInit, \
"Detect this model is not a low-bit model."
quantization_config = ConfigInit[quantization_config["quant_method"]].from_dict(quantization_config)
logger.info("Loading Low Bits model by Neural Speed.")
quantization_config.post_init_runtime()

from neural_speed import Model

model = Model()
model.init( # pylint: disable=E1123
pretrained_model_name_or_path,
weight_dtype=quantization_config.weight_dtype,
alg=quantization_config.scheme,
group_size=quantization_config.group_size,
scale_dtype=quantization_config.scale_dtype,
compute_dtype=quantization_config.compute_dtype,
use_ggml=quantization_config.use_ggml,
use_quant=True,
use_gptq=quantization_config.quant_method.value == "gptq"
or quantization_config.quant_method.value == "autoround"
or quantization_config.quant_method.value == "rtn",
use_awq=quantization_config.quant_method.value == "awq",
model_hub=model_hub,
)
model.quantization_config = quantization_config
return model
else:
logger.info(
"quantization_config: {}".format(config.quantization_config)
Expand Down Expand Up @@ -559,7 +582,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
):
logger.info("Applying Weight Only Quantization.")
if use_neural_speed:
logger.info("Using Neural Speed.")
if not isinstance(quantization_config, RtnConfig):
logger.error("Only Supports RTN Quantization in Neural Speed.")
exit(0)
logger.info("Quantize model by Neural Speed with RTN Algorithm.")
quantization_config.post_init_runtime()
from neural_speed import Model

Expand All @@ -573,9 +599,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
compute_dtype=quantization_config.compute_dtype,
use_ggml=quantization_config.use_ggml,
use_quant=True,
use_gptq=quantization_config.quant_method.value == "gptq"
or quantization_config.quant_method.value == "autoround",
use_awq=quantization_config.quant_method.value == "awq",
use_gptq=False,
use_awq=False,
model_hub=model_hub,
)
model.quantization_config = quantization_config
Expand Down Expand Up @@ -939,17 +964,33 @@ def calib_func(model):
)
logger.info("SmoothQuant done.")
else:
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
if (
not torch.cuda.is_available()
or device_map == "cpu"
or device_map == torch.device("cpu")
) and model.config.model_type == "chatglm":
model = model.float()
if use_neural_speed:
logger.info("Using Neural Speed with FP32 model dtype.")
from neural_speed import Model

model.eval()
model = Model()
model.init( # pylint: disable=E1123
pretrained_model_name_or_path,
weight_dtype="fp32",
use_quant=False,
use_gptq=False,
use_awq=False,
model_hub=model_hub,
)
model.quantization_config = None
return model
else:
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
if (
not torch.cuda.is_available()
or device_map == "cpu"
or device_map == torch.device("cpu")
) and model.config.model_type == "chatglm":
model = model.float()

model.eval()
return model

@classmethod
Expand Down
3 changes: 1 addition & 2 deletions tests/Nightly/test_llm_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ def test_llm_runtime(self):
print(tokenizer.decode(pt_generate_ids))

# check output ids
woq_config = RtnConfig(use_quant=False)
itrex_model = AutoModel.from_pretrained(model_name, quantization_config=woq_config, use_neural_speed=True, trust_remote_code=True)
itrex_model = AutoModel.from_pretrained(model_name, quantization_config=None, use_neural_speed=True, trust_remote_code=True)
itrex_generate_ids = itrex_model.generate(inputs.input_ids, do_sample=False, max_new_tokens=100)[0]
print(tokenizer.decode(itrex_generate_ids))
for i in range(len(pt_generate_ids)):
Expand Down
8 changes: 3 additions & 5 deletions tests/Nightly/test_neural_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,15 @@ def test_llm_runtime(self):
print(tokenizer.decode(pt_generate_ids))

# check output ids
woq_config = RtnConfig(use_quant=False)
itrex_model = AutoModel.from_pretrained(model_name, quantization_config=woq_config, use_neural_speed=True, trust_remote_code=True)
itrex_model = AutoModel.from_pretrained(model_name, quantization_config=None, use_neural_speed=True, trust_remote_code=True)
itrex_generate_ids = itrex_model.generate(inputs.input_ids, do_sample=False, max_new_tokens=100)[0]
print(tokenizer.decode(itrex_generate_ids))
for i in range(len(pt_generate_ids)):
self.assertEqual(pt_generate_ids[i], itrex_generate_ids[i])

# check diff of logits
woq_configs = {
"fp32": RtnConfig(use_quant=False),
"fp32": None,
# "ggml_int4": RtnConfig(compute_dtype="int8", weight_dtype="int4",use_ggml=True),
"jblas_int4": RtnConfig(bits=8, compute_dtype="int8", weight_dtype="int4"),
# "jblas_int8": RtnConfig(compute_dtype="bf16", weight_dtype="int8"),
Expand Down Expand Up @@ -118,9 +117,8 @@ def test_beam_search(self):
pt_generate_ids = torch.load("/tf_dataset2/inc-ut/nlptoolkit_ut_model/beam_pt_generate_ids.pth").tolist()

# llm runtime fp32
woq_config = RtnConfig(use_quant=False)
itrex_model = AutoModelForCausalLM.from_pretrained(
model_name, quantization_config=woq_config, trust_remote_code=True)
model_name, quantization_config=None, trust_remote_code=True)
itrex_generate_ids = itrex_model.generate(
inputs.input_ids, num_beams=4, max_new_tokens=128, min_new_tokens=30, early_stopping=True, pad_token=pad_token)
for i in range(len(itrex_generate_ids)):
Expand Down

0 comments on commit 39ecf38

Please sign in to comment.