From fa78e2b0e448595c6c5ca4fb6f88413be5095bb2 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 9 Jan 2023 15:57:04 +0100 Subject: [PATCH] refactor: Change PromptNode registered templates from per class to per instance (#3810) --- haystack/nodes/prompt/prompt_node.py | 163 ++++++++++++--------------- 1 file changed, 72 insertions(+), 91 deletions(-) diff --git a/haystack/nodes/prompt/prompt_node.py b/haystack/nodes/prompt/prompt_node.py index d666775406..6eb1054ec8 100644 --- a/haystack/nodes/prompt/prompt_node.py +++ b/haystack/nodes/prompt/prompt_node.py @@ -53,13 +53,12 @@ class PromptTemplate(BasePromptTemplate, ABC): ```python PromptTemplate(name="sentiment-analysis", prompt_text="Please give a sentiment for this context. Answer with positive, negative - or neutral. Context: $documents; Answer:", - prompt_params=["documents"]) + or neutral. Context: $documents; Answer:") ``` - PromptTemplate declares prompt_params, which are the input parameters that need to be filled in the prompt_text. - For example, in the above example, the prompt_params are ["documents"] and the prompt_text is - "Please give a sentiment..." + PromptTemplate declares optional prompt_params, which are the input parameters that need to be filled in + the prompt_text. For example, in the above example, the prompt_params are ["documents"] and the prompt_text is + "Please give a sentiment...". The prompt_text contains a placeholder $documents. This variable will be filled in runtime with the non-keyword or keyword argument `documents` passed to this PromptTemplate's fill() method. @@ -135,61 +134,6 @@ def fill(self, *args, **kwargs) -> Dict[str, Any]: return template_dict -PREDEFINED_PROMPT_TEMPLATES = [ - PromptTemplate( - name="question-answering", - prompt_text="Given the context please answer the question. Context: $documents; Question: $questions; Answer:", - prompt_params=["documents", "questions"], - ), - PromptTemplate( - name="question-generation", - prompt_text="Given the context please generate a question. Context: $documents; Question:", - prompt_params=["documents"], - ), - PromptTemplate( - name="conditioned-question-generation", - prompt_text="Please come up with a question for the given context and the answer. " - "Context: $documents; Answer: $answers; Question:", - prompt_params=["documents", "answers"], - ), - PromptTemplate( - name="summarization", prompt_text="Summarize this document: $documents Summary:", prompt_params=["documents"] - ), - PromptTemplate( - name="question-answering-check", - prompt_text="Does the following context contain the answer to the question. " - "Context: $documents; Question: $questions; Please answer yes or no! Answer:", - prompt_params=["documents", "questions"], - ), - PromptTemplate( - name="sentiment-analysis", - prompt_text="Please give a sentiment for this context. Answer with positive, " - "negative or neutral. Context: $documents; Answer:", - prompt_params=["documents"], - ), - PromptTemplate( - name="multiple-choice-question-answering", - prompt_text="Question:$questions ; Choose the most suitable option to answer the above question. " - "Options: $options; Answer:", - prompt_params=["questions", "options"], - ), - PromptTemplate( - name="topic-classification", - prompt_text="Categories: $options; What category best describes: $documents; Answer:", - prompt_params=["documents", "options"], - ), - PromptTemplate( - name="language-detection", - prompt_text="Detect the language in the following context and answer with the " - "name of the language. Context: $documents; Answer:", - ), - PromptTemplate( - name="translation", - prompt_text="Translate the following context to $target_language. Context: $documents; Translation:", - ), -] - - class PromptModelInvocationLayer: """ PromptModelInvocationLayer implementations execute a prompt on an underlying model. @@ -524,6 +468,54 @@ def run_batch( raise NotImplementedError("This method should never be implemented in the derived class") +def get_predefined_prompt_templates() -> List[PromptTemplate]: + return [ + PromptTemplate( + name="question-answering", + prompt_text="Given the context please answer the question. Context: $documents; Question: " + "$questions; Answer:", + ), + PromptTemplate( + name="question-generation", + prompt_text="Given the context please generate a question. Context: $documents; Question:", + ), + PromptTemplate( + name="conditioned-question-generation", + prompt_text="Please come up with a question for the given context and the answer. " + "Context: $documents; Answer: $answers; Question:", + ), + PromptTemplate(name="summarization", prompt_text="Summarize this document: $documents Summary:"), + PromptTemplate( + name="question-answering-check", + prompt_text="Does the following context contain the answer to the question. " + "Context: $documents; Question: $questions; Please answer yes or no! Answer:", + ), + PromptTemplate( + name="sentiment-analysis", + prompt_text="Please give a sentiment for this context. Answer with positive, " + "negative or neutral. Context: $documents; Answer:", + ), + PromptTemplate( + name="multiple-choice-question-answering", + prompt_text="Question:$questions ; Choose the most suitable option to answer the above question. " + "Options: $options; Answer:", + ), + PromptTemplate( + name="topic-classification", + prompt_text="Categories: $options; What category best describes: $documents; Answer:", + ), + PromptTemplate( + name="language-detection", + prompt_text="Detect the language in the following context and answer with the " + "name of the language. Context: $documents; Answer:", + ), + PromptTemplate( + name="translation", + prompt_text="Translate the following context to $target_language. Context: $documents; Translation:", + ), + ] + + class PromptNode(BaseComponent): """ The PromptNode class is the central abstraction in Haystack's large language model (LLM) support. PromptNode @@ -546,9 +538,6 @@ class PromptNode(BaseComponent): """ outgoing_edges: int = 1 - prompt_templates: Dict[str, PromptTemplate] = { - prompt_template.name: prompt_template for prompt_template in PREDEFINED_PROMPT_TEMPLATES # type: ignore - } def __init__( self, @@ -562,6 +551,7 @@ def __init__( devices: Optional[List[Union[str, torch.device]]] = None, ): super().__init__() + self.prompt_templates: Dict[str, PromptTemplate] = {pt.name: pt for pt in get_predefined_prompt_templates()} # type: ignore self.default_prompt_template: Union[str, PromptTemplate, None] = default_prompt_template self.output_variable: Optional[str] = output_variable self.model_name_or_path: Union[str, PromptModel] = model_name_or_path @@ -665,34 +655,30 @@ def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, * results.append(item) return results - @classmethod - def add_prompt_template(cls, prompt_template: PromptTemplate) -> None: + def add_prompt_template(self, prompt_template: PromptTemplate) -> None: """ Adds a prompt template to the list of supported prompt templates. :param prompt_template: PromptTemplate object to be added. :return: None """ - if prompt_template.name in cls.prompt_templates: + if prompt_template.name in self.prompt_templates: raise ValueError( f"Prompt template {prompt_template.name} already exists " f"Please select a different name to add this prompt template." ) - cls.prompt_templates[prompt_template.name] = prompt_template # type: ignore + self.prompt_templates[prompt_template.name] = prompt_template # type: ignore - @classmethod - def remove_prompt_template(cls, prompt_template: str) -> PromptTemplate: + def remove_prompt_template(self, prompt_template: str) -> PromptTemplate: """ Removes a prompt template from the list of supported prompt templates. :param prompt_template: Name of the prompt template to be removed. :return: PromptTemplate object that was removed. """ - if prompt_template in [template.name for template in PREDEFINED_PROMPT_TEMPLATES]: - raise ValueError(f"Cannot remove predefined prompt template {prompt_template}") - if prompt_template not in cls.prompt_templates: + if prompt_template not in self.prompt_templates: raise ValueError(f"Prompt template {prompt_template} does not exist") - return cls.prompt_templates.pop(prompt_template) + return self.prompt_templates.pop(prompt_template) def set_default_prompt_template(self, prompt_template: Union[str, PromptTemplate]) -> "PromptNode": """ @@ -708,56 +694,51 @@ def set_default_prompt_template(self, prompt_template: Union[str, PromptTemplate self.default_prompt_template = prompt_template return self - @classmethod - def get_prompt_templates(cls) -> List[PromptTemplate]: + def get_prompt_templates(self) -> List[PromptTemplate]: """ Returns the list of supported prompt templates. :return: List of supported prompt templates. """ - return list(cls.prompt_templates.values()) + return list(self.prompt_templates.values()) - @classmethod - def get_prompt_template_names(cls) -> List[str]: + def get_prompt_template_names(self) -> List[str]: """ Returns the list of supported prompt template names. :return: List of supported prompt template names. """ - return list(cls.prompt_templates.keys()) + return list(self.prompt_templates.keys()) - @classmethod - def is_supported_template(cls, prompt_template: Union[str, PromptTemplate]) -> bool: + def is_supported_template(self, prompt_template: Union[str, PromptTemplate]) -> bool: """ Checks if a prompt template is supported. :param prompt_template: the prompt template to be checked. :return: True if the prompt template is supported, False otherwise. """ template_name = prompt_template if isinstance(prompt_template, str) else prompt_template.name - return template_name in cls.prompt_templates + return template_name in self.prompt_templates - @classmethod - def get_prompt_template(cls, prompt_template_name: str) -> PromptTemplate: + def get_prompt_template(self, prompt_template_name: str) -> PromptTemplate: """ Returns a prompt template by name. :param prompt_template_name: the name of the prompt template to be returned. :return: the prompt template object. """ - if prompt_template_name not in cls.prompt_templates: + if prompt_template_name not in self.prompt_templates: raise ValueError(f"Prompt template {prompt_template_name} not supported") - return cls.prompt_templates[prompt_template_name] + return self.prompt_templates[prompt_template_name] - @classmethod - def prompt_template_params(cls, prompt_template: str) -> List[str]: + def prompt_template_params(self, prompt_template: str) -> List[str]: """ Returns the list of parameters for a prompt template. :param prompt_template: the name of the prompt template. :return: the list of parameters for the prompt template. """ - if not cls.is_supported_template(prompt_template): + if not self.is_supported_template(prompt_template): raise ValueError( - f"{prompt_template} not supported, please select one of: {cls.get_prompt_template_names()}" + f"{prompt_template} not supported, please select one of: {self.get_prompt_template_names()}" ) - return list(cls.prompt_templates[prompt_template].prompt_params) + return list(self.prompt_templates[prompt_template].prompt_params) def __eq__(self, other): if isinstance(other, PromptNode):