Skip to content

Commit

Permalink
[NeuralChat] Support compatible stats format (#1112)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvliang-intel committed Jan 16, 2024
1 parent d753cb8 commit c0a89c5
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 12 deletions.
2 changes: 2 additions & 0 deletions intel_extension_for_transformers/neural_chat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,8 @@ class GenerationConfig:
max_gpu_memory: int = None
use_fp16: bool = False
ipex_int8: bool = False
return_stats: bool = False
format_version: str = "v2"
task: str = ""

@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def construct_parameters(query, model_name, device, assistant_model, config):
params["use_hpu_graphs"] = config.use_hpu_graphs
params["use_cache"] = config.use_cache
params["ipex_int8"] = config.ipex_int8
params["return_stats"] = config.return_stats
params["format_version"] = config.format_version
params["assistant_model"] = assistant_model
params["device"] = device
return params
Expand Down
34 changes: 23 additions & 11 deletions intel_extension_for_transformers/neural_chat/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ def predict_stream(**params):
Determines whether to utilize Habana Processing Units (HPUs) for accelerated generation.
`use_cache` (bool): Determines whether to utilize kv cache for accelerated generation.
`ipex_int8` (bool): Whether to use IPEX int8 model to inference.
`format_version` (string): the format version of return stats.
Returns:
generator: A generator that yields the generated streaming text.
Expand Down Expand Up @@ -822,6 +823,7 @@ def predict_stream(**params):
use_hpu_graphs = params["use_hpu_graphs"] if "use_hpu_graphs" in params else False
use_cache = params["use_cache"] if "use_cache" in params else True
return_stats = params["return_stats"] if "return_stats" in params else False
format_version = params["format_version"] if "format_version" in params else "v2"
prompt = params["prompt"]
ipex_int8 = params["ipex_int8"] if "ipex_int8" in params else False
model = MODELS[model_name]["model"]
Expand Down Expand Up @@ -1017,17 +1019,27 @@ def generate_output():
0
)
if return_stats:
stats = {
"input_token_len": str(input_token_len),
"output_token_len": str(output_token_len),
"duration": str(duration) + " ms",
"first_token_latency": str(first_token_latency) + " ms",
"msecond_per_token": str(msecond_per_token) + " ms",
}
yield "\n| {:<22} | {:<27} |\n".format("Key", "Value")
yield "| " + "-"*22 + " | " + "-"*27 + " |" + "\n"
for key, value in stats.items():
yield "| {:<22} | {:<27} |\n".format(key, value)
if format_version == "v1":
stats = {
"input_token_len": input_token_len,
"output_token_len": output_token_len,
"duration": duration,
"first_token_latency": first_token_latency,
"msecond_per_token": msecond_per_token,
}
yield "END_OF_STREAM_STATS={}".format(stats)
else:
stats = {
"input_token_len": str(input_token_len),
"output_token_len": str(output_token_len),
"duration": str(duration) + " ms",
"first_token_latency": str(first_token_latency) + " ms",
"msecond_per_token": str(msecond_per_token) + " ms",
}
yield "\n| {:<22} | {:<27} |\n".format("Key", "Value")
yield "| " + "-"*22 + " | " + "-"*27 + " |" + "\n"
for key, value in stats.items():
yield "| {:<22} | {:<27} |\n".format(key, value)


def predict(**params):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,27 @@ def _run_retrieval(local_dir):
_run_retrieval(local_dir="/tf_dataset2/inc-ut/instructor-large")
_run_retrieval(local_dir="/tf_dataset2/inc-ut/bge-base-en-v1.5")

def test_text_chat_stream_return_stats_with_v1_format(self):
config = PipelineConfig(model_name_or_path="facebook/opt-125m")
chatbot = build_chatbot(config)
stream_text = ""
gen_config = GenerationConfig(return_stats=True, format_version="v1")
results, _ = chatbot.predict_stream("Tell me about Intel Xeon Scalable Processors.", config=gen_config)
for text in results:
stream_text += text
print(text)
self.assertIn("END_OF_STREAM_STATS=", stream_text)

def test_text_chat_stream_return_stats(self):
config = PipelineConfig(model_name_or_path="facebook/opt-125m")
chatbot = build_chatbot(config)
stream_text = ""
gen_config = GenerationConfig(return_stats=True)
results, _ = chatbot.predict_stream("Tell me about Intel Xeon Scalable Processors.", config=gen_config)
for text in results:
stream_text += text
print(text)
self.assertIn("| Key | Value |", stream_text)

if __name__ == '__main__':
unittest.main()
10 changes: 9 additions & 1 deletion workflows/chatbot/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ def parse_args():
)
parser.add_argument(
"--return_stats", action='store_true', default=False,)
parser.add_argument(
"--format_version",
type=str,
default="v2",
help="the version of return stats format",
)
args = parser.parse_args()
return args

Expand Down Expand Up @@ -232,7 +238,9 @@ def main():
use_hpu_graphs=args.use_hpu_graphs,
use_cache=args.use_kv_cache,
num_return_sequences=args.num_return_sequences,
ipex_int8=args.ipex_int8
ipex_int8=args.ipex_int8,
return_stats=args.return_stats,
format_version=args.format_version
)

if args.habana:
Expand Down

0 comments on commit c0a89c5

Please sign in to comment.