From e1e77c580cc5900181b7dae44e3e0ab7cd52b722 Mon Sep 17 00:00:00 2001 From: XuhuiRen Date: Tue, 26 Dec 2023 18:06:29 +0800 Subject: [PATCH] fix Signed-off-by: XuhuiRen --- .../pipeline/plugins/prompt/prompt_template.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/intel_extension_for_transformers/neural_chat/pipeline/plugins/prompt/prompt_template.py b/intel_extension_for_transformers/neural_chat/pipeline/plugins/prompt/prompt_template.py index 6e01432c29b..14a8f24adc5 100644 --- a/intel_extension_for_transformers/neural_chat/pipeline/plugins/prompt/prompt_template.py +++ b/intel_extension_for_transformers/neural_chat/pipeline/plugins/prompt/prompt_template.py @@ -22,14 +22,14 @@ def generate_qa_prompt(query, context=None, history=None): if context and history: conv = PromptTemplate("rag_with_context_memory") - conv.append_message(conv.roles[0], query) conv.append_message(conv.roles[1], context) conv.append_message(conv.roles[2], history) + conv.append_message(conv.roles[0], query) conv.append_message(conv.roles[3], None) elif context: conv = PromptTemplate("rag_with_context_memory") - conv.append_message(conv.roles[0], query) conv.append_message(conv.roles[1], context) + conv.append_message(conv.roles[0], query) conv.append_message(conv.roles[3], None) else: conv = PromptTemplate("rag_without_context") @@ -40,14 +40,14 @@ def generate_qa_prompt(query, context=None, history=None): def generate_qa_enterprise(query, context=None, history=None): if context and history: conv = PromptTemplate("rag_with_threshold") - conv.append_message(conv.roles[0], query) conv.append_message(conv.roles[1], context) conv.append_message(conv.roles[2], history) + conv.append_message(conv.roles[0], query) conv.append_message(conv.roles[3], None) else: conv = PromptTemplate("rag_with_threshold") - conv.append_message(conv.roles[0], query) conv.append_message(conv.roles[1], context) + conv.append_message(conv.roles[0], query) conv.append_message(conv.roles[3], None) return conv.get_prompt() @@ -55,8 +55,8 @@ def generate_qa_enterprise(query, context=None, history=None): def generate_prompt(query, history=None): if history: conv = PromptTemplate("rag_without_context_memory") - conv.append_message(conv.roles[0], query) conv.append_message(conv.roles[1], history) + conv.append_message(conv.roles[0], query) conv.append_message(conv.roles[2], None) else: conv = PromptTemplate("rag_without_context")