Skip to content

Commit

Permalink
refactor: Change PromptNode registered templates from per class to pe…
Browse files Browse the repository at this point in the history
…r instance (#3810)
  • Loading branch information
vblagoje committed Jan 9, 2023
1 parent 6ca88bf commit fa78e2b
Showing 1 changed file with 72 additions and 91 deletions.
163 changes: 72 additions & 91 deletions haystack/nodes/prompt/prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,12 @@ class PromptTemplate(BasePromptTemplate, ABC):
```python
PromptTemplate(name="sentiment-analysis",
prompt_text="Please give a sentiment for this context. Answer with positive, negative
or neutral. Context: $documents; Answer:",
prompt_params=["documents"])
or neutral. Context: $documents; Answer:")
```
PromptTemplate declares prompt_params, which are the input parameters that need to be filled in the prompt_text.
For example, in the above example, the prompt_params are ["documents"] and the prompt_text is
"Please give a sentiment..."
PromptTemplate declares optional prompt_params, which are the input parameters that need to be filled in
the prompt_text. For example, in the above example, the prompt_params are ["documents"] and the prompt_text is
"Please give a sentiment...".
The prompt_text contains a placeholder $documents. This variable will be filled in runtime with the non-keyword
or keyword argument `documents` passed to this PromptTemplate's fill() method.
Expand Down Expand Up @@ -135,61 +134,6 @@ def fill(self, *args, **kwargs) -> Dict[str, Any]:
return template_dict


PREDEFINED_PROMPT_TEMPLATES = [
PromptTemplate(
name="question-answering",
prompt_text="Given the context please answer the question. Context: $documents; Question: $questions; Answer:",
prompt_params=["documents", "questions"],
),
PromptTemplate(
name="question-generation",
prompt_text="Given the context please generate a question. Context: $documents; Question:",
prompt_params=["documents"],
),
PromptTemplate(
name="conditioned-question-generation",
prompt_text="Please come up with a question for the given context and the answer. "
"Context: $documents; Answer: $answers; Question:",
prompt_params=["documents", "answers"],
),
PromptTemplate(
name="summarization", prompt_text="Summarize this document: $documents Summary:", prompt_params=["documents"]
),
PromptTemplate(
name="question-answering-check",
prompt_text="Does the following context contain the answer to the question. "
"Context: $documents; Question: $questions; Please answer yes or no! Answer:",
prompt_params=["documents", "questions"],
),
PromptTemplate(
name="sentiment-analysis",
prompt_text="Please give a sentiment for this context. Answer with positive, "
"negative or neutral. Context: $documents; Answer:",
prompt_params=["documents"],
),
PromptTemplate(
name="multiple-choice-question-answering",
prompt_text="Question:$questions ; Choose the most suitable option to answer the above question. "
"Options: $options; Answer:",
prompt_params=["questions", "options"],
),
PromptTemplate(
name="topic-classification",
prompt_text="Categories: $options; What category best describes: $documents; Answer:",
prompt_params=["documents", "options"],
),
PromptTemplate(
name="language-detection",
prompt_text="Detect the language in the following context and answer with the "
"name of the language. Context: $documents; Answer:",
),
PromptTemplate(
name="translation",
prompt_text="Translate the following context to $target_language. Context: $documents; Translation:",
),
]


class PromptModelInvocationLayer:
"""
PromptModelInvocationLayer implementations execute a prompt on an underlying model.
Expand Down Expand Up @@ -524,6 +468,54 @@ def run_batch(
raise NotImplementedError("This method should never be implemented in the derived class")


def get_predefined_prompt_templates() -> List[PromptTemplate]:
return [
PromptTemplate(
name="question-answering",
prompt_text="Given the context please answer the question. Context: $documents; Question: "
"$questions; Answer:",
),
PromptTemplate(
name="question-generation",
prompt_text="Given the context please generate a question. Context: $documents; Question:",
),
PromptTemplate(
name="conditioned-question-generation",
prompt_text="Please come up with a question for the given context and the answer. "
"Context: $documents; Answer: $answers; Question:",
),
PromptTemplate(name="summarization", prompt_text="Summarize this document: $documents Summary:"),
PromptTemplate(
name="question-answering-check",
prompt_text="Does the following context contain the answer to the question. "
"Context: $documents; Question: $questions; Please answer yes or no! Answer:",
),
PromptTemplate(
name="sentiment-analysis",
prompt_text="Please give a sentiment for this context. Answer with positive, "
"negative or neutral. Context: $documents; Answer:",
),
PromptTemplate(
name="multiple-choice-question-answering",
prompt_text="Question:$questions ; Choose the most suitable option to answer the above question. "
"Options: $options; Answer:",
),
PromptTemplate(
name="topic-classification",
prompt_text="Categories: $options; What category best describes: $documents; Answer:",
),
PromptTemplate(
name="language-detection",
prompt_text="Detect the language in the following context and answer with the "
"name of the language. Context: $documents; Answer:",
),
PromptTemplate(
name="translation",
prompt_text="Translate the following context to $target_language. Context: $documents; Translation:",
),
]


class PromptNode(BaseComponent):
"""
The PromptNode class is the central abstraction in Haystack's large language model (LLM) support. PromptNode
Expand All @@ -546,9 +538,6 @@ class PromptNode(BaseComponent):
"""

outgoing_edges: int = 1
prompt_templates: Dict[str, PromptTemplate] = {
prompt_template.name: prompt_template for prompt_template in PREDEFINED_PROMPT_TEMPLATES # type: ignore
}

def __init__(
self,
Expand All @@ -562,6 +551,7 @@ def __init__(
devices: Optional[List[Union[str, torch.device]]] = None,
):
super().__init__()
self.prompt_templates: Dict[str, PromptTemplate] = {pt.name: pt for pt in get_predefined_prompt_templates()} # type: ignore
self.default_prompt_template: Union[str, PromptTemplate, None] = default_prompt_template
self.output_variable: Optional[str] = output_variable
self.model_name_or_path: Union[str, PromptModel] = model_name_or_path
Expand Down Expand Up @@ -665,34 +655,30 @@ def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, *
results.append(item)
return results

@classmethod
def add_prompt_template(cls, prompt_template: PromptTemplate) -> None:
def add_prompt_template(self, prompt_template: PromptTemplate) -> None:
"""
Adds a prompt template to the list of supported prompt templates.
:param prompt_template: PromptTemplate object to be added.
:return: None
"""
if prompt_template.name in cls.prompt_templates:
if prompt_template.name in self.prompt_templates:
raise ValueError(
f"Prompt template {prompt_template.name} already exists "
f"Please select a different name to add this prompt template."
)

cls.prompt_templates[prompt_template.name] = prompt_template # type: ignore
self.prompt_templates[prompt_template.name] = prompt_template # type: ignore

@classmethod
def remove_prompt_template(cls, prompt_template: str) -> PromptTemplate:
def remove_prompt_template(self, prompt_template: str) -> PromptTemplate:
"""
Removes a prompt template from the list of supported prompt templates.
:param prompt_template: Name of the prompt template to be removed.
:return: PromptTemplate object that was removed.
"""
if prompt_template in [template.name for template in PREDEFINED_PROMPT_TEMPLATES]:
raise ValueError(f"Cannot remove predefined prompt template {prompt_template}")
if prompt_template not in cls.prompt_templates:
if prompt_template not in self.prompt_templates:
raise ValueError(f"Prompt template {prompt_template} does not exist")

return cls.prompt_templates.pop(prompt_template)
return self.prompt_templates.pop(prompt_template)

def set_default_prompt_template(self, prompt_template: Union[str, PromptTemplate]) -> "PromptNode":
"""
Expand All @@ -708,56 +694,51 @@ def set_default_prompt_template(self, prompt_template: Union[str, PromptTemplate
self.default_prompt_template = prompt_template
return self

@classmethod
def get_prompt_templates(cls) -> List[PromptTemplate]:
def get_prompt_templates(self) -> List[PromptTemplate]:
"""
Returns the list of supported prompt templates.
:return: List of supported prompt templates.
"""
return list(cls.prompt_templates.values())
return list(self.prompt_templates.values())

@classmethod
def get_prompt_template_names(cls) -> List[str]:
def get_prompt_template_names(self) -> List[str]:
"""
Returns the list of supported prompt template names.
:return: List of supported prompt template names.
"""
return list(cls.prompt_templates.keys())
return list(self.prompt_templates.keys())

@classmethod
def is_supported_template(cls, prompt_template: Union[str, PromptTemplate]) -> bool:
def is_supported_template(self, prompt_template: Union[str, PromptTemplate]) -> bool:
"""
Checks if a prompt template is supported.
:param prompt_template: the prompt template to be checked.
:return: True if the prompt template is supported, False otherwise.
"""
template_name = prompt_template if isinstance(prompt_template, str) else prompt_template.name
return template_name in cls.prompt_templates
return template_name in self.prompt_templates

@classmethod
def get_prompt_template(cls, prompt_template_name: str) -> PromptTemplate:
def get_prompt_template(self, prompt_template_name: str) -> PromptTemplate:
"""
Returns a prompt template by name.
:param prompt_template_name: the name of the prompt template to be returned.
:return: the prompt template object.
"""
if prompt_template_name not in cls.prompt_templates:
if prompt_template_name not in self.prompt_templates:
raise ValueError(f"Prompt template {prompt_template_name} not supported")
return cls.prompt_templates[prompt_template_name]
return self.prompt_templates[prompt_template_name]

@classmethod
def prompt_template_params(cls, prompt_template: str) -> List[str]:
def prompt_template_params(self, prompt_template: str) -> List[str]:
"""
Returns the list of parameters for a prompt template.
:param prompt_template: the name of the prompt template.
:return: the list of parameters for the prompt template.
"""
if not cls.is_supported_template(prompt_template):
if not self.is_supported_template(prompt_template):
raise ValueError(
f"{prompt_template} not supported, please select one of: {cls.get_prompt_template_names()}"
f"{prompt_template} not supported, please select one of: {self.get_prompt_template_names()}"
)

return list(cls.prompt_templates[prompt_template].prompt_params)
return list(self.prompt_templates[prompt_template].prompt_params)

def __eq__(self, other):
if isinstance(other, PromptNode):
Expand Down

0 comments on commit fa78e2b

Please sign in to comment.