Skip to content

Commit

Permalink
NeuralChat support IPEX int8 model (#486)
Browse files Browse the repository at this point in the history
* to support ipex int8 model

Signed-off-by: changwangss <chang1.wang@intel.com>
  • Loading branch information
changwangss committed Oct 17, 2023
1 parent aa5d8af commit e133632
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 21 deletions.
4 changes: 3 additions & 1 deletion intel_extension_for_transformers/neural_chat/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def build_chatbot(config: PipelineConfig=None):
elif "opt" in config.model_name_or_path or \
"gpt" in config.model_name_or_path or \
"flan-t5" in config.model_name_or_path or \
"bloom" in config.model_name_or_path:
"bloom" in config.model_name_or_path or \
"starcoder" in config.model_name_or_path:
from .models.base_model import BaseModel
adapter = BaseModel()
else:
Expand Down Expand Up @@ -125,6 +126,7 @@ def build_chatbot(config: PipelineConfig=None):
parameters["device"] = config.device
parameters["use_hpu_graphs"] = config.loading_config.use_hpu_graphs
parameters["cpu_jit"] = config.loading_config.cpu_jit
parameters["ipex_int8"] = config.loading_config.ipex_int8
parameters["use_cache"] = config.loading_config.use_cache
parameters["peft_path"] = config.loading_config.peft_path
parameters["use_deepspeed"] = config.loading_config.use_deepspeed
Expand Down
1 change: 1 addition & 0 deletions intel_extension_for_transformers/neural_chat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ class LoadingModelConfig:
use_hpu_graphs: bool = False
use_cache: bool = True
use_deepspeed: bool = False
ipex_int8: bool = False

@dataclass
class WeightOnlyQuantizationConfig:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def construct_parameters(query, model_name, device, config):
params["force_words_ids"] = config.force_words_ids
params["use_hpu_graphs"] = config.use_hpu_graphs
params["use_cache"] = config.use_cache
params["ipex_int8"] = config.ipex_int8
params["device"] = device
return params

Expand Down Expand Up @@ -93,6 +94,7 @@ def load_model(self, kwargs: dict):
"device": "cuda",
"use_hpu_graphs": True,
"cpu_jit": False,
"ipex_int8": False,
"use_cache": True,
"peft_path": "/path/to/peft",
"use_deepspeed": False
Expand All @@ -109,6 +111,7 @@ def load_model(self, kwargs: dict):
device=kwargs["device"],
use_hpu_graphs=kwargs["use_hpu_graphs"],
cpu_jit=kwargs["cpu_jit"],
ipex_int8=kwargs["ipex_int8"],
use_cache=kwargs["use_cache"],
peft_path=kwargs["peft_path"],
use_deepspeed=kwargs["use_deepspeed"],
Expand Down
81 changes: 61 additions & 20 deletions intel_extension_for_transformers/neural_chat/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def load_model(
device="cpu",
use_hpu_graphs=False,
cpu_jit=False,
ipex_int8=False,
use_cache=True,
peft_path=None,
use_deepspeed=False,
Expand Down Expand Up @@ -340,7 +341,7 @@ def load_model(
if device == "hpu" and use_deepspeed and load_to_meta:
with deepspeed.OnDevice(dtype=torch.bfloat16, device="meta"):
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
elif re.search("flan-t5", model_name, re.IGNORECASE):
elif re.search("flan-t5", model_name, re.IGNORECASE) and not ipex_int8:
with smart_context_manager(use_deepspeed=use_deepspeed):
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
Expand All @@ -358,7 +359,8 @@ def load_model(
or re.search("neural-chat-7b-v1", model_name, re.IGNORECASE)
or re.search("neural-chat-7b-v2", model_name, re.IGNORECASE)
or re.search("qwen", model_name, re.IGNORECASE)
):
or re.search("starcoder", model_name, re.IGNORECASE)
) and not ipex_int8:
with smart_context_manager(use_deepspeed=use_deepspeed):
model = AutoModelForCausalLM.from_pretrained(
model_name,
Expand All @@ -367,6 +369,17 @@ def load_model(
low_cpu_mem_usage=True,
quantization_config=bitsandbytes_quant_config,
)
elif (
(re.search("starcoder", model_name, re.IGNORECASE)
) and ipex_int8
):
with smart_context_manager(use_deepspeed=use_deepspeed):
import intel_extension_for_pytorch
from optimum.intel.generation.modeling import TSModelForCausalLM
model = TSModelForCausalLM.from_pretrained(
model_name,
file_name="best_model.pt",
)
else:
raise ValueError(
f"Unsupported model {model_name}, only supports FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT/QWEN/NEURAL-CHAT now."
Expand Down Expand Up @@ -434,7 +447,7 @@ def load_model(
model = model.to(dtype=torch_dtype)

if device == "cpu":
if torch_dtype == torch.bfloat16:
if torch_dtype == torch.bfloat16 and not ipex_int8:
import intel_extension_for_pytorch as intel_ipex

model = intel_ipex.optimize(
Expand Down Expand Up @@ -471,7 +484,13 @@ def load_model(
if tokenizer.pad_token is None and tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token_id = model.generation_config.eos_token_id

# warmup for int8 model
if ipex_int8:
input_ids = tokenizer("A chat between a curious human and an artificial intelligence assistant.\n"
" Human: Tell me about Intel.\n Assistant:", return_tensors="pt").input_ids.to('cpu')
with torch.inference_mode(), torch.no_grad():
for i in range(2):
model.generate(input_ids, max_new_tokens=32, do_sample=False, temperature=0.9)
MODELS[model_name]["model"] = model
MODELS[model_name]["tokenizer"] = tokenizer
print("Model loaded.")
Expand Down Expand Up @@ -550,6 +569,7 @@ def predict_stream(**params):
`use_hpu_graphs` (bool):
Determines whether to utilize Habana Processing Units (HPUs) for accelerated generation.
`use_cache` (bool): Determines whether to utilize kv cache for accelerated generation.
`ipex_int8` (bool): Whether to use IPEX int8 model to inference.
Returns:
generator: A generator that yields the generated streaming text.
Expand Down Expand Up @@ -579,6 +599,7 @@ def predict_stream(**params):
use_cache = params["use_cache"] if "use_cache" in params else True
return_stats = params["return_stats"] if "return_stats" in params else False
prompt = params["prompt"]
ipex_int8 = params["ipex_int8"] if "ipex_int8" in params else False
model = MODELS[model_name]["model"]
tokenizer = MODELS[model_name]["tokenizer"]
errors_queue = Queue()
Expand Down Expand Up @@ -624,17 +645,28 @@ def generate_output():
context = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=True)
elif device == "xpu":
context = torch.xpu.amp.autocast(enabled=True, dtype=dtype, cache_enabled=True)
with context:
if ipex_int8:
global output_token_len
output_token=model.generate(
**input_tokens,
**generate_kwargs,
streamer=streamer,
generation_config=generation_config,
return_dict_in_generate=True,
)
output_token_len=output_token.sequences[0].shape[-1]
return output_token
**input_tokens,
**generate_kwargs,
streamer=streamer,
generation_config=generation_config,
return_dict_in_generate=True,
)

else:
with context:
global output_token_len
output_token=model.generate(
**input_tokens,
**generate_kwargs,
streamer=streamer,
generation_config=generation_config,
return_dict_in_generate=True,
)
output_token_len=output_token.sequences[0].shape[-1]
return output_token
except Exception as e:
errors_queue.put(e)

Expand Down Expand Up @@ -759,6 +791,7 @@ def predict(**params):
`use_hpu_graphs` (bool):
Determines whether to utilize Habana Processing Units (HPUs) for accelerated generation.
`use_cache` (bool): Determines whether to utilize kv cache for accelerated generation.
`ipex_int8` (bool): Whether to use IPEX int8 model to inference.
Returns:
generator: A generator that yields the generated streaming text.
Expand All @@ -785,6 +818,7 @@ def predict(**params):
force_words_ids = params["force_words_ids"] if "force_words_ids" in params else None
use_hpu_graphs = params["use_hpu_graphs"] if "use_hpu_graphs" in params else False
use_cache = params["use_cache"] if "use_cache" in params else False
ipex_int8 = params["ipex_int8"] if "ipex_int8" in params else False
prompt = params["prompt"]
model = MODELS[model_name]["model"]
tokenizer = MODELS[model_name]["tokenizer"]
Expand Down Expand Up @@ -824,14 +858,21 @@ def predict(**params):
context = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=True)
elif device == "xpu":
context = torch.xpu.amp.autocast(enabled=True, dtype=dtype, cache_enabled=True)

with context:
if ipex_int8:
generation_output = model.generate(
**input_tokens,
**generate_kwargs,
generation_config=generation_config,
return_dict_in_generate=True
)
**input_tokens,
**generate_kwargs,
generation_config=generation_config,
return_dict_in_generate=True
)
else:
with context:
generation_output = model.generate(
**input_tokens,
**generate_kwargs,
generation_config=generation_config,
return_dict_in_generate=True
)
elif device == "hpu":
# Move inputs to target device(s)
input_tokens = prepare_inputs(input_tokens, model.device)
Expand Down
7 changes: 7 additions & 0 deletions workflows/chatbot/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ def parse_args():
action="store_true",
help="Whether to use jit trace. It should speed up generation.",
)
parser.add_argument(
"--ipex_int8",
action="store_true",
help="Whether to use int8 IPEX quantized model. It should speed up generation.",
)
parser.add_argument(
"--seed",
default=27,
Expand Down Expand Up @@ -206,6 +211,7 @@ def main():
loading_config=LoadingModelConfig(
use_hpu_graphs=args.use_hpu_graphs,
cpu_jit=args.jit,
ipex_int8=args.ipex_int8,
use_cache=args.use_kv_cache,
peft_path=args.peft_model_path,
use_deepspeed=True if use_deepspeed and args.habana else False,
Expand All @@ -225,6 +231,7 @@ def main():
use_hpu_graphs=args.use_hpu_graphs,
use_cache=args.use_kv_cache,
num_return_sequences=args.num_return_sequences,
ipex_int8=args.ipex_int8
)

if args.habana:
Expand Down

0 comments on commit e133632

Please sign in to comment.