From 419038033e3667d4916af7fbdd1d526ac0ad49d7 Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Mon, 8 May 2023 10:34:29 +0200 Subject: [PATCH] Revert "refactor!: Deprecate `name` param in `PromptTemplate` and introduce `template_name` instead (#4810)" This reverts commit f660f41c0615e6b3064ef3e321f1e5a295fafc1b. --- haystack/nodes/answer_generator/openai.py | 2 +- haystack/nodes/prompt/prompt_node.py | 14 ++--- haystack/nodes/prompt/prompt_template.py | 67 +++++++---------------- test/agents/test_agent.py | 4 +- test/nodes/test_generator.py | 2 +- test/nodes/test_shaper.py | 2 +- test/prompt/test_prompt_node.py | 52 +++++++++--------- test/prompt/test_prompt_template.py | 66 +++++++--------------- 8 files changed, 74 insertions(+), 135 deletions(-) diff --git a/haystack/nodes/answer_generator/openai.py b/haystack/nodes/answer_generator/openai.py index 1b80492e64..ffe5ae9690 100644 --- a/haystack/nodes/answer_generator/openai.py +++ b/haystack/nodes/answer_generator/openai.py @@ -115,7 +115,7 @@ def __init__( stop_words = ["\n", "<|endoftext|>"] if prompt_template is None: prompt_template = PromptTemplate( - template_name="question-answering-with-examples", + name="question-answering-with-examples", prompt_text="Please answer the question according to the above context." "\n===\nContext: {examples_context}\n===\n{examples}\n\n" "===\nContext: {context}\n===\n{query}", diff --git a/haystack/nodes/prompt/prompt_node.py b/haystack/nodes/prompt/prompt_node.py index 00c451e8be..4defbfef9e 100644 --- a/haystack/nodes/prompt/prompt_node.py +++ b/haystack/nodes/prompt/prompt_node.py @@ -97,7 +97,7 @@ def __init__( }, ) super().__init__() - self.prompt_templates: Dict[str, PromptTemplate] = {pt.template_name: pt for pt in get_predefined_prompt_templates()} # type: ignore + self.prompt_templates: Dict[str, PromptTemplate] = {pt.name: pt for pt in get_predefined_prompt_templates()} # type: ignore self.default_prompt_template: Union[str, PromptTemplate, None] = default_prompt_template self.output_variable: Optional[str] = output_variable self.model_name_or_path: Union[str, PromptModel] = model_name_or_path @@ -193,13 +193,13 @@ def add_prompt_template(self, prompt_template: PromptTemplate) -> None: :param prompt_template: The PromptTemplate object to be added. :return: None """ - if prompt_template.template_name in self.prompt_templates: + if prompt_template.name in self.prompt_templates: raise ValueError( - f"Prompt template {prompt_template.template_name} already exists. " + f"Prompt template {prompt_template.name} already exists. " f"Select a different name for this prompt template." ) - self.prompt_templates[prompt_template.template_name] = prompt_template # type: ignore + self.prompt_templates[prompt_template.name] = prompt_template # type: ignore def remove_prompt_template(self, prompt_template: str) -> PromptTemplate: """ @@ -244,7 +244,7 @@ def is_supported_template(self, prompt_template: Union[str, PromptTemplate]) -> :param prompt_template: The prompt template to be checked. :return: True if the prompt template is supported, False otherwise. """ - template_name = prompt_template if isinstance(prompt_template, str) else prompt_template.template_name + template_name = prompt_template if isinstance(prompt_template, str) else prompt_template.name return template_name in self.prompt_templates def get_prompt_template(self, prompt_template: Union[str, PromptTemplate, None] = None) -> Optional[PromptTemplate]: @@ -288,9 +288,7 @@ def get_prompt_template(self, prompt_template: Union[str, PromptTemplate, None] default_prompt_template = self.get_prompt_template() if default_prompt_template: output_parser = default_prompt_template.output_parser - return PromptTemplate( - template_name="custom-at-query-time", prompt_text=prompt_text, output_parser=output_parser - ) + return PromptTemplate(name="custom-at-query-time", prompt_text=prompt_text, output_parser=output_parser) def prompt_template_params(self, prompt_template: str) -> List[str]: """ diff --git a/haystack/nodes/prompt/prompt_template.py b/haystack/nodes/prompt/prompt_template.py index c7d2f30fff..871b8eb219 100644 --- a/haystack/nodes/prompt/prompt_template.py +++ b/haystack/nodes/prompt/prompt_template.py @@ -1,4 +1,3 @@ -import warnings from typing import Optional, List, Union, Tuple, Dict, Iterator, Any import logging import os @@ -185,49 +184,22 @@ class PromptTemplate(BasePromptTemplate, ABC): """ def __init__( - self, - prompt_text: str, - template_name: Optional[str] = None, - output_parser: Optional[Union[BaseOutputParser, Dict[str, Any]]] = None, - name: Optional[str] = None, + self, name: str, prompt_text: str, output_parser: Optional[Union[BaseOutputParser, Dict[str, Any]]] = None ): """ Creates a PromptTemplate instance. + :param name: The name of the prompt template (for example, "sentiment-analysis", "question-generation"). You can specify your own name but it must be unique. :param prompt_text: The prompt text, including prompt parameters. - :param template_name: The name of the prompt template (for example, "sentiment-analysis", "question-generation"). - You can specify your own name but it must be unique. :param output_parser: A parser that applied to the model output. For example, to convert the model output to an Answer object, you can use `AnswerParser`. Instead of BaseOutputParser instances, you can also pass dictionaries defining the output parsers. For example: ``` output_parser={"type": "AnswerParser", "params": {"pattern": "Answer: (.*)"}}, ``` - :param name: This parameter is deprecated. Use `template_name` instead. """ super().__init__() - if template_name is None and name is None: - raise ValueError("Specify the parameter `template_name`.") - - if name is not None and template_name is None: - warnings.warn( - "The parameter `name` is deprecated and will be removed in Haystack 2.0. Use the parameter " - "`template_name` instead.", - category=DeprecationWarning, - stacklevel=2, - ) - template_name = name - - if name is not None and template_name is not None: - warnings.warn( - "You are using both `name` and `template_name` parameters. The parameter `name` is deprecated and will be " - "removed in Haystack 2.0. Use only the parameter `template_name`. " - f"PromptTemplate will be initialized using the parameter `template_name` ('{template_name}').", - category=DeprecationWarning, - stacklevel=2, - ) - # use case when PromptTemplate is loaded from a YAML file, we need to start and end the prompt text with quotes for strip in PROMPT_TEMPLATE_STRIPS: prompt_text = prompt_text.strip(strip) @@ -240,8 +212,7 @@ def __init__( self._ast_expression = ast.parse(f'f"{prompt_text}"', mode="eval") - template_name = str(template_name) - ast_validator = _ValidationVisitor(prompt_template_name=template_name) + ast_validator = _ValidationVisitor(prompt_template_name=name) ast_validator.visit(self._ast_expression) ast_transformer = _FstringParamsTransformer() @@ -249,7 +220,7 @@ def __init__( self._prompt_params_functions = ast_transformer.prompt_params_functions self._used_functions = ast_validator.used_functions - self.template_name = template_name + self.name = name self.prompt_text = prompt_text self.prompt_params: List[str] = sorted( param for param in ast_validator.prompt_params if param not in PROMPT_TEMPLATE_SPECIAL_CHAR_ALIAS @@ -364,25 +335,25 @@ def fill(self, *args, **kwargs) -> Iterator[str]: yield prompt_prepared def __repr__(self): - return f"PromptTemplate(prompt_text={self.prompt_text}, template_name={self.template_name}, prompt_params={self.prompt_params})" + return f"PromptTemplate(name={self.name}, prompt_text={self.prompt_text}, prompt_params={self.prompt_params})" def get_predefined_prompt_templates() -> List[PromptTemplate]: return [ PromptTemplate( - template_name="question-answering", + name="question-answering", prompt_text="Given the context please answer the question. Context: {join(documents)}; Question: " "{query}; Answer:", output_parser=AnswerParser(), ), PromptTemplate( - template_name="question-answering-per-document", + name="question-answering-per-document", prompt_text="Given the context please answer the question. Context: {documents}; Question: " "{query}; Answer:", output_parser=AnswerParser(), ), PromptTemplate( - template_name="question-answering-with-references", + name="question-answering-with-references", prompt_text="Create a concise and informative answer (no more than 50 words) for a given question " "based solely on the given documents. You must only use information from the given documents. " "Use an unbiased and journalistic tone. Do not repeat text. Cite the documents using Document[number] notation. " @@ -392,7 +363,7 @@ def get_predefined_prompt_templates() -> List[PromptTemplate]: output_parser=AnswerParser(reference_pattern=r"Document\[(\d+)\]"), ), PromptTemplate( - template_name="question-answering-with-document-scores", + name="question-answering-with-document-scores", prompt_text="Answer the following question using the paragraphs below as sources. " "An answer should be short, a few words at most.\n" "Paragraphs:\n{documents}\n" @@ -403,47 +374,47 @@ def get_predefined_prompt_templates() -> List[PromptTemplate]: "After having considered all possibilities, the final answer is:\n", ), PromptTemplate( - template_name="question-generation", + name="question-generation", prompt_text="Given the context please generate a question. Context: {documents}; Question:", ), PromptTemplate( - template_name="conditioned-question-generation", + name="conditioned-question-generation", prompt_text="Please come up with a question for the given context and the answer. " "Context: {documents}; Answer: {answers}; Question:", ), - PromptTemplate(template_name="summarization", prompt_text="Summarize this document: {documents} Summary:"), + PromptTemplate(name="summarization", prompt_text="Summarize this document: {documents} Summary:"), PromptTemplate( - template_name="question-answering-check", + name="question-answering-check", prompt_text="Does the following context contain the answer to the question? " "Context: {documents}; Question: {query}; Please answer yes or no! Answer:", output_parser=AnswerParser(), ), PromptTemplate( - template_name="sentiment-analysis", + name="sentiment-analysis", prompt_text="Please give a sentiment for this context. Answer with positive, " "negative or neutral. Context: {documents}; Answer:", ), PromptTemplate( - template_name="multiple-choice-question-answering", + name="multiple-choice-question-answering", prompt_text="Question:{query} ; Choose the most suitable option to answer the above question. " "Options: {options}; Answer:", output_parser=AnswerParser(), ), PromptTemplate( - template_name="topic-classification", + name="topic-classification", prompt_text="Categories: {options}; What category best describes: {documents}; Answer:", ), PromptTemplate( - template_name="language-detection", + name="language-detection", prompt_text="Detect the language in the following context and answer with the " "name of the language. Context: {documents}; Answer:", ), PromptTemplate( - template_name="translation", + name="translation", prompt_text="Translate the following context to {target_language}. Context: {documents}; Translation:", ), PromptTemplate( - template_name="zero-shot-react", + name="zero-shot-react", prompt_text="You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions " "correctly, you have access to the following tools:\n\n" "{tool_names_with_descriptions}\n\n" diff --git a/test/agents/test_agent.py b/test/agents/test_agent.py index 5ae86f15bf..dce4bf5d39 100644 --- a/test/agents/test_agent.py +++ b/test/agents/test_agent.py @@ -185,7 +185,7 @@ def test_tool_result_extraction(reader, retriever_with_docs): assert result == "Paris" or result == "Madrid" # PromptNode as a Tool - pt = PromptTemplate("Here is a question: {query}, Answer:", "test") + pt = PromptTemplate("test", "Here is a question: {query}, Answer:") pn = PromptNode(default_prompt_template=pt) t = Tool(name="Search", pipeline_or_node=pn, description="N/A", output_variable="results") @@ -219,7 +219,7 @@ def test_agent_run(reader, retriever_with_docs, document_store_with_docs): country_finder = PromptNode( model_name_or_path=prompt_model, default_prompt_template=PromptTemplate( - template_name="country_finder", + name="country_finder", prompt_text="When I give you a name of the city, respond with the country where the city is located.\n" "City: Rome\nCountry: Italy\n" "City: Berlin\nCountry: Germany\n" diff --git a/test/nodes/test_generator.py b/test/nodes/test_generator.py index feb58cd8ae..dde821fb0f 100644 --- a/test/nodes/test_generator.py +++ b/test/nodes/test_generator.py @@ -158,7 +158,7 @@ def test_openai_answer_generator_custom_template(haystack_openai_config, docs): pytest.skip("No API key found, skipping test") lfqa_prompt = PromptTemplate( - template_name="lfqa", + name="lfqa", prompt_text=""" Synthesize a comprehensive answer from your knowledge and the following topk most relevant paragraphs and the given question. \n===\Paragraphs: {context}\n===\n{query}""", diff --git a/test/nodes/test_shaper.py b/test/nodes/test_shaper.py index 6eae21be05..9c29ff29cb 100644 --- a/test/nodes/test_shaper.py +++ b/test/nodes/test_shaper.py @@ -838,7 +838,7 @@ def test_strings_to_answers_after_prompt_node_yaml(tmp_path): - name: prompt_template_raw_qa_per_document type: PromptTemplate params: - template_name: raw-question-answering-per-document + name: raw-question-answering-per-document prompt_text: 'Given the context please answer the question. Context: {{documents}}; Question: {{query}}; Answer:' - name: prompt_node_raw_qa diff --git a/test/prompt/test_prompt_node.py b/test/prompt/test_prompt_node.py index 34ecc36657..11a3834414 100644 --- a/test/prompt/test_prompt_node.py +++ b/test/prompt/test_prompt_node.py @@ -33,7 +33,7 @@ def test_add_and_remove_template(): assert len(node.get_prompt_template_names()) == 14 # Add a fake template - fake_template = PromptTemplate(template_name="fake-template", prompt_text="Fake prompt") + fake_template = PromptTemplate(name="fake-template", prompt_text="Fake prompt") node.add_prompt_template(fake_template) assert len(node.get_prompt_template_names()) == 15 assert "fake-template" in node.get_prompt_template_names() @@ -64,7 +64,7 @@ def test_prompt_after_adding_template(mock_model): # Create a template template = PromptTemplate( - template_name="fake-sentiment-analysis", + name="fake-sentiment-analysis", prompt_text="Please give a sentiment for this context. Answer with positive, " "negative or neutral. Context: {documents}; Answer:", ) @@ -85,7 +85,7 @@ def test_prompt_passing_template(mock_model): # Create a template template = PromptTemplate( - template_name="fake-sentiment-analysis", + name="fake-sentiment-analysis", prompt_text="Please give a sentiment for this context. Answer with positive, " "negative or neutral. Context: {documents}; Answer:", ) @@ -142,10 +142,10 @@ def test_get_prompt_template_without_default_template(mock_model): assert node.get_prompt_template() is None template = node.get_prompt_template("question-answering") - assert template.template_name == "question-answering" + assert template.name == "question-answering" - template = node.get_prompt_template(PromptTemplate(template_name="fake-template", prompt_text="")) - assert template.template_name == "fake-template" + template = node.get_prompt_template(PromptTemplate(name="fake-template", prompt_text="")) + assert template.name == "fake-template" with pytest.raises(ValueError) as e: node.get_prompt_template("some-unsupported-template") @@ -153,14 +153,14 @@ def test_get_prompt_template_without_default_template(mock_model): fake_yaml_prompt = "name: fake-yaml-template\nprompt_text: fake prompt text" template = node.get_prompt_template(fake_yaml_prompt) - assert template.template_name == "fake-yaml-template" + assert template.name == "fake-yaml-template" fake_yaml_prompt = "- prompt_text: fake prompt text" template = node.get_prompt_template(fake_yaml_prompt) - assert template.template_name == "custom-at-query-time" + assert template.name == "custom-at-query-time" template = node.get_prompt_template("some prompt") - assert template.template_name == "custom-at-query-time" + assert template.name == "custom-at-query-time" @pytest.mark.unit @@ -170,13 +170,13 @@ def test_get_prompt_template_with_default_template(mock_model): node.set_default_prompt_template("question-answering") template = node.get_prompt_template() - assert template.template_name == "question-answering" + assert template.name == "question-answering" template = node.get_prompt_template("sentiment-analysis") - assert template.template_name == "sentiment-analysis" + assert template.name == "sentiment-analysis" - template = node.get_prompt_template(PromptTemplate(template_name="fake-template", prompt_text="")) - assert template.template_name == "fake-template" + template = node.get_prompt_template(PromptTemplate(name="fake-template", prompt_text="")) + assert template.name == "fake-template" with pytest.raises(ValueError) as e: node.get_prompt_template("some-unsupported-template") @@ -184,14 +184,14 @@ def test_get_prompt_template_with_default_template(mock_model): fake_yaml_prompt = "name: fake-yaml-template\nprompt_text: fake prompt text" template = node.get_prompt_template(fake_yaml_prompt) - assert template.template_name == "fake-yaml-template" + assert template.name == "fake-yaml-template" fake_yaml_prompt = "- prompt_text: fake prompt text" template = node.get_prompt_template(fake_yaml_prompt) - assert template.template_name == "custom-at-query-time" + assert template.name == "custom-at-query-time" template = node.get_prompt_template("some prompt") - assert template.template_name == "custom-at-query-time" + assert template.name == "custom-at-query-time" @pytest.mark.integration @@ -290,7 +290,7 @@ def test_stop_words(prompt_model): assert "capital" in r[0] or "Germany" in r[0] tt = PromptTemplate( - template_name="question-generation-copy", + name="question-generation-copy", prompt_text="Given the context please generate a question. Context: {documents}; Question:", ) # with custom prompt template @@ -588,7 +588,7 @@ def test_pipeline_with_prompt_template_and_nested_shaper_yaml(tmp_path): - name: template_with_nested_shaper type: PromptTemplate params: - template_name: custom-template-with-nested-shaper + name: custom-template-with-nested-shaper prompt_text: "Given the context please answer the question. Context: {{documents}}; Question: {{query}}; Answer: " output_parser: type: AnswerParser @@ -653,7 +653,7 @@ def test_complex_pipeline_with_qa(prompt_model): skip_test_for_invalid_key(prompt_model) prompt_template = PromptTemplate( - template_name="question-answering-new", + name="question-answering-new", prompt_text="Given the context please answer the question. Context: {documents}; Question: {query}; Answer:", ) node = PromptNode(prompt_model, default_prompt_template=prompt_template) @@ -853,7 +853,7 @@ def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_ - name: question_generation_template type: PromptTemplate params: - template_name: question-generation-new + name: question-generation-new prompt_text: "Given the context please generate a question. Context: {{documents}}; Question:" - name: p1 params: @@ -933,7 +933,7 @@ def run_batch( - name: question_generation_template type: PromptTemplate params: - template_name: question-generation-new + name: question-generation-new prompt_text: "Given the context please generate a question. Context: {{documents}}; Question:" - name: p1 params: @@ -1007,7 +1007,7 @@ def test_complex_pipeline_with_all_features(tmp_path, haystack_openai_config): - name: question_generation_template type: PromptTemplate params: - template_name: question-generation-new + name: question-generation-new prompt_text: "Given the context please generate a question. Context: {{documents}}; Question:" - name: p1 params: @@ -1084,7 +1084,7 @@ class TestTokenLimit: @pytest.mark.integration def test_hf_token_limit_warning(self, caplog): prompt_template = PromptTemplate( - template_name="too-long-temp", prompt_text="Repeating text" * 200 + "Docs: {documents}; Answer:" + name="too-long-temp", prompt_text="Repeating text" * 200 + "Docs: {documents}; Answer:" ) with caplog.at_level(logging.WARNING): node = PromptNode("google/flan-t5-small", devices=["cpu"]) @@ -1098,9 +1098,7 @@ def test_hf_token_limit_warning(self, caplog): reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", ) def test_openai_token_limit_warning(self, caplog): - tt = PromptTemplate( - template_name="too-long-temp", prompt_text="Repeating text" * 200 + "Docs: {documents}; Answer:" - ) + tt = PromptTemplate(name="too-long-temp", prompt_text="Repeating text" * 200 + "Docs: {documents}; Answer:") prompt_node = PromptNode("text-ada-001", max_length=2000, api_key=os.environ.get("OPENAI_API_KEY", "")) with caplog.at_level(logging.WARNING): _ = prompt_node.prompt(tt, documents=["Berlin is an amazing city."]) @@ -1155,7 +1153,7 @@ def test_simple_pipeline_batch_query_multiple_doc_list(self, prompt_model): skip_test_for_invalid_key(prompt_model) prompt_template = PromptTemplate( - template_name="question-answering-new", + name="question-answering-new", prompt_text="Given the context please answer the question. Context: {documents}; Question: {query}; Answer:", ) node = PromptNode(prompt_model, default_prompt_template=prompt_template) diff --git a/test/prompt/test_prompt_template.py b/test/prompt/test_prompt_template.py index dc56ce114e..d86eba6c36 100644 --- a/test/prompt/test_prompt_template.py +++ b/test/prompt/test_prompt_template.py @@ -10,32 +10,30 @@ from haystack.pipelines.base import Pipeline from haystack.schema import Answer, Document -from ..conftest import fail_at_version - @pytest.mark.unit def test_prompt_templates(): - p = PromptTemplate("Here is some fake template with variable {foo}", "t1") + p = PromptTemplate("t1", "Here is some fake template with variable {foo}") assert set(p.prompt_params) == {"foo"} - p = PromptTemplate("Here is some fake template with variable {foo} and {bar}", "t3") + p = PromptTemplate("t3", "Here is some fake template with variable {foo} and {bar}") assert set(p.prompt_params) == {"foo", "bar"} - p = PromptTemplate("Here is some fake template with variable {foo1} and {bar2}", "t4") + p = PromptTemplate("t4", "Here is some fake template with variable {foo1} and {bar2}") assert set(p.prompt_params) == {"foo1", "bar2"} - p = PromptTemplate("Here is some fake template with variable {foo_1} and {bar_2}", "t4") + p = PromptTemplate("t4", "Here is some fake template with variable {foo_1} and {bar_2}") assert set(p.prompt_params) == {"foo_1", "bar_2"} - p = PromptTemplate("Here is some fake template with variable {Foo_1} and {Bar_2}", "t4") + p = PromptTemplate("t4", "Here is some fake template with variable {Foo_1} and {Bar_2}") assert set(p.prompt_params) == {"Foo_1", "Bar_2"} - p = PromptTemplate("'Here is some fake template with variable {baz}'", "t4") + p = PromptTemplate("t4", "'Here is some fake template with variable {baz}'") assert set(p.prompt_params) == {"baz"} # strip single quotes, happens in YAML as we need to use single quotes for the template string assert p.prompt_text == "Here is some fake template with variable {baz}" - p = PromptTemplate('"Here is some fake template with variable {baz}"', "t4") + p = PromptTemplate("t4", '"Here is some fake template with variable {baz}"') assert set(p.prompt_params) == {"baz"} # strip double quotes, happens in YAML as we need to use single quotes for the template string assert p.prompt_text == "Here is some fake template with variable {baz}" @@ -43,7 +41,7 @@ def test_prompt_templates(): @pytest.mark.unit def test_missing_prompt_template_params(): - template = PromptTemplate("Here is some fake template with variable {foo} and {bar}", "missing_params") + template = PromptTemplate("missing_params", "Here is some fake template with variable {foo} and {bar}") # both params provided - ok template.prepare(foo="foo", bar="bar") @@ -64,8 +62,8 @@ def test_missing_prompt_template_params(): @pytest.mark.unit def test_prompt_template_repr(): - p = PromptTemplate("Here is variable {baz}", "t") - desired_repr = "PromptTemplate(prompt_text=Here is variable {baz}, template_name=t, prompt_params=['baz'])" + p = PromptTemplate("t", "Here is variable {baz}") + desired_repr = "PromptTemplate(name=t, prompt_text=Here is variable {baz}, prompt_params=['baz'])" assert repr(p) == desired_repr assert str(p) == desired_repr @@ -74,7 +72,7 @@ def test_prompt_template_repr(): @patch("haystack.nodes.prompt.prompt_node.PromptModel") def test_prompt_template_deserialization(mock_prompt_model): custom_prompt_template = PromptTemplate( - template_name="custom-question-answering", + name="custom-question-answering", prompt_text="Given the context please answer the question. Context: {context}; Question: {query}; Answer:", output_parser=AnswerParser(), ) @@ -90,7 +88,7 @@ def test_prompt_template_deserialization(mock_prompt_model): loaded_generator = loaded_pipe.get_node("Generator") assert isinstance(loaded_generator, PromptNode) assert isinstance(loaded_generator.default_prompt_template, PromptTemplate) - assert loaded_generator.default_prompt_template.template_name == "custom-question-answering" + assert loaded_generator.default_prompt_template.name == "custom-question-answering" assert ( loaded_generator.default_prompt_template.prompt_text == "Given the context please answer the question. Context: {context}; Question: {query}; Answer:" @@ -137,7 +135,7 @@ class TestPromptTemplateSyntax: def test_prompt_template_syntax_parser( self, prompt_text: str, expected_prompt_params: Set[str], expected_used_functions: Set[str] ): - prompt_template = PromptTemplate(template_name="test", prompt_text=prompt_text) + prompt_template = PromptTemplate(name="test", prompt_text=prompt_text) assert set(prompt_template.prompt_params) == expected_prompt_params assert set(prompt_template._used_functions) == expected_used_functions @@ -218,7 +216,7 @@ def test_prompt_template_syntax_parser( def test_prompt_template_syntax_fill( self, prompt_text: str, documents: List[Document], query: str, expected_prompts: List[str] ): - prompt_template = PromptTemplate(template_name="test", prompt_text=prompt_text) + prompt_template = PromptTemplate(name="test", prompt_text=prompt_text) prompts = [prompt for prompt in prompt_template.fill(documents=documents, query=query)] assert prompts == expected_prompts @@ -245,7 +243,7 @@ def test_prompt_template_syntax_fill( ], ) def test_join(self, prompt_text: str, documents: List[Document], expected_prompts: List[str]): - prompt_template = PromptTemplate(template_name="test", prompt_text=prompt_text) + prompt_template = PromptTemplate(name="test", prompt_text=prompt_text) prompts = [prompt for prompt in prompt_template.fill(documents=documents)] assert prompts == expected_prompts @@ -278,7 +276,7 @@ def test_join(self, prompt_text: str, documents: List[Document], expected_prompt ], ) def test_to_strings(self, prompt_text: str, documents: List[Document], expected_prompts: List[str]): - prompt_template = PromptTemplate(template_name="test", prompt_text=prompt_text) + prompt_template = PromptTemplate(name="test", prompt_text=prompt_text) prompts = [prompt for prompt in prompt_template.fill(documents=documents)] assert prompts == expected_prompts @@ -302,7 +300,7 @@ def test_prompt_template_syntax_init_raises( self, prompt_text: str, exc_type: Type[BaseException], expected_exc_match: str ): with pytest.raises(exc_type, match=expected_exc_match): - PromptTemplate(template_name="test", prompt_text=prompt_text) + PromptTemplate(name="test", prompt_text=prompt_text) @pytest.mark.unit @pytest.mark.parametrize( @@ -318,7 +316,7 @@ def test_prompt_template_syntax_fill_raises( expected_exc_match: str, ): with pytest.raises(exc_type, match=expected_exc_match): - prompt_template = PromptTemplate(template_name="test", prompt_text=prompt_text) + prompt_template = PromptTemplate(name="test", prompt_text=prompt_text) next(prompt_template.fill(documents=documents, query=query)) @pytest.mark.unit @@ -339,32 +337,6 @@ def test_prompt_template_syntax_fill_raises( def test_prompt_template_syntax_fill_ignores_dangerous_input( self, prompt_text: str, documents: List[Document], query: str, expected_prompts: List[str] ): - prompt_template = PromptTemplate(template_name="test", prompt_text=prompt_text) + prompt_template = PromptTemplate(name="test", prompt_text=prompt_text) prompts = [prompt for prompt in prompt_template.fill(documents=documents, query=query)] assert prompts == expected_prompts - - @pytest.mark.unit - @fail_at_version(2, 0) - def test_name_parameter_deprecated(self): - with pytest.warns(DeprecationWarning) as w: - prompt_template = PromptTemplate(name="test", prompt_text="test") - assert "Use the parameter `template_name` instead" in str(w[0].message) - assert prompt_template.template_name == "test" - - @pytest.mark.unit - @fail_at_version(2, 0) - def test_passing_name_and_template_name_parameter(self): - with pytest.warns(DeprecationWarning) as w: - prompt_template = PromptTemplate(name="test", template_name="test2", prompt_text="test") - assert ( - "Use only the parameter `template_name`. PromptTemplate will be initialized " - "using the parameter `template_name`" in str(w[0].message) - ) - assert prompt_template.template_name == "test2" - - @pytest.mark.unit - @fail_at_version(2, 0) - def test_passing_neither_name_nor_template_name_parameter(self): - with pytest.raises(ValueError) as e: - PromptTemplate(prompt_text="test") - assert "Please specify the parameter `template_name`" in str(e[0].message)