diff --git a/intel_extension_for_transformers/neural_chat/models/base_model.py b/intel_extension_for_transformers/neural_chat/models/base_model.py index 213b4bbcb27..cf94ffe6863 100644 --- a/intel_extension_for_transformers/neural_chat/models/base_model.py +++ b/intel_extension_for_transformers/neural_chat/models/base_model.py @@ -391,19 +391,19 @@ def get_conv_template(self, model_path: str, task: str = "") -> Conversation: if self.conv_template: return if not task: - self.conv_template = PromptTemplate(self.get_default_conv_template(model_path).name) + self.conv_template = PromptTemplate(self.get_default_conv_template(model_path).name, clear_history=True) else: - clear_after_gen = True + clear_history = True if task == "completion": name = "alpaca_without_input" elif task == "chat": name = "neural-chat-7b-v2" - clear_after_gen = False + clear_history = False elif task == "summarization": name = "summarization" else: raise NotImplementedError(f"Unsupported task {task}.") - self.conv_template = PromptTemplate(name, clear_after_gen=clear_after_gen) + self.conv_template = PromptTemplate(name, clear_history=clear_history) def prepare_prompt(self, prompt: str, model_path: str, task: str = ""): self.get_conv_template(model_path, task) diff --git a/intel_extension_for_transformers/neural_chat/prompts/prompt.py b/intel_extension_for_transformers/neural_chat/prompts/prompt.py index 89458e9a0bf..25d505401b8 100644 --- a/intel_extension_for_transformers/neural_chat/prompts/prompt.py +++ b/intel_extension_for_transformers/neural_chat/prompts/prompt.py @@ -202,9 +202,9 @@ ) class PromptTemplate: - def __init__(self, name="one_shot", clear_after_gen=False): + def __init__(self, name="one_shot", clear_history=False): self.conv = get_conv_template(name) - self.clear_after_gen = clear_after_gen + self.clear_history = clear_history @property def roles(self): @@ -215,7 +215,7 @@ def append_message(self, role: str, message: str): def get_prompt(self) -> str: res = self.conv.get_prompt() - if self.clear_after_gen: + if self.clear_history: self.clear_messages() return res