Skip to content

Commit

Permalink
Revert "refactor!: Deprecate name param in PromptTemplate and int…
Browse files Browse the repository at this point in the history
…roduce `template_name` instead (#4810)" (#4834)

This reverts commit f660f41.
  • Loading branch information
bogdankostic committed May 8, 2023
1 parent 6e982e9 commit 5b2ef2a
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 135 deletions.
2 changes: 1 addition & 1 deletion haystack/nodes/answer_generator/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
stop_words = ["\n", "<|endoftext|>"]
if prompt_template is None:
prompt_template = PromptTemplate(
template_name="question-answering-with-examples",
name="question-answering-with-examples",
prompt_text="Please answer the question according to the above context."
"\n===\nContext: {examples_context}\n===\n{examples}\n\n"
"===\nContext: {context}\n===\n{query}",
Expand Down
14 changes: 6 additions & 8 deletions haystack/nodes/prompt/prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
},
)
super().__init__()
self.prompt_templates: Dict[str, PromptTemplate] = {pt.template_name: pt for pt in get_predefined_prompt_templates()} # type: ignore
self.prompt_templates: Dict[str, PromptTemplate] = {pt.name: pt for pt in get_predefined_prompt_templates()} # type: ignore
self.default_prompt_template: Union[str, PromptTemplate, None] = default_prompt_template
self.output_variable: Optional[str] = output_variable
self.model_name_or_path: Union[str, PromptModel] = model_name_or_path
Expand Down Expand Up @@ -193,13 +193,13 @@ def add_prompt_template(self, prompt_template: PromptTemplate) -> None:
:param prompt_template: The PromptTemplate object to be added.
:return: None
"""
if prompt_template.template_name in self.prompt_templates:
if prompt_template.name in self.prompt_templates:
raise ValueError(
f"Prompt template {prompt_template.template_name} already exists. "
f"Prompt template {prompt_template.name} already exists. "
f"Select a different name for this prompt template."
)

self.prompt_templates[prompt_template.template_name] = prompt_template # type: ignore
self.prompt_templates[prompt_template.name] = prompt_template # type: ignore

def remove_prompt_template(self, prompt_template: str) -> PromptTemplate:
"""
Expand Down Expand Up @@ -244,7 +244,7 @@ def is_supported_template(self, prompt_template: Union[str, PromptTemplate]) ->
:param prompt_template: The prompt template to be checked.
:return: True if the prompt template is supported, False otherwise.
"""
template_name = prompt_template if isinstance(prompt_template, str) else prompt_template.template_name
template_name = prompt_template if isinstance(prompt_template, str) else prompt_template.name
return template_name in self.prompt_templates

def get_prompt_template(self, prompt_template: Union[str, PromptTemplate, None] = None) -> Optional[PromptTemplate]:
Expand Down Expand Up @@ -288,9 +288,7 @@ def get_prompt_template(self, prompt_template: Union[str, PromptTemplate, None]
default_prompt_template = self.get_prompt_template()
if default_prompt_template:
output_parser = default_prompt_template.output_parser
return PromptTemplate(
template_name="custom-at-query-time", prompt_text=prompt_text, output_parser=output_parser
)
return PromptTemplate(name="custom-at-query-time", prompt_text=prompt_text, output_parser=output_parser)

def prompt_template_params(self, prompt_template: str) -> List[str]:
"""
Expand Down
67 changes: 19 additions & 48 deletions haystack/nodes/prompt/prompt_template.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from typing import Optional, List, Union, Tuple, Dict, Iterator, Any
import logging
import os
Expand Down Expand Up @@ -185,49 +184,22 @@ class PromptTemplate(BasePromptTemplate, ABC):
"""

def __init__(
self,
prompt_text: str,
template_name: Optional[str] = None,
output_parser: Optional[Union[BaseOutputParser, Dict[str, Any]]] = None,
name: Optional[str] = None,
self, name: str, prompt_text: str, output_parser: Optional[Union[BaseOutputParser, Dict[str, Any]]] = None
):
"""
Creates a PromptTemplate instance.
:param name: The name of the prompt template (for example, "sentiment-analysis", "question-generation"). You can specify your own name but it must be unique.
:param prompt_text: The prompt text, including prompt parameters.
:param template_name: The name of the prompt template (for example, "sentiment-analysis", "question-generation").
You can specify your own name but it must be unique.
:param output_parser: A parser that applied to the model output.
For example, to convert the model output to an Answer object, you can use `AnswerParser`.
Instead of BaseOutputParser instances, you can also pass dictionaries defining the output parsers. For example:
```
output_parser={"type": "AnswerParser", "params": {"pattern": "Answer: (.*)"}},
```
:param name: This parameter is deprecated. Use `template_name` instead.
"""
super().__init__()

if template_name is None and name is None:
raise ValueError("Specify the parameter `template_name`.")

if name is not None and template_name is None:
warnings.warn(
"The parameter `name` is deprecated and will be removed in Haystack 2.0. Use the parameter "
"`template_name` instead.",
category=DeprecationWarning,
stacklevel=2,
)
template_name = name

if name is not None and template_name is not None:
warnings.warn(
"You are using both `name` and `template_name` parameters. The parameter `name` is deprecated and will be "
"removed in Haystack 2.0. Use only the parameter `template_name`. "
f"PromptTemplate will be initialized using the parameter `template_name` ('{template_name}').",
category=DeprecationWarning,
stacklevel=2,
)

# use case when PromptTemplate is loaded from a YAML file, we need to start and end the prompt text with quotes
for strip in PROMPT_TEMPLATE_STRIPS:
prompt_text = prompt_text.strip(strip)
Expand All @@ -240,16 +212,15 @@ def __init__(

self._ast_expression = ast.parse(f'f"{prompt_text}"', mode="eval")

template_name = str(template_name)
ast_validator = _ValidationVisitor(prompt_template_name=template_name)
ast_validator = _ValidationVisitor(prompt_template_name=name)
ast_validator.visit(self._ast_expression)

ast_transformer = _FstringParamsTransformer()
self._ast_expression = ast.fix_missing_locations(ast_transformer.visit(self._ast_expression))
self._prompt_params_functions = ast_transformer.prompt_params_functions
self._used_functions = ast_validator.used_functions

self.template_name = template_name
self.name = name
self.prompt_text = prompt_text
self.prompt_params: List[str] = sorted(
param for param in ast_validator.prompt_params if param not in PROMPT_TEMPLATE_SPECIAL_CHAR_ALIAS
Expand Down Expand Up @@ -364,25 +335,25 @@ def fill(self, *args, **kwargs) -> Iterator[str]:
yield prompt_prepared

def __repr__(self):
return f"PromptTemplate(prompt_text={self.prompt_text}, template_name={self.template_name}, prompt_params={self.prompt_params})"
return f"PromptTemplate(name={self.name}, prompt_text={self.prompt_text}, prompt_params={self.prompt_params})"


def get_predefined_prompt_templates() -> List[PromptTemplate]:
return [
PromptTemplate(
template_name="question-answering",
name="question-answering",
prompt_text="Given the context please answer the question. Context: {join(documents)}; Question: "
"{query}; Answer:",
output_parser=AnswerParser(),
),
PromptTemplate(
template_name="question-answering-per-document",
name="question-answering-per-document",
prompt_text="Given the context please answer the question. Context: {documents}; Question: "
"{query}; Answer:",
output_parser=AnswerParser(),
),
PromptTemplate(
template_name="question-answering-with-references",
name="question-answering-with-references",
prompt_text="Create a concise and informative answer (no more than 50 words) for a given question "
"based solely on the given documents. You must only use information from the given documents. "
"Use an unbiased and journalistic tone. Do not repeat text. Cite the documents using Document[number] notation. "
Expand All @@ -392,7 +363,7 @@ def get_predefined_prompt_templates() -> List[PromptTemplate]:
output_parser=AnswerParser(reference_pattern=r"Document\[(\d+)\]"),
),
PromptTemplate(
template_name="question-answering-with-document-scores",
name="question-answering-with-document-scores",
prompt_text="Answer the following question using the paragraphs below as sources. "
"An answer should be short, a few words at most.\n"
"Paragraphs:\n{documents}\n"
Expand All @@ -403,47 +374,47 @@ def get_predefined_prompt_templates() -> List[PromptTemplate]:
"After having considered all possibilities, the final answer is:\n",
),
PromptTemplate(
template_name="question-generation",
name="question-generation",
prompt_text="Given the context please generate a question. Context: {documents}; Question:",
),
PromptTemplate(
template_name="conditioned-question-generation",
name="conditioned-question-generation",
prompt_text="Please come up with a question for the given context and the answer. "
"Context: {documents}; Answer: {answers}; Question:",
),
PromptTemplate(template_name="summarization", prompt_text="Summarize this document: {documents} Summary:"),
PromptTemplate(name="summarization", prompt_text="Summarize this document: {documents} Summary:"),
PromptTemplate(
template_name="question-answering-check",
name="question-answering-check",
prompt_text="Does the following context contain the answer to the question? "
"Context: {documents}; Question: {query}; Please answer yes or no! Answer:",
output_parser=AnswerParser(),
),
PromptTemplate(
template_name="sentiment-analysis",
name="sentiment-analysis",
prompt_text="Please give a sentiment for this context. Answer with positive, "
"negative or neutral. Context: {documents}; Answer:",
),
PromptTemplate(
template_name="multiple-choice-question-answering",
name="multiple-choice-question-answering",
prompt_text="Question:{query} ; Choose the most suitable option to answer the above question. "
"Options: {options}; Answer:",
output_parser=AnswerParser(),
),
PromptTemplate(
template_name="topic-classification",
name="topic-classification",
prompt_text="Categories: {options}; What category best describes: {documents}; Answer:",
),
PromptTemplate(
template_name="language-detection",
name="language-detection",
prompt_text="Detect the language in the following context and answer with the "
"name of the language. Context: {documents}; Answer:",
),
PromptTemplate(
template_name="translation",
name="translation",
prompt_text="Translate the following context to {target_language}. Context: {documents}; Translation:",
),
PromptTemplate(
template_name="zero-shot-react",
name="zero-shot-react",
prompt_text="You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
"correctly, you have access to the following tools:\n\n"
"{tool_names_with_descriptions}\n\n"
Expand Down
4 changes: 2 additions & 2 deletions test/agents/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def test_tool_result_extraction(reader, retriever_with_docs):
assert result == "Paris" or result == "Madrid"

# PromptNode as a Tool
pt = PromptTemplate("Here is a question: {query}, Answer:", "test")
pt = PromptTemplate("test", "Here is a question: {query}, Answer:")
pn = PromptNode(default_prompt_template=pt)

t = Tool(name="Search", pipeline_or_node=pn, description="N/A", output_variable="results")
Expand Down Expand Up @@ -219,7 +219,7 @@ def test_agent_run(reader, retriever_with_docs, document_store_with_docs):
country_finder = PromptNode(
model_name_or_path=prompt_model,
default_prompt_template=PromptTemplate(
template_name="country_finder",
name="country_finder",
prompt_text="When I give you a name of the city, respond with the country where the city is located.\n"
"City: Rome\nCountry: Italy\n"
"City: Berlin\nCountry: Germany\n"
Expand Down
2 changes: 1 addition & 1 deletion test/nodes/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_openai_answer_generator_custom_template(haystack_openai_config, docs):
pytest.skip("No API key found, skipping test")

lfqa_prompt = PromptTemplate(
template_name="lfqa",
name="lfqa",
prompt_text="""
Synthesize a comprehensive answer from your knowledge and the following topk most relevant paragraphs and the given question.
\n===\Paragraphs: {context}\n===\n{query}""",
Expand Down
2 changes: 1 addition & 1 deletion test/nodes/test_shaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ def test_strings_to_answers_after_prompt_node_yaml(tmp_path):
- name: prompt_template_raw_qa_per_document
type: PromptTemplate
params:
template_name: raw-question-answering-per-document
name: raw-question-answering-per-document
prompt_text: 'Given the context please answer the question. Context: {{documents}}; Question: {{query}}; Answer:'
- name: prompt_node_raw_qa
Expand Down
Loading

0 comments on commit 5b2ef2a

Please sign in to comment.