Skip to content

Commit

Permalink
Update run_generation_gpu_woq.py (#1553)
Browse files Browse the repository at this point in the history
* Update run_generation_gpu_woq.py

Signed-off-by: Dong, Bo <bo1.dong@intel.com>
  • Loading branch information
a32543254 committed May 21, 2024
1 parent 61bee27 commit f4b3a7b
Showing 1 changed file with 9 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
# ============Benchmark configs==============
parser.add_argument("--benchmark", action="store_true")
parser.add_argument("--do_profiling", action="store_true")
parser.add_argument("--disable_optimize_transformers", action="store_true")
parser.add_argument("--profile_token_latency", action="store_true")
parser.add_argument("--iters", default=10, type=int, help="num iter")
parser.add_argument("--num_warmup", default=3, type=int, help="num warmup")
Expand Down Expand Up @@ -211,6 +210,12 @@
user_model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)

enable_optimize_transformers = False
opt_gpu_model_type_list = ["llama", "gptj", "mistral", "qwen"]

if config.model_type in opt_gpu_model_type_list:
enable_optimize_transformers = True

if args.benchmark:
if config.model_type == "qwen":
prompt = "它完成了,并提交了。你可以在Android和网络上玩美味生存。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子."
Expand All @@ -226,7 +231,7 @@
user_model = user_model.to(memory_format=torch.channels_last)
if quantization_config is None:
quantization_config = user_model.quantization_config if hasattr(user_model, "quantization_config") else None
if not args.disable_optimize_transformers:
if enable_optimize_transformers:
print("Optimize with IPEX...")
user_model = ipex.optimize_transformers(
user_model.eval(), device=args.device, inplace=True, quantization_config=quantization_config, dtype=torch_dtype)
Expand All @@ -243,7 +248,7 @@
if args.profile_token_latency:
ipex.transformers.optimize.convert_function(user_model, "greedy_search", _greedy_search)
ipex.transformers.optimize.convert_function(user_model, "_greedy_search", _greedy_search)
if args.disable_optimize_transformers:
if not enable_optimize_transformers:
ipex.transformers.optimize.convert_function(user_model, "beam_search", _beam_search)
ipex.transformers.optimize.convert_function(user_model, "_beam_search", _beam_search)
user_model.config.token_latency = True
Expand Down Expand Up @@ -320,7 +325,7 @@
if user_model is None else user_model
if quantization_config is None:
quantization_config = user_model.quantization_config if hasattr(user_model, "quantization_config") else None
if not args.disable_optimize_transformers:
if enable_optimize_transformers:
print("Optimize with IPEX...")
user_model = ipex.optimize_transformers(
user_model.eval(), device=args.device, inplace=True, quantization_config=quantization_config, dtype=torch_dtype)
Expand Down

0 comments on commit f4b3a7b

Please sign in to comment.