From 6e0000732df9d38e2e9f4ab0511c05651826dcb1 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Date: Tue, 16 May 2023 10:57:41 +0200 Subject: [PATCH 01/13] feat: add BLIP support in `TransformersImageToText` (#4912) * add blip support * fix typo Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> --- haystack/nodes/image_to_text/transformers.py | 15 ++++++++------- test/nodes/test_image_to_text.py | 9 --------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/haystack/nodes/image_to_text/transformers.py b/haystack/nodes/image_to_text/transformers.py index 25357c5411..d7f9c04bce 100644 --- a/haystack/nodes/image_to_text/transformers.py +++ b/haystack/nodes/image_to_text/transformers.py @@ -17,7 +17,11 @@ # supported models classes should be extended when HF image-to-text pipeline willl support more classes # see https://github.com/huggingface/transformers/issues/21110 -SUPPORTED_MODELS_CLASSES = ["VisionEncoderDecoderModel"] +SUPPORTED_MODELS_CLASSES = [ + "VisionEncoderDecoderModel", + "BlipForConditionalGeneration", + "Blip2ForConditionalGeneration", +] UNSUPPORTED_MODEL_MESSAGE = ( f"The supported classes are: {SUPPORTED_MODELS_CLASSES}. \n" @@ -33,8 +37,6 @@ class TransformersImageToText(BaseImageToText): """ A transformer-based model to generate captions for images using the Hugging Face's transformers framework. - Currently, this node supports `VisionEncoderDecoderModel` models. - **Example** ```python @@ -64,7 +66,7 @@ class TransformersImageToText(BaseImageToText): def __init__( self, - model_name_or_path: str = "nlpconnect/vit-gpt2-image-captioning", + model_name_or_path: str = "Salesforce/blip-image-captioning-base", model_version: Optional[str] = None, generation_kwargs: Optional[dict] = None, use_gpu: bool = True, @@ -74,15 +76,14 @@ def __init__( devices: Optional[List[Union[str, torch.device]]] = None, ): """ - Load a `VisionEncoderDecoderModel` model from transformers. + Load an Image-to-Text model from transformers. :param model_name_or_path: Directory of a saved model or the name of a public model. - Currently, only `VisionEncoderDecoderModel` models are supported. To find these models: 1. Visit [Hugging Face image to text models](https://huggingface.co/models?pipeline_tag=image-to-text).` 2. Open the model you want to check. 3. On the model page, go to the "Files and Versions" tab. - 4. Open the `config.json` file and make sure the `architectures` field contains `VisionEncoderDecoderModel`. + 4. Open the `config.json` file and make sure the `architectures` field contains `VisionEncoderDecoderModel`, `BlipForConditionalGeneration`, or `Blip2ForConditionalGeneration`. :param model_version: The version of the model to use from the Hugging Face model hub. This can be the tag name, branch name, or commit hash. :param generation_kwargs: Dictionary containing arguments for the `generate()` method of the Hugging Face model. See [generate()](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationMixin.generate) in Hugging Face documentation. diff --git a/test/nodes/test_image_to_text.py b/test/nodes/test_image_to_text.py index 91ffe11f64..689fe7e13f 100644 --- a/test/nodes/test_image_to_text.py +++ b/test/nodes/test_image_to_text.py @@ -91,12 +91,3 @@ def test_image_to_text_unsupported_model_after_loading(): match="The model 'deepset/minilm-uncased-squad2' \(class 'BertForQuestionAnswering'\) is not supported for ImageToText", ): _ = TransformersImageToText(model_name_or_path="deepset/minilm-uncased-squad2") - - -@pytest.mark.integration -def test_image_to_text_unsupported_model_before_loading(): - with pytest.raises( - ValueError, - match=r"The model '.*' \(class '.*'\) is not supported for ImageToText. The supported classes are: \['VisionEncoderDecoderModel'\]", - ): - _ = TransformersImageToText(model_name_or_path="Salesforce/blip-image-captioning-base") From 37cadd702a6706305652847612a700ff27cd7c10 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 16 May 2023 13:35:19 +0200 Subject: [PATCH 02/13] fix: Make sure summary memory is cumulative (#4932) * Fix summary memory not being cummulative * PR feedback - Julian --- .../memory/conversation_summary_memory.py | 15 ++++++-- test/agents/test_summary_memory.py | 37 +++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/haystack/agents/memory/conversation_summary_memory.py b/haystack/agents/memory/conversation_summary_memory.py index f6aaf7ac07..9d4d7bbd73 100644 --- a/haystack/agents/memory/conversation_summary_memory.py +++ b/haystack/agents/memory/conversation_summary_memory.py @@ -50,18 +50,27 @@ def load(self, keys: Optional[List[str]] = None, **kwargs) -> str: :return: A formatted string containing the conversation history with the latest summary. """ if self.has_unsummarized_snippets(): - unsummarized = super().load(keys=keys, window_size=self.unsummarized_snippets()) + unsummarized = self.load_recent_snippets(window_size=self.unsummarized_snippets()) return f"{self.summary}\n{unsummarized}" else: return self.summary + def load_recent_snippets(self, window_size: int = 1) -> str: + """ + Load the most recent conversation snippets as a formatted string. + + :param window_size: integer specifying the number of most recent conversation snippets to load. + :return: A formatted string containing the most recent conversation snippets. + """ + return super().load(window_size=window_size) + def summarize(self) -> str: """ Generate a summary of the conversation history and clear the history. :return: A string containing the generated summary. """ - most_recent_chat_snippets = self.load(window_size=self.summary_frequency) + most_recent_chat_snippets = self.load_recent_snippets(window_size=self.summary_frequency) pn_response = self.prompt_node.prompt(self.template, chat_transcript=most_recent_chat_snippets) return pn_response[0] @@ -97,7 +106,7 @@ def save(self, data: Dict[str, Any]) -> None: super().save(data) self.save_count += 1 if self.needs_summary(): - self.summary = self.summarize() + self.summary += self.summarize() def clear(self) -> None: """ diff --git a/test/agents/test_summary_memory.py b/test/agents/test_summary_memory.py index 99bc50fd0f..15b7281f6e 100644 --- a/test/agents/test_summary_memory.py +++ b/test/agents/test_summary_memory.py @@ -88,6 +88,43 @@ def test_conversation_summary_memory_lower_summary_frequency(mocked_prompt_node) assert summary_mem.unsummarized_snippets() == 0 +@pytest.mark.unit +def test_conversation_summary_is_accumulating(mocked_prompt_node): + # ensure that the summary memory works after being triggered twice + summary = "This is a fake summary definitely." + mocked_prompt_node.prompt.return_value = [summary] + summary_mem = ConversationSummaryMemory(mocked_prompt_node, summary_frequency=2) + + data1: Dict[str, Any] = {"input": "Hello", "output": "Hi there"} + summary_mem.save(data1) + assert summary_mem.load() == "\nHuman: Hello\nAI: Hi there\n" + assert summary_mem.has_unsummarized_snippets() + assert summary_mem.unsummarized_snippets() == 1 + + # Test summarization + data2: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."} + summary_mem.save(data2) + assert summary_mem.load() == summary + assert not summary_mem.has_unsummarized_snippets() + assert summary_mem.unsummarized_snippets() == 0 + + # Add more snippets + new_snippet = "\nHuman: What's the weather like?\nAI: It's sunny outside.\n" + data3: Dict[str, Any] = {"input": "What's the weather like?", "output": "It's sunny outside."} + summary_mem.save(data3) + assert summary_mem.load() == summary + new_snippet + assert summary_mem.has_unsummarized_snippets() + assert summary_mem.unsummarized_snippets() == 1 + + # Trigger summarization again + data3: Dict[str, Any] = {"input": "What's the weather tomorrow?", "output": "It will be sunny."} + summary_mem.save(data3) + + # Ensure that the summary is accumulating + assert summary_mem.load() == summary + summary + assert not summary_mem.has_unsummarized_snippets() + + @pytest.mark.unit def test_conversation_summary_memory_with_template(mocked_prompt_node): pt = PromptTemplate("conversational-summary", "Summarize the conversation: {chat_transcript}") From 7625829684675113aa79bb7a92edc0b7d0dc5576 Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Tue, 16 May 2023 16:03:09 +0200 Subject: [PATCH 03/13] fix: `EvaluationResult` serialization changes dataframes (#4906) * fix nan and index values * add test * make test for None values after evalresult read explicit --- haystack/schema.py | 3 ++- test/others/test_schema.py | 18 +++++++++++++++++- test/pipelines/test_eval.py | 21 ++++++++++++++++++++- 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/haystack/schema.py b/haystack/schema.py index 7fc587b999..db4c32bc3b 100644 --- a/haystack/schema.py +++ b/haystack/schema.py @@ -932,7 +932,7 @@ def __len__(self): def append(self, key: str, value: DataFrame): if value is not None and len(value) > 0: if key in self.node_results: - self.node_results[key] = pd.concat([self.node_results[key], value]) + self.node_results[key] = pd.concat([self.node_results[key], value]).reset_index(drop=True) else: self.node_results[key] = value @@ -1620,6 +1620,7 @@ def safe_literal_eval(x: str) -> Any: node_results = {file.stem: pd.read_csv(file, **read_csv_kwargs) for file in csv_files} # backward compatibility mappings for df in node_results.values(): + df.replace(to_replace=np.nan, value=None, inplace=True) df.rename(columns={"gold_document_contents": "gold_contexts", "content": "context"}, inplace=True) # convert single document_id to list if "answer" in df.columns and "document_id" in df.columns and not "document_ids" in df.columns: diff --git a/test/others/test_schema.py b/test/others/test_schema.py index ba1a52b702..cde25cbf79 100644 --- a/test/others/test_schema.py +++ b/test/others/test_schema.py @@ -1,6 +1,6 @@ import json -from haystack.schema import Document, Label, Answer, Span, MultiLabel, TableCell, _dict_factory +from haystack.schema import Document, EvaluationResult, Label, Answer, Span, MultiLabel, TableCell, _dict_factory import pytest import numpy as np import pandas as pd @@ -1062,3 +1062,19 @@ def test_dict_factory(): assert result["key1"] == "some_value" assert result["key2"] == ["val1", "val2"] assert result["key3"] == [["col1", "col2"], [1, 3], [2, 4]] + + +@pytest.mark.unit +def test_evaluation_result_append(): + df1 = pd.DataFrame({"col1": [1, 2], "index": [3, 4]}) + df2 = pd.DataFrame({"col1": [5, 6], "index": [7, 8]}) + df_expected = pd.DataFrame({"col1": [1, 2, 5, 6], "index": [3, 4, 7, 8]}) + + eval_result = EvaluationResult() + eval_result.append("test", df1) + pd.testing.assert_frame_equal(eval_result["test"], df1) + assert isinstance(eval_result["test"].index, pd.RangeIndex) + + eval_result.append("test", df2) + pd.testing.assert_frame_equal(eval_result["test"], df_expected) + assert isinstance(eval_result["test"].index, pd.RangeIndex) diff --git a/test/pipelines/test_eval.py b/test/pipelines/test_eval.py index 3c3e8738b4..cc57012abd 100644 --- a/test/pipelines/test_eval.py +++ b/test/pipelines/test_eval.py @@ -598,6 +598,10 @@ def test_extractive_qa_eval(reader, retriever_with_docs, tmp_path): eval_result.save(tmp_path) saved_eval_result = EvaluationResult.load(tmp_path) + + for key, df in eval_result.node_results.items(): + pd.testing.assert_frame_equal(df, saved_eval_result[key]) + metrics = saved_eval_result.calculate_metrics(document_scope="document_id") assert ( @@ -718,6 +722,10 @@ def test_generative_qa_eval(retriever_with_docs, tmp_path): eval_result.save(tmp_path) saved_eval_result = EvaluationResult.load(tmp_path) + + for key, df in eval_result.node_results.items(): + pd.testing.assert_frame_equal(df, saved_eval_result[key]) + loaded_metrics = saved_eval_result.calculate_metrics(document_scope="document_id") assert metrics == loaded_metrics @@ -815,6 +823,10 @@ def test_generative_qa_w_promptnode_eval(retriever_with_docs, tmp_path): eval_result.save(tmp_path) saved_eval_result = EvaluationResult.load(tmp_path) + + for key, df in eval_result.node_results.items(): + pd.testing.assert_frame_equal(df, saved_eval_result[key]) + loaded_metrics = saved_eval_result.calculate_metrics(document_scope="document_id") assert metrics == loaded_metrics @@ -864,6 +876,10 @@ def test_extractive_qa_eval_multiple_queries(reader, retriever_with_docs, tmp_pa eval_result.save(tmp_path) saved_eval_result = EvaluationResult.load(tmp_path) + + for key, df in eval_result.node_results.items(): + pd.testing.assert_frame_equal(df, saved_eval_result[key]) + metrics = saved_eval_result.calculate_metrics(document_scope="document_id") assert ( @@ -2084,7 +2100,7 @@ def test_load_legacy_evaluation_result(tmp_path): assert "content" not in eval_result["legacy"] -def test_load_evaluation_result_w_empty_document_ids(tmp_path): +def test_load_evaluation_result_w_none_values(tmp_path): eval_result_csv = Path(tmp_path) / "Reader.csv" with open(eval_result_csv, "w") as eval_result_csv: columns = [ @@ -2158,3 +2174,6 @@ def test_load_evaluation_result_w_empty_document_ids(tmp_path): eval_result = EvaluationResult.load(tmp_path) assert "Reader" in eval_result assert len(eval_result) == 1 + assert eval_result["Reader"].iloc[0].answer is None + assert eval_result["Reader"].iloc[0].context is None + assert eval_result["Reader"].iloc[0].document_ids is None From ca68601ec7ddeee7004d72a9def8a65b1e4749f5 Mon Sep 17 00:00:00 2001 From: yuanwu2017 Date: Wed, 17 May 2023 14:48:11 +0800 Subject: [PATCH 04/13] fix: shaper exception when retriever return 0 docs. (#4929) * When retriever retrieves 0 documents from the documentStore, shaper will raise an exception. Signed-off-by: root Co-authored-by: root --- haystack/nodes/other/shaper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/nodes/other/shaper.py b/haystack/nodes/other/shaper.py index 26ae6162bd..a229234c69 100644 --- a/haystack/nodes/other/shaper.py +++ b/haystack/nodes/other/shaper.py @@ -740,7 +740,7 @@ def run( # type: ignore if labels and "labels" not in invocation_context.keys(): invocation_context["labels"] = labels - if documents and "documents" not in invocation_context.keys(): + if documents != None and "documents" not in invocation_context.keys(): invocation_context["documents"] = documents if meta and "meta" not in invocation_context.keys(): From 9d52998b2546a135a469d6659c687f2e393a7c05 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 17 May 2023 15:19:09 +0200 Subject: [PATCH 05/13] feat: Add conversational agent (#4931) --- examples/agent_multihop_qa.py | 1 + examples/conversational_agent.py | 17 ++++ haystack/agents/agent_step.py | 100 +++++++++++---------- haystack/agents/base.py | 100 ++++++++++++--------- haystack/agents/conversational.py | 62 +++++++++++++ haystack/nodes/prompt/prompt_template.py | 6 +- test/agents/test_agent.py | 27 +++--- test/agents/test_agent_step.py | 109 +++++++++++++++++++++++ test/agents/test_conversational_agent.py | 39 ++++++++ test/agents/test_tools_manager.py | 37 ++++++++ test/prompt/test_prompt_node.py | 2 +- 11 files changed, 398 insertions(+), 102 deletions(-) create mode 100644 examples/conversational_agent.py create mode 100644 haystack/agents/conversational.py create mode 100644 test/agents/test_agent_step.py create mode 100644 test/agents/test_conversational_agent.py diff --git a/examples/agent_multihop_qa.py b/examples/agent_multihop_qa.py index 076843c02f..3cd764f897 100644 --- a/examples/agent_multihop_qa.py +++ b/examples/agent_multihop_qa.py @@ -82,6 +82,7 @@ ## Question: {query} Thought: +{transcript} """ few_shot_agent_template = PromptTemplate("few-shot-react", prompt_text=few_shot_prompt) prompt_node = PromptNode( diff --git a/examples/conversational_agent.py b/examples/conversational_agent.py new file mode 100644 index 0000000000..4efa973e56 --- /dev/null +++ b/examples/conversational_agent.py @@ -0,0 +1,17 @@ +import os + +from haystack.agents.conversational import ConversationalAgent +from haystack.nodes import PromptNode + +pn = PromptNode("gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"), max_length=256) +agent = ConversationalAgent(pn) + +while True: + user_input = input("Human (type 'exit' or 'quit' to quit, 'memory' for agent's memory): ") + if user_input.lower() == "exit" or user_input.lower() == "quit": + break + if user_input.lower() == "memory": + print("\nMemory:\n", agent.memory.load()) + else: + assistant_response = agent.run(user_input) + print("\nAssistant:", assistant_response) diff --git a/haystack/agents/agent_step.py b/haystack/agents/agent_step.py index ec263c3f3c..2981974744 100644 --- a/haystack/agents/agent_step.py +++ b/haystack/agents/agent_step.py @@ -20,49 +20,38 @@ def __init__( self, current_step: int = 1, max_steps: int = 10, - final_answer_pattern: str = r"Final Answer\s*:\s*(.*)", + final_answer_pattern: Optional[str] = None, prompt_node_response: str = "", transcript: str = "", ): """ :param current_step: The current step in the execution of the agent. :param max_steps: The maximum number of steps the agent can execute. - :param final_answer_pattern: The regex pattern to extract the final answer from the PromptNode response. + :param final_answer_pattern: The regex pattern to extract the final answer from the PromptNode response. If no + pattern is provided, entire prompt node response is considered the final answer. :param prompt_node_response: The PromptNode response received. - :param transcript: The full Agent execution transcript based on the Agent's initial prompt template and the text it generated during execution up to this step. The transcript is used to generate the next prompt. """ self.current_step = current_step self.max_steps = max_steps - self.final_answer_pattern = final_answer_pattern + self.final_answer_pattern = final_answer_pattern or r"^([\s\S]+)$" self.prompt_node_response = prompt_node_response self.transcript = transcript - def prepare_prompt(self): - """ - Prepares the prompt for the next step. - """ - return self.transcript - - def create_next_step(self, prompt_node_response: Any) -> AgentStep: + def create_next_step(self, prompt_node_response: Any, current_step: Optional[int] = None) -> AgentStep: """ Creates the next agent step based on the current step and the PromptNode response. :param prompt_node_response: The PromptNode response received. + :param current_step: The current step in the execution of the agent. """ - if not isinstance(prompt_node_response, list): + if not isinstance(prompt_node_response, list) or not prompt_node_response: raise AgentError( - f"Agent output must be a list of str, but {prompt_node_response} received. " + f"Agent output must be a non-empty list of str, but {prompt_node_response} received. " f"Transcript:\n{self.transcript}" ) - - if not prompt_node_response: - raise AgentError( - f"Agent output must be a non empty list of str, but {prompt_node_response} received. " - f"Transcript:\n{self.transcript}" - ) - - return AgentStep( - current_step=self.current_step + 1, + cls = type(self) + return cls( + current_step=current_step if current_step else self.current_step + 1, max_steps=self.max_steps, final_answer_pattern=self.final_answer_pattern, prompt_node_response=prompt_node_response[0], @@ -81,7 +70,7 @@ def final_answer(self, query: str) -> Dict[str, Any]: "answers": [Answer(answer="", type="generative")], "transcript": self.transcript, } - if self.current_step >= self.max_steps: + if self.current_step > self.max_steps: logger.warning( "Maximum number of iterations (%s) reached for query (%s). Increase max_steps " "or no answer can be provided for this query.", @@ -89,10 +78,10 @@ def final_answer(self, query: str) -> Dict[str, Any]: query, ) else: - final_answer = self.extract_final_answer() + final_answer = self.parse_final_answer() if not final_answer: logger.warning( - "Final answer pattern (%s) not found in PromptNode response (%s).", + "Final answer parser (%s) could not parse PromptNode response (%s).", self.final_answer_pattern, self.prompt_node_response, ) @@ -104,35 +93,14 @@ def final_answer(self, query: str) -> Dict[str, Any]: } return answer - def extract_final_answer(self) -> Optional[str]: - """ - Parse the final answer from the PromptNode response. - :return: The final answer. - """ - if not self.is_last(): - raise AgentError("Cannot extract final answer from non terminal step.") - - final_answer_match = re.search(self.final_answer_pattern, self.prompt_node_response) - if final_answer_match: - final_answer = final_answer_match.group(1) - return final_answer.strip('" ') - return None - - def is_final_answer_pattern_found(self) -> bool: - """ - Check if the final answer pattern was found in PromptNode response. - :return: True if the final answer pattern was found in PromptNode response, False otherwise. - """ - return bool(re.search(self.final_answer_pattern, self.prompt_node_response)) - def is_last(self) -> bool: """ Check if this is the last step of the Agent. :return: True if this is the last step of the Agent, False otherwise. """ - return self.is_final_answer_pattern_found() or self.current_step >= self.max_steps + return bool(self.parse_final_answer()) or self.current_step > self.max_steps - def completed(self, observation: Optional[str]): + def completed(self, observation: Optional[str]) -> None: """ Update the transcript with the observation :param observation: received observation from the Agent environment. @@ -142,3 +110,39 @@ def completed(self, observation: Optional[str]): if observation else self.prompt_node_response ) + + def __repr__(self) -> str: + """ + Return a string representation of the AgentStep object. + + :return: A string that represents the AgentStep object. + """ + return ( + f"AgentStep(current_step={self.current_step}, max_steps={self.max_steps}, " + f"prompt_node_response={self.prompt_node_response}, final_answer_pattern={self.final_answer_pattern}, " + f"transcript={self.transcript})" + ) + + def parse_final_answer(self) -> Optional[str]: + """ + Parse the final answer from the response of the prompt node. + + This function searches the prompt node's response for a match with the + pre-defined final answer pattern. If a match is found, it's returned as the + final answer after removing leading/trailing quotes and whitespaces. + If no match is found, it returns None. + + :return: The final answer as a string if a match is found, otherwise None. + """ + # Search for a match with the final answer pattern in the prompt node response + final_answer_match = re.search(self.final_answer_pattern, self.prompt_node_response) + + if final_answer_match: + # If a match is found, get the first group (i.e., the content inside the parentheses of the regex pattern) + final_answer = final_answer_match.group(1) + + # Remove leading/trailing quotes and whitespaces, then return the final answer + return final_answer.strip('" ') # type: ignore + else: + # If no match is found, return None + return None diff --git a/haystack/agents/base.py b/haystack/agents/base.py index 166789b501..6cd11ff7b3 100644 --- a/haystack/agents/base.py +++ b/haystack/agents/base.py @@ -2,17 +2,18 @@ import logging import re +from collections.abc import Iterable, Callable from hashlib import md5 from typing import List, Optional, Union, Dict, Any, Tuple from events import Events from haystack import Pipeline, BaseComponent, Answer, Document +from haystack.agents.memory import Memory, NoMemory from haystack.telemetry import send_event from haystack.agents.agent_step import AgentStep from haystack.agents.types import Color, AgentTokenStreamingHandler from haystack.agents.utils import print_text, STREAMING_CAPABLE_MODELS -from haystack.errors import AgentError from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate from haystack.pipelines import ( BaseStandardPipeline, @@ -221,8 +222,10 @@ class Agent: def __init__( self, prompt_node: PromptNode, - prompt_template: Union[str, PromptTemplate] = "zero-shot-react", + prompt_template: Optional[Union[str, PromptTemplate]] = None, tools_manager: Optional[ToolsManager] = None, + memory: Optional[Memory] = None, + prompt_parameters_resolver: Optional[Callable] = None, max_steps: int = 8, final_answer_pattern: str = r"Final Answer\s*:\s*(.*)", ): @@ -235,24 +238,40 @@ def __init__( choosing tools to answer queries step-by-step. You can use the default `zero-shot-react` template or create a new template in a similar format. with `add_tool()` before running the Agent. - :param tools: A list of tools to add to the Agent. Each tool must have a unique name. You can also add tools - with `add_tool()` before running the Agent. + :param tools_manager: A ToolsManager instance that the Agent uses to run tools. Each tool must have a unique name. + You can also add tools with `add_tool()` before running the Agent. + :param memory: A Memory instance that the Agent uses to store information between iterations. + :param prompt_parameters_resolver: A callable that takes query, agent, and agent_step as parameters and returns + a dictionary of parameters to pass to the prompt_template. The default is a callable that returns a dictionary + of keys and values needed for the React agent prompt template. :param max_steps: The number of times the Agent can run a tool +1 to let it infer it knows the final answer. Set it to at least 2, so that the Agent can run one a tool once and then infer it knows the final answer. - The default is 5. - text the Agent generated. + The default is 8. :param final_answer_pattern: A regular expression to extract the final answer from the text the Agent generated. """ self.max_steps = max_steps self.tm = tools_manager or ToolsManager() + self.memory = memory or NoMemory() self.callback_manager = Events(("on_agent_start", "on_agent_step", "on_agent_finish", "on_new_token")) self.prompt_node = prompt_node + prompt_template = prompt_template or "zero-shot-react" resolved_prompt_template = prompt_node.get_prompt_template(prompt_template) if not resolved_prompt_template: raise ValueError( f"Prompt template '{prompt_template}' not found. Please check the spelling of the template name." ) self.prompt_template = resolved_prompt_template + react_parameter_resolver: Callable[ + [str, Agent, AgentStep, Dict[str, Any]], Dict[str, Any] + ] = lambda query, agent, agent_step, **kwargs: { + "query": query, + "tool_names": agent.tm.get_tool_names(), + "tool_names_with_descriptions": agent.tm.get_tool_names_with_descriptions(), + "transcript": agent_step.transcript, + } + self.prompt_parameters_resolver = ( + prompt_parameters_resolver if prompt_parameters_resolver else react_parameter_resolver + ) self.final_answer_pattern = final_answer_pattern # Resolve model name to check if it's a streaming model if isinstance(self.prompt_node.model_name_or_path, str): @@ -350,37 +369,18 @@ def run( except Exception as exc: logger.debug("Telemetry exception: %s", exc) - if max_steps is None: - max_steps = self.max_steps - if max_steps < 2: - raise AgentError( - f"max_steps must be at least 2 to let the Agent use a tool once and then infer it knows the final " - f"answer. It was set to {max_steps}." - ) self.callback_manager.on_agent_start(name=self.prompt_template.name, query=query, params=params) - agent_step = self._create_first_step(query, max_steps) + agent_step = self.create_agent_step(max_steps) try: while not agent_step.is_last(): - agent_step = self._step(agent_step, params) + agent_step = self._step(query, agent_step, params) finally: self.callback_manager.on_agent_finish(agent_step) return agent_step.final_answer(query=query) - def _create_first_step(self, query: str, max_steps: int = 10): - transcript = self._get_initial_transcript(query=query) - return AgentStep( - current_step=1, - max_steps=max_steps, - final_answer_pattern=self.final_answer_pattern, - prompt_node_response="", # no LLM response for the first step - transcript=transcript, - ) - - def _step(self, current_step: AgentStep, params: Optional[dict] = None): + def _step(self, query: str, current_step: AgentStep, params: Optional[dict] = None): # plan next step using the LLM - prompt_node_response = self.prompt_node( - current_step.prepare_prompt(), stream_handler=AgentTokenStreamingHandler(self.callback_manager) - ) + prompt_node_response = self._plan(query, current_step) # from the LLM response, create the next step next_step = current_step.create_next_step(prompt_node_response) @@ -389,21 +389,41 @@ def _step(self, current_step: AgentStep, params: Optional[dict] = None): # run the tool selected by the LLM observation = self.tm.run_tool(next_step.prompt_node_response, params) if not next_step.is_last() else None + # save the input, output and observation to memory (if memory is enabled) + memory_data = self.prepare_data_for_memory(input=query, output=prompt_node_response, observation=observation) + self.memory.save(data=memory_data) + # update the next step with the observation next_step.completed(observation) return next_step - def _get_initial_transcript(self, query: str): + def _plan(self, query, current_step): + # first resolve prompt template params + template_params = self.prompt_parameters_resolver(query=query, agent=self, agent_step=current_step) + + # if prompt node has no default prompt template, use agent's prompt template + if self.prompt_node.default_prompt_template is None: + prepared_prompt = next(self.prompt_template.fill(**template_params)) + prompt_node_response = self.prompt_node( + prepared_prompt, stream_handler=AgentTokenStreamingHandler(self.callback_manager) + ) + # otherwise, if prompt node has default prompt template, use it + else: + prompt_node_response = self.prompt_node( + stream_handler=AgentTokenStreamingHandler(self.callback_manager), **template_params + ) + return prompt_node_response + + def create_agent_step(self, max_steps: Optional[int] = None) -> AgentStep: """ - Fills the Agent's PromptTemplate with the query, tool names, and descriptions. + Create an AgentStep object. Override this method to customize the AgentStep class used by the Agent. + """ + return AgentStep(max_steps=max_steps or self.max_steps, final_answer_pattern=self.final_answer_pattern) - :param query: The search query. + def prepare_data_for_memory(self, **kwargs) -> dict: """ - return next( - self.prompt_template.fill( - query=query, - tool_names=self.tm.get_tool_names(), - tool_names_with_descriptions=self.tm.get_tool_names_with_descriptions(), - ), - "", - ) + Prepare data for saving to the Agent's memory. Override this method to customize the data saved to the memory. + """ + return { + k: v if isinstance(v, str) else next(iter(v)) for k, v in kwargs.items() if isinstance(v, (str, Iterable)) + } diff --git a/haystack/agents/conversational.py b/haystack/agents/conversational.py new file mode 100644 index 0000000000..adb42c9802 --- /dev/null +++ b/haystack/agents/conversational.py @@ -0,0 +1,62 @@ +from typing import Optional, Callable + +from haystack.agents import Agent +from haystack.agents.memory import Memory, ConversationMemory +from haystack.nodes import PromptNode + + +class ConversationalAgent(Agent): + """ + A conversational agent that can be used to build a conversational chat applications. + + Here is an example of how you can create a simple chat application: + ``` + import os + + from haystack.agents.base import ConversationalAgent + from haystack.agents.memory import ConversationSummaryMemory + from haystack.nodes import PromptNode + + pn = PromptNode("gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"), max_length=256) + agent = ConversationalAgent(pn, memory=ConversationSummaryMemory(pn)) + + while True: + user_input = input("Human (type 'exit' or 'quit' to quit): ") + if user_input.lower() == "exit" or user_input.lower() == "quit": + break + elif user_input.lower() == "memory": + print("\nMemory:\n", agent.memory.load()) + else: + assistant_response = agent.run(user_input) + print("\nAssistant:", assistant_response) + + ``` + """ + + def __init__( + self, + prompt_node: PromptNode, + memory: Optional[Memory] = None, + prompt_parameters_resolver: Optional[Callable] = None, + ): + """ + Creates a new ConversationalAgent instance + + :param prompt_node: A PromptNode used to communicate with LLM. + :param memory: A memory instance for storing conversation history and other relevant data, defaults to + ConversationMemory. + :param prompt_parameters_resolver: An optional callable for resolving prompt template parameters, + defaults to a callable that returns a dictionary with the query and the conversation history. + """ + super().__init__( + prompt_node=prompt_node, + prompt_template=prompt_node.default_prompt_template + if prompt_node.default_prompt_template is not None + else "conversational-agent", + max_steps=2, + memory=memory if memory else ConversationMemory(), + prompt_parameters_resolver=prompt_parameters_resolver + if prompt_parameters_resolver + else lambda query, agent, **kwargs: {"query": query, "history": agent.memory.load()}, + final_answer_pattern=r"^([\s\S]+)$", + ) diff --git a/haystack/nodes/prompt/prompt_template.py b/haystack/nodes/prompt/prompt_template.py index 4811d7d28d..c180135641 100644 --- a/haystack/nodes/prompt/prompt_template.py +++ b/haystack/nodes/prompt/prompt_template.py @@ -432,7 +432,11 @@ def get_predefined_prompt_templates() -> List[PromptTemplate]: "Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n" "---\n\n" "Question: {query}\n" - "Thought: Let's think step-by-step, I first need to ", + "Thought: Let's think step-by-step, I first need to {transcript}", + ), + PromptTemplate( + name="conversational-agent", + prompt_text="The following is a conversation between a human and an AI.\n{history}\nHuman: {query}\nAI:", ), PromptTemplate( name="conversational-summary", diff --git a/test/agents/test_agent.py b/test/agents/test_agent.py index 6ddba22804..f1b1a48255 100644 --- a/test/agents/test_agent.py +++ b/test/agents/test_agent.py @@ -128,17 +128,9 @@ def test_extract_final_answer(): ] for example, expected_answer in zip(match_examples, expected_answers): - agent_step = AgentStep(prompt_node_response=example) - final_answer = agent_step.extract_final_answer() - assert final_answer == expected_answer - - -@pytest.mark.unit -def test_format_answer(): - step = AgentStep(prompt_node_response="have the final answer to the question.\nFinal Answer: Florida") - formatted_answer = step.final_answer(query="query") - assert formatted_answer["query"] == "query" - assert formatted_answer["answers"] == [Answer(answer="Florida", type="generative")] + agent_step = AgentStep(prompt_node_response=example, final_answer_pattern=r"Final Answer\s*:\s*(.*)") + final_answer = agent_step.final_answer(query="irrelevant") + assert final_answer["answers"][0].answer == expected_answer @pytest.mark.unit @@ -229,7 +221,7 @@ def test_agent_run(reader, retriever_with_docs, document_store_with_docs): ), ) - agent = Agent(prompt_node=prompt_node) + agent = Agent(prompt_node=prompt_node, max_steps=12) agent.add_tool( Tool( name="Search", @@ -282,3 +274,14 @@ def test_update_hash(): assert agent.hash == "d41d8cd98f00b204e9800998ecf8427e" agent.update_hash() assert agent.hash == "5ac8eca2f92c9545adcce3682b80d4c5" + + +def test_invalid_agent_template(): + pn = PromptNode() + with pytest.raises(ValueError, match="some_non_existing_template not supported"): + Agent(prompt_node=pn, prompt_template="some_non_existing_template") + + # if prompt_template is None, then we'll use zero-shot-react + a = Agent(prompt_node=pn, prompt_template=None) + assert isinstance(a.prompt_template, PromptTemplate) + assert a.prompt_template.name == "zero-shot-react" diff --git a/test/agents/test_agent_step.py b/test/agents/test_agent_step.py new file mode 100644 index 0000000000..3d0f0cb4e7 --- /dev/null +++ b/test/agents/test_agent_step.py @@ -0,0 +1,109 @@ +import pytest + +from haystack import Answer +from haystack.agents import AgentStep +from haystack.errors import AgentError + + +@pytest.fixture +def agent_step(): + return AgentStep( + current_step=1, max_steps=10, final_answer_pattern=None, prompt_node_response="Hello", transcript="Hello" + ) + + +@pytest.mark.unit +def test_create_next_step(agent_step): + # Test normal case + next_step = agent_step.create_next_step(["Hello again"]) + assert next_step.current_step == 2 + assert next_step.prompt_node_response == "Hello again" + assert next_step.transcript == "Hello" + + # Test with invalid prompt_node_response + with pytest.raises(AgentError): + agent_step.create_next_step({}) + + # Test with empty prompt_node_response + with pytest.raises(AgentError): + agent_step.create_next_step([]) + + +@pytest.mark.unit +def test_final_answer(agent_step): + # Test normal case + result = agent_step.final_answer("query") + assert result["query"] == "query" + assert isinstance(result["answers"][0], Answer) + assert result["answers"][0].answer == "Hello" + assert result["answers"][0].type == "generative" + assert result["transcript"] == "Hello" + + # Test with max_steps reached + agent_step.current_step = 11 + result = agent_step.final_answer("query") + assert result["answers"][0].answer == "" + + +@pytest.mark.unit +def test_is_last(): + # Test is last, and it is last because of valid prompt_node_response and default final_answer_pattern + agent_step = AgentStep(current_step=1, max_steps=10, prompt_node_response="Hello", transcript="Hello") + assert agent_step.is_last() + + # Test not last + agent_step.current_step = 1 + agent_step.prompt_node_response = "final answer not satisfying pattern" + agent_step.final_answer_pattern = r"Final Answer\s*:\s*(.*)" + assert not agent_step.is_last() + + # Test border cases for max_steps + agent_step.current_step = 9 + assert not agent_step.is_last() + agent_step.current_step = 10 + assert not agent_step.is_last() + + # Test when last due to max_steps + agent_step.current_step = 11 + assert agent_step.is_last() + + +@pytest.mark.unit +def test_completed(agent_step): + # Test without observation + agent_step.completed(None) + assert agent_step.transcript == "HelloHello" + + # Test with observation, adds Hello from prompt_node_response + agent_step.completed("observation") + assert agent_step.transcript == "HelloHelloHello\nObservation: observation\nThought:" + + +@pytest.mark.unit +def test_repr(agent_step): + assert repr(agent_step) == ( + "AgentStep(current_step=1, max_steps=10, " + "prompt_node_response=Hello, final_answer_pattern=^([\\s\\S]+)$, " + "transcript=Hello)" + ) + + +@pytest.mark.unit +def test_parse_final_answer(agent_step): + # Test when pattern matches + assert agent_step.parse_final_answer() == "Hello" + + # Test when pattern does not match + agent_step.final_answer_pattern = "goodbye" + assert agent_step.parse_final_answer() is None + + +@pytest.mark.unit +def test_format_react_answer(): + step = AgentStep( + final_answer_pattern=r"Final Answer\s*:\s*(.*)", + prompt_node_response="have the final answer to the question.\nFinal Answer: Florida", + ) + formatted_answer = step.final_answer(query="query") + assert formatted_answer["query"] == "query" + assert formatted_answer["answers"] == [Answer(answer="Florida", type="generative")] diff --git a/test/agents/test_conversational_agent.py b/test/agents/test_conversational_agent.py new file mode 100644 index 0000000000..8ffcd1802b --- /dev/null +++ b/test/agents/test_conversational_agent.py @@ -0,0 +1,39 @@ +import pytest +from unittest.mock import MagicMock + +from haystack.agents.conversational import ConversationalAgent +from haystack.agents.memory import ConversationSummaryMemory, ConversationMemory, NoMemory +from haystack.nodes import PromptNode + + +@pytest.mark.unit +def test_init(): + prompt_node = PromptNode() + agent = ConversationalAgent(prompt_node) + # Test normal case + assert isinstance(agent.memory, ConversationMemory) + assert callable(agent.prompt_parameters_resolver) + assert agent.max_steps == 2 + assert agent.final_answer_pattern == r"^([\s\S]+)$" + + # ConversationalAgent doesn't have tools + assert not agent.tm.tools + + # Test with summary memory + agent = ConversationalAgent(prompt_node, memory=ConversationSummaryMemory(prompt_node)) + assert isinstance(agent.memory, ConversationSummaryMemory) + + # Test with no memory + agent = ConversationalAgent(prompt_node, memory=NoMemory()) + assert isinstance(agent.memory, NoMemory) + + +@pytest.mark.unit +def test_run(): + prompt_node = PromptNode() + agent = ConversationalAgent(prompt_node) + + # Mock the Agent run method + agent.run = MagicMock(return_value="Hello") + assert agent.run("query") == "Hello" + agent.run.assert_called_once_with("query") diff --git a/test/agents/test_tools_manager.py b/test/agents/test_tools_manager.py index b144d4f9f5..48ddbe1259 100644 --- a/test/agents/test_tools_manager.py +++ b/test/agents/test_tools_manager.py @@ -1,6 +1,9 @@ +import unittest from unittest import mock import pytest + +from haystack import Pipeline, Answer, Document from haystack.agents.base import ToolsManager, Tool @@ -61,3 +64,37 @@ def test_extract_tool_name_and_tool_input(tools_manager): for example in negative_examples: tool_name, tool_input = tools_manager.extract_tool_name_and_tool_input(example) assert tool_name is None and tool_input is None + + +@pytest.mark.unit +def test_invalid_tool_creation(): + with pytest.raises(ValueError, match="Invalid"): + Tool(name="Tool-A", pipeline_or_node=mock.Mock(), description="Tool A Description") + + +@pytest.mark.unit +def test_tool_invocation(): + # by default for pipelines as tools we look for results key in the output + p = Pipeline() + tool = Tool(name="ToolA", pipeline_or_node=p, description="Tool A Description") + with unittest.mock.patch("haystack.pipelines.Pipeline.run", return_value={"results": "mock"}): + assert tool.run("input") == "mock" + + # now fail if results key is not present + with unittest.mock.patch("haystack.pipelines.Pipeline.run", return_value={"no_results": "mock"}): + with pytest.raises(ValueError, match="Tool ToolA returned result"): + assert tool.run("input") + + # now try tool with a correct output variable + tool = Tool(name="ToolA", pipeline_or_node=p, description="Tool A Description", output_variable="no_results") + with unittest.mock.patch("haystack.pipelines.Pipeline.run", return_value={"no_results": "mock_no_results"}): + assert tool.run("input") == "mock_no_results" + + # try tool that internally returns an Answer object but we extract the string + tool = Tool(name="ToolA", pipeline_or_node=p, description="Tool A Description") + with unittest.mock.patch("haystack.pipelines.Pipeline.run", return_value=[Answer("mocked_answer")]): + assert tool.run("input") == "mocked_answer" + + # same but for the document + with unittest.mock.patch("haystack.pipelines.Pipeline.run", return_value=[Document("mocked_document")]): + assert tool.run("input") == "mocked_document" diff --git a/test/prompt/test_prompt_node.py b/test/prompt/test_prompt_node.py index 33acf5edcc..3d43b2b9f4 100644 --- a/test/prompt/test_prompt_node.py +++ b/test/prompt/test_prompt_node.py @@ -28,7 +28,7 @@ def get_api_key(request): def test_add_and_remove_template(): with patch("haystack.nodes.prompt.prompt_node.PromptModel"): node = PromptNode() - total_count = 15 + total_count = 16 # Verifies default assert len(node.get_prompt_template_names()) == total_count From 34b7d1edb09729f15f73955b6c0892e2f8c9c033 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Wed, 17 May 2023 18:51:21 +0200 Subject: [PATCH 06/13] Small fix to PromptTemplate API docs (#4870) --- haystack/nodes/prompt/prompt_template.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/haystack/nodes/prompt/prompt_template.py b/haystack/nodes/prompt/prompt_template.py index c180135641..fd80b5fb2c 100644 --- a/haystack/nodes/prompt/prompt_template.py +++ b/haystack/nodes/prompt/prompt_template.py @@ -164,8 +164,8 @@ class PromptTemplate(BasePromptTemplate, ABC): ```python PromptTemplate(name="sentiment-analysis", - prompt_text="Give a sentiment for this context. Answer with positive, negative - or neutral. Context: {documents}; Answer:") + prompt_text="Give a sentiment for this context. Answer with positive, negative" + "or neutral. Context: {documents}; Answer:") ``` Optionally, you can declare prompt parameters using f-string syntax in the PromptTemplate. Prompt parameters are input parameters that need to be filled in From df46e7fadd9a41e761f621ea6e59e75e67b30f54 Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Wed, 17 May 2023 18:54:34 +0200 Subject: [PATCH 07/13] fix: Use `AutoTokenizer` instead of DPR specific tokenizer (#4898) * Use AutoTokenizer instead of DPR specific tokenizer * Adapt TableTextRetriever * Adapt tests * Adapt tests --- haystack/nodes/retriever/dense.py | 22 ++++++---------------- test/nodes/test_retriever.py | 12 ++++++------ 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index 1ee803d88a..a2dbe19455 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -19,13 +19,7 @@ from torch.utils.data.sampler import SequentialSampler import pandas as pd from huggingface_hub import hf_hub_download -from transformers import ( - AutoConfig, - DPRContextEncoderTokenizerFast, - DPRQuestionEncoderTokenizerFast, - DPRContextEncoderTokenizer, - DPRQuestionEncoderTokenizer, -) +from transformers import AutoConfig, AutoTokenizer from haystack.errors import HaystackError from haystack.schema import Document, FilterType @@ -191,7 +185,7 @@ def __init__( ) # Init & Load Encoders - self.query_tokenizer = DPRQuestionEncoderTokenizerFast.from_pretrained( + self.query_tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path=query_embedding_model, revision=model_version, do_lower_case=True, @@ -203,7 +197,7 @@ def __init__( model_type="DPRQuestionEncoder", use_auth_token=use_auth_token, ) - self.passage_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained( + self.passage_tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path=passage_embedding_model, revision=model_version, do_lower_case=True, @@ -873,12 +867,8 @@ def __init__( self.embed_meta_fields = embed_meta_fields self.scale_score = scale_score - query_tokenizer_class = DPRQuestionEncoderTokenizerFast if use_fast else DPRQuestionEncoderTokenizer - passage_tokenizer_class = DPRContextEncoderTokenizerFast if use_fast else DPRContextEncoderTokenizer - table_tokenizer_class = DPRContextEncoderTokenizerFast if use_fast else DPRContextEncoderTokenizer - # Init & Load Encoders - self.query_tokenizer = query_tokenizer_class.from_pretrained( + self.query_tokenizer = AutoTokenizer.from_pretrained( query_embedding_model, revision=model_version, do_lower_case=True, @@ -888,7 +878,7 @@ def __init__( self.query_encoder = get_language_model( pretrained_model_name_or_path=query_embedding_model, revision=model_version, use_auth_token=use_auth_token ) - self.passage_tokenizer = passage_tokenizer_class.from_pretrained( + self.passage_tokenizer = AutoTokenizer.from_pretrained( passage_embedding_model, revision=model_version, do_lower_case=True, @@ -898,7 +888,7 @@ def __init__( self.passage_encoder = get_language_model( pretrained_model_name_or_path=passage_embedding_model, revision=model_version, use_auth_token=use_auth_token ) - self.table_tokenizer = table_tokenizer_class.from_pretrained( + self.table_tokenizer = AutoTokenizer.from_pretrained( table_embedding_model, revision=model_version, do_lower_case=True, diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index 3915b4e36d..64dec513ff 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -10,7 +10,7 @@ import requests from boilerpy3.extractors import ArticleExtractor from pandas.testing import assert_frame_equal -from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast +from transformers import PreTrainedTokenizerFast try: @@ -578,8 +578,8 @@ def sum_params(model): assert loaded_retriever.processor.max_seq_len_query == 64 # Tokenizer - assert isinstance(loaded_retriever.passage_tokenizer, DPRContextEncoderTokenizerFast) - assert isinstance(loaded_retriever.query_tokenizer, DPRQuestionEncoderTokenizerFast) + assert isinstance(loaded_retriever.passage_tokenizer, PreTrainedTokenizerFast) + assert isinstance(loaded_retriever.query_tokenizer, PreTrainedTokenizerFast) assert loaded_retriever.passage_tokenizer.do_lower_case == True assert loaded_retriever.query_tokenizer.do_lower_case == True assert loaded_retriever.passage_tokenizer.vocab_size == 30522 @@ -621,9 +621,9 @@ def sum_params(model): assert loaded_retriever.processor.max_seq_len_query == 64 # Tokenizer - assert isinstance(loaded_retriever.passage_tokenizer, DPRContextEncoderTokenizerFast) - assert isinstance(loaded_retriever.table_tokenizer, DPRContextEncoderTokenizerFast) - assert isinstance(loaded_retriever.query_tokenizer, DPRQuestionEncoderTokenizerFast) + assert isinstance(loaded_retriever.passage_tokenizer, PreTrainedTokenizerFast) + assert isinstance(loaded_retriever.table_tokenizer, PreTrainedTokenizerFast) + assert isinstance(loaded_retriever.query_tokenizer, PreTrainedTokenizerFast) assert loaded_retriever.passage_tokenizer.do_lower_case == True assert loaded_retriever.table_tokenizer.do_lower_case == True assert loaded_retriever.query_tokenizer.do_lower_case == True From 8cfeed095db538d0b9cfa12c138998cb466e39ed Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Wed, 17 May 2023 21:31:08 +0200 Subject: [PATCH 08/13] build: Remove mmh3 dependency (#4896) * build: Remove mmh3 dependency * resolve circular import * pylint * make mmh3.py sibling of schema.py * pylint import order * pylint * undo example changes * increase coverage in modeling module * increase coverage further * rename new unit tests --- haystack/__init__.py | 2 +- haystack/mmh3.py | 344 ++++++++++++++++++++++++++++++++ haystack/schema.py | 6 +- haystack/utils/squad_data.py | 5 +- pyproject.toml | 1 - test/agents/test_agent.py | 16 +- test/modeling/test_processor.py | 14 +- test/others/test_squad_data.py | 34 +++- test/utils/test_mmh3.py | 10 + 9 files changed, 421 insertions(+), 11 deletions(-) create mode 100644 haystack/mmh3.py create mode 100644 test/utils/test_mmh3.py diff --git a/haystack/__init__.py b/haystack/__init__.py index b527cee87b..8db3d83112 100644 --- a/haystack/__init__.py +++ b/haystack/__init__.py @@ -31,7 +31,6 @@ "magic", "markdown", "mlflow", - "mmh3", "more_itertools", "networkx", "nltk", @@ -94,6 +93,7 @@ def is_imported(module_name: str) -> bool: from haystack.nodes.base import BaseComponent from haystack.pipelines.base import Pipeline from haystack.environment import set_pytorch_secure_model_loading +from haystack.mmh3 import hash128 # Enables torch's secure model loading through setting an env var. diff --git a/haystack/mmh3.py b/haystack/mmh3.py new file mode 100644 index 0000000000..7dc482dedf --- /dev/null +++ b/haystack/mmh3.py @@ -0,0 +1,344 @@ +import sys as _sys + +# based on https://github.com/wc-duck/pymmh3/blob/master/pymmh3.py + +if _sys.version_info > (3, 0): + + def xrange(a, b, c): + return range(a, b, c) + + def xencode(x): + if isinstance(x, (bytes, bytearray)): + return x + else: + return x.encode() + +else: + + def xencode(x): + return x + + +del _sys + + +def hash128(key, seed=0x0, x64arch=True): + """Implements 128bit murmur3 hash.""" + + def hash128_x64(key, seed): + """Implements 128bit murmur3 hash for x64.""" + + def fmix(k): + k ^= k >> 33 + k = (k * 0xFF51AFD7ED558CCD) & 0xFFFFFFFFFFFFFFFF + k ^= k >> 33 + k = (k * 0xC4CEB9FE1A85EC53) & 0xFFFFFFFFFFFFFFFF + k ^= k >> 33 + return k + + length = len(key) + nblocks = int(length / 16) + + h1 = seed + h2 = seed + + c1 = 0x87C37B91114253D5 + c2 = 0x4CF5AD432745937F + + # body + for block_start in xrange(0, nblocks * 8, 8): + # ??? big endian? + k1 = ( + key[2 * block_start + 7] << 56 + | key[2 * block_start + 6] << 48 + | key[2 * block_start + 5] << 40 + | key[2 * block_start + 4] << 32 + | key[2 * block_start + 3] << 24 + | key[2 * block_start + 2] << 16 + | key[2 * block_start + 1] << 8 + | key[2 * block_start + 0] + ) + + k2 = ( + key[2 * block_start + 15] << 56 + | key[2 * block_start + 14] << 48 + | key[2 * block_start + 13] << 40 + | key[2 * block_start + 12] << 32 + | key[2 * block_start + 11] << 24 + | key[2 * block_start + 10] << 16 + | key[2 * block_start + 9] << 8 + | key[2 * block_start + 8] + ) + + k1 = (c1 * k1) & 0xFFFFFFFFFFFFFFFF + k1 = (k1 << 31 | k1 >> 33) & 0xFFFFFFFFFFFFFFFF # inlined ROTL64 + k1 = (c2 * k1) & 0xFFFFFFFFFFFFFFFF + h1 ^= k1 + + h1 = (h1 << 27 | h1 >> 37) & 0xFFFFFFFFFFFFFFFF # inlined ROTL64 + h1 = (h1 + h2) & 0xFFFFFFFFFFFFFFFF + h1 = (h1 * 5 + 0x52DCE729) & 0xFFFFFFFFFFFFFFFF + + k2 = (c2 * k2) & 0xFFFFFFFFFFFFFFFF + k2 = (k2 << 33 | k2 >> 31) & 0xFFFFFFFFFFFFFFFF # inlined ROTL64 + k2 = (c1 * k2) & 0xFFFFFFFFFFFFFFFF + h2 ^= k2 + + h2 = (h2 << 31 | h2 >> 33) & 0xFFFFFFFFFFFFFFFF # inlined ROTL64 + h2 = (h1 + h2) & 0xFFFFFFFFFFFFFFFF + h2 = (h2 * 5 + 0x38495AB5) & 0xFFFFFFFFFFFFFFFF + + # tail + tail_index = nblocks * 16 + k1 = 0 + k2 = 0 + tail_size = length & 15 + + if tail_size >= 15: + k2 ^= key[tail_index + 14] << 48 + if tail_size >= 14: + k2 ^= key[tail_index + 13] << 40 + if tail_size >= 13: + k2 ^= key[tail_index + 12] << 32 + if tail_size >= 12: + k2 ^= key[tail_index + 11] << 24 + if tail_size >= 11: + k2 ^= key[tail_index + 10] << 16 + if tail_size >= 10: + k2 ^= key[tail_index + 9] << 8 + if tail_size >= 9: + k2 ^= key[tail_index + 8] + + if tail_size > 8: + k2 = (k2 * c2) & 0xFFFFFFFFFFFFFFFF + k2 = (k2 << 33 | k2 >> 31) & 0xFFFFFFFFFFFFFFFF # inlined ROTL64 + k2 = (k2 * c1) & 0xFFFFFFFFFFFFFFFF + h2 ^= k2 + + if tail_size >= 8: + k1 ^= key[tail_index + 7] << 56 + if tail_size >= 7: + k1 ^= key[tail_index + 6] << 48 + if tail_size >= 6: + k1 ^= key[tail_index + 5] << 40 + if tail_size >= 5: + k1 ^= key[tail_index + 4] << 32 + if tail_size >= 4: + k1 ^= key[tail_index + 3] << 24 + if tail_size >= 3: + k1 ^= key[tail_index + 2] << 16 + if tail_size >= 2: + k1 ^= key[tail_index + 1] << 8 + if tail_size >= 1: + k1 ^= key[tail_index + 0] + + if tail_size > 0: + k1 = (k1 * c1) & 0xFFFFFFFFFFFFFFFF + k1 = (k1 << 31 | k1 >> 33) & 0xFFFFFFFFFFFFFFFF # inlined ROTL64 + k1 = (k1 * c2) & 0xFFFFFFFFFFFFFFFF + h1 ^= k1 + + # finalization + h1 ^= length + h2 ^= length + + h1 = (h1 + h2) & 0xFFFFFFFFFFFFFFFF + h2 = (h1 + h2) & 0xFFFFFFFFFFFFFFFF + + h1 = fmix(h1) + h2 = fmix(h2) + + h1 = (h1 + h2) & 0xFFFFFFFFFFFFFFFF + h2 = (h1 + h2) & 0xFFFFFFFFFFFFFFFF + + return h2 << 64 | h1 + + def hash128_x86(key, seed): + """Implements 128bit murmur3 hash for x86.""" + + def fmix(h): + h ^= h >> 16 + h = (h * 0x85EBCA6B) & 0xFFFFFFFF + h ^= h >> 13 + h = (h * 0xC2B2AE35) & 0xFFFFFFFF + h ^= h >> 16 + return h + + length = len(key) + nblocks = int(length / 16) + + h1 = seed + h2 = seed + h3 = seed + h4 = seed + + c1 = 0x239B961B + c2 = 0xAB0E9789 + c3 = 0x38B34AE5 + c4 = 0xA1E38B93 + + # body + for block_start in xrange(0, nblocks * 16, 16): + k1 = ( + key[block_start + 3] << 24 + | key[block_start + 2] << 16 + | key[block_start + 1] << 8 + | key[block_start + 0] + ) + + k2 = ( + key[block_start + 7] << 24 + | key[block_start + 6] << 16 + | key[block_start + 5] << 8 + | key[block_start + 4] + ) + + k3 = ( + key[block_start + 11] << 24 + | key[block_start + 10] << 16 + | key[block_start + 9] << 8 + | key[block_start + 8] + ) + + k4 = ( + key[block_start + 15] << 24 + | key[block_start + 14] << 16 + | key[block_start + 13] << 8 + | key[block_start + 12] + ) + + k1 = (c1 * k1) & 0xFFFFFFFF + k1 = (k1 << 15 | k1 >> 17) & 0xFFFFFFFF # inlined ROTL32 + k1 = (c2 * k1) & 0xFFFFFFFF + h1 ^= k1 + + h1 = (h1 << 19 | h1 >> 13) & 0xFFFFFFFF # inlined ROTL32 + h1 = (h1 + h2) & 0xFFFFFFFF + h1 = (h1 * 5 + 0x561CCD1B) & 0xFFFFFFFF + + k2 = (c2 * k2) & 0xFFFFFFFF + k2 = (k2 << 16 | k2 >> 16) & 0xFFFFFFFF # inlined ROTL32 + k2 = (c3 * k2) & 0xFFFFFFFF + h2 ^= k2 + + h2 = (h2 << 17 | h2 >> 15) & 0xFFFFFFFF # inlined ROTL32 + h2 = (h2 + h3) & 0xFFFFFFFF + h2 = (h2 * 5 + 0x0BCAA747) & 0xFFFFFFFF + + k3 = (c3 * k3) & 0xFFFFFFFF + k3 = (k3 << 17 | k3 >> 15) & 0xFFFFFFFF # inlined ROTL32 + k3 = (c4 * k3) & 0xFFFFFFFF + h3 ^= k3 + + h3 = (h3 << 15 | h3 >> 17) & 0xFFFFFFFF # inlined ROTL32 + h3 = (h3 + h4) & 0xFFFFFFFF + h3 = (h3 * 5 + 0x96CD1C35) & 0xFFFFFFFF + + k4 = (c4 * k4) & 0xFFFFFFFF + k4 = (k4 << 18 | k4 >> 14) & 0xFFFFFFFF # inlined ROTL32 + k4 = (c1 * k4) & 0xFFFFFFFF + h4 ^= k4 + + h4 = (h4 << 13 | h4 >> 19) & 0xFFFFFFFF # inlined ROTL32 + h4 = (h1 + h4) & 0xFFFFFFFF + h4 = (h4 * 5 + 0x32AC3B17) & 0xFFFFFFFF + + # tail + tail_index = nblocks * 16 + k1 = 0 + k2 = 0 + k3 = 0 + k4 = 0 + tail_size = length & 15 + + if tail_size >= 15: + k4 ^= key[tail_index + 14] << 16 + if tail_size >= 14: + k4 ^= key[tail_index + 13] << 8 + if tail_size >= 13: + k4 ^= key[tail_index + 12] + + if tail_size > 12: + k4 = (k4 * c4) & 0xFFFFFFFF + k4 = (k4 << 18 | k4 >> 14) & 0xFFFFFFFF # inlined ROTL32 + k4 = (k4 * c1) & 0xFFFFFFFF + h4 ^= k4 + + if tail_size >= 12: + k3 ^= key[tail_index + 11] << 24 + if tail_size >= 11: + k3 ^= key[tail_index + 10] << 16 + if tail_size >= 10: + k3 ^= key[tail_index + 9] << 8 + if tail_size >= 9: + k3 ^= key[tail_index + 8] + + if tail_size > 8: + k3 = (k3 * c3) & 0xFFFFFFFF + k3 = (k3 << 17 | k3 >> 15) & 0xFFFFFFFF # inlined ROTL32 + k3 = (k3 * c4) & 0xFFFFFFFF + h3 ^= k3 + + if tail_size >= 8: + k2 ^= key[tail_index + 7] << 24 + if tail_size >= 7: + k2 ^= key[tail_index + 6] << 16 + if tail_size >= 6: + k2 ^= key[tail_index + 5] << 8 + if tail_size >= 5: + k2 ^= key[tail_index + 4] + + if tail_size > 4: + k2 = (k2 * c2) & 0xFFFFFFFF + k2 = (k2 << 16 | k2 >> 16) & 0xFFFFFFFF # inlined ROTL32 + k2 = (k2 * c3) & 0xFFFFFFFF + h2 ^= k2 + + if tail_size >= 4: + k1 ^= key[tail_index + 3] << 24 + if tail_size >= 3: + k1 ^= key[tail_index + 2] << 16 + if tail_size >= 2: + k1 ^= key[tail_index + 1] << 8 + if tail_size >= 1: + k1 ^= key[tail_index + 0] + + if tail_size > 0: + k1 = (k1 * c1) & 0xFFFFFFFF + k1 = (k1 << 15 | k1 >> 17) & 0xFFFFFFFF # inlined ROTL32 + k1 = (k1 * c2) & 0xFFFFFFFF + h1 ^= k1 + + # finalization + h1 ^= length + h2 ^= length + h3 ^= length + h4 ^= length + + h1 = (h1 + h2) & 0xFFFFFFFF + h1 = (h1 + h3) & 0xFFFFFFFF + h1 = (h1 + h4) & 0xFFFFFFFF + h2 = (h1 + h2) & 0xFFFFFFFF + h3 = (h1 + h3) & 0xFFFFFFFF + h4 = (h1 + h4) & 0xFFFFFFFF + + h1 = fmix(h1) + h2 = fmix(h2) + h3 = fmix(h3) + h4 = fmix(h4) + + h1 = (h1 + h2) & 0xFFFFFFFF + h1 = (h1 + h3) & 0xFFFFFFFF + h1 = (h1 + h4) & 0xFFFFFFFF + h2 = (h1 + h2) & 0xFFFFFFFF + h3 = (h1 + h3) & 0xFFFFFFFF + h4 = (h1 + h4) & 0xFFFFFFFF + + return h4 << 96 | h3 << 64 | h2 << 32 | h1 + + key = bytearray(xencode(key)) + + if x64arch: + return hash128_x64(key, seed) + else: + return hash128_x86(key, seed) diff --git a/haystack/schema.py b/haystack/schema.py index db4c32bc3b..40992ac49e 100644 --- a/haystack/schema.py +++ b/haystack/schema.py @@ -18,7 +18,6 @@ import ast from dataclasses import asdict -import mmh3 import numpy as np from numpy import ndarray import pandas as pd @@ -32,6 +31,7 @@ from pydantic.dataclasses import dataclass from haystack import is_imported +from haystack.mmh3 import hash128 logger = logging.getLogger(__name__) @@ -147,7 +147,7 @@ def _get_id(self, id_hash_keys: Optional[List[str]] = None): """ if id_hash_keys is None: - return "{:02x}".format(mmh3.hash128(str(self.content), signed=False)) + return "{:02x}".format(hash128(str(self.content))) final_hash_key = "" for attr in id_hash_keys: @@ -163,7 +163,7 @@ def _get_id(self, id_hash_keys: Optional[List[str]] = None): "Can't create 'Document': 'id_hash_keys' must contain at least one of ['content', 'meta'] or be set to None." ) - return "{:02x}".format(mmh3.hash128(final_hash_key, signed=False)) + return "{:02x}".format(hash128(final_hash_key)) def to_dict(self, field_map: Optional[Dict[str, Any]] = None) -> Dict: """ diff --git a/haystack/utils/squad_data.py b/haystack/utils/squad_data.py index 80ee9ac45f..037be09e84 100644 --- a/haystack/utils/squad_data.py +++ b/haystack/utils/squad_data.py @@ -5,13 +5,12 @@ import random import pandas as pd from tqdm.auto import tqdm -import mmh3 from haystack import is_imported +from haystack.mmh3 import hash128 from haystack.schema import Document, Label, Answer from haystack.modeling.data_handler.processor import _read_squad_file - logger = logging.getLogger(__name__) @@ -112,7 +111,7 @@ def to_df(data): title = document.get("title", "") for paragraph in document["paragraphs"]: context = paragraph["context"] - document_id = paragraph.get("document_id", "{:02x}".format(mmh3.hash128(str(context), signed=False))) + document_id = paragraph.get("document_id", "{:02x}".format(hash128(str(context)))) for question in paragraph["qas"]: q = question["question"] id = question["id"] diff --git a/pyproject.toml b/pyproject.toml index 61e8969162..7518afb4ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,6 @@ dependencies = [ "dill", # pickle extension for (de-)serialization "tqdm", # progress bars in model download and training scripts "networkx", # graphs library - "mmh3", # fast hashing function (murmurhash3) "quantulum3", # quantities extraction from text "posthog", # telemetry "azure-ai-formrecognizer>=3.2.0b2", # forms reader diff --git a/test/agents/test_agent.py b/test/agents/test_agent.py index f1b1a48255..7bdaff5ba6 100644 --- a/test/agents/test_agent.py +++ b/test/agents/test_agent.py @@ -7,7 +7,7 @@ from unittest import mock import pytest -from haystack import BaseComponent, Answer +from haystack import BaseComponent, Answer, Document from haystack.agents import Agent, AgentStep from haystack.agents.base import Tool, ToolsManager from haystack.nodes import PromptModel, PromptNode, PromptTemplate @@ -276,6 +276,20 @@ def test_update_hash(): assert agent.hash == "5ac8eca2f92c9545adcce3682b80d4c5" +@pytest.mark.unit +def test_tool_fails_processing_dict_result(): + tool = Tool(name="name", pipeline_or_node=MockPromptNode(), description="description") + with pytest.raises(ValueError): + tool._process_result({"answer": "answer"}) + + +@pytest.mark.unit +def test_tool_processes_answer_result_and_document_result(): + tool = Tool(name="name", pipeline_or_node=MockPromptNode(), description="description") + assert tool._process_result(Answer(answer="answer")) == "answer" + assert tool._process_result(Document(content="content")) == "content" + + def test_invalid_agent_template(): pn = PromptNode() with pytest.raises(ValueError, match="some_non_existing_template not supported"): diff --git a/test/modeling/test_processor.py b/test/modeling/test_processor.py index 2f053fefc6..9a45de953c 100644 --- a/test/modeling/test_processor.py +++ b/test/modeling/test_processor.py @@ -1,10 +1,11 @@ import copy import logging +from pathlib import Path import pytest from transformers import AutoTokenizer -from haystack.modeling.data_handler.processor import SquadProcessor +from haystack.modeling.data_handler.processor import SquadProcessor, _is_json # during inference (parameter return_baskets = False) we do not convert labels @@ -300,6 +301,17 @@ def test_dataset_from_dicts_qa_label_conversion(samples_path, caplog=None): ], f"Processing labels for {model} has changed." +@pytest.mark.unit +def test_is_json_identifies_json_objects(): + """Test that _is_json correctly identifies json objects""" + # Paths to json files should be considered json + assert _is_json(Path("processor_config.json")) + # dicts should be considered json + assert _is_json({"a": 1}) + # non-serializable objects should not be considered json + assert not _is_json(AutoTokenizer) + + @pytest.mark.integration def test_dataset_from_dicts_auto_determine_max_answers(samples_path, caplog=None): """ diff --git a/test/others/test_squad_data.py b/test/others/test_squad_data.py index 7aa1564587..216c9ff83d 100644 --- a/test/others/test_squad_data.py +++ b/test/others/test_squad_data.py @@ -1,4 +1,6 @@ import pandas as pd +import pytest + from haystack.utils.squad_data import SquadData from haystack.utils.augment_squad import augment_squad from haystack.schema import Document, Label, Answer @@ -22,7 +24,8 @@ def test_squad_augmentation(samples_path): assert original_squad.count(unit="paragraph") == augmented_squad.count(unit="paragraph") * multiplication_factor -def test_squad_to_df(): +@pytest.mark.unit +def test_squad_data_converts_df_to_data(): df = pd.DataFrame( [["title", "context", "question", "id", "answer", 1, False]], columns=["title", "context", "question", "id", "answer_text", "answer_start", "is_impossible"], @@ -51,6 +54,35 @@ def test_squad_to_df(): assert result == expected_result +@pytest.mark.unit +def test_squad_data_converts_data_to_df(): + data = [ + { + "title": "title", + "paragraphs": [ + { + "context": "context", + "document_id": "document_id", + "qas": [ + { + "question": "question", + "id": "id", + "answers": [{"text": "answer", "answer_start": 1}], + "is_impossible": False, + } + ], + } + ], + } + ] + expected_result = pd.DataFrame( + [["title", "context", "question", "id", "answer", 1, False, "document_id"]], + columns=["title", "context", "question", "id", "answer_text", "answer_start", "is_impossible", "document_id"], + ) + result = SquadData.to_df(data) + assert result.equals(expected_result) + + def test_to_label_object(): squad_data_list = [ { diff --git a/test/utils/test_mmh3.py b/test/utils/test_mmh3.py new file mode 100644 index 0000000000..127b1161a8 --- /dev/null +++ b/test/utils/test_mmh3.py @@ -0,0 +1,10 @@ +import pytest + +from haystack.mmh3 import hash128 + + +@pytest.mark.unit +def test_mmh3(): + content = "This is the document text" * 100 + hashed_content = hash128(content) + assert hashed_content == 305042678480070366459393623793278501577 From 3ea784464ae3f367b3c4da951d47af0598c58967 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Thu, 18 May 2023 09:12:03 +0200 Subject: [PATCH 09/13] add test case for #4929 (#4936) --- test/nodes/test_shaper.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/nodes/test_shaper.py b/test/nodes/test_shaper.py index 9c29ff29cb..a61f8e87c0 100644 --- a/test/nodes/test_shaper.py +++ b/test/nodes/test_shaper.py @@ -29,6 +29,13 @@ def test_basic_invocation_only_inputs(mock_function): assert results["invocation_context"]["c"] == ["test query", "test query", "test query"] +@pytest.mark.unit +def test_basic_invocation_empty_documents_list(mock_function): + shaper = Shaper(func="test_function", inputs={"a": "query", "b": "documents"}, outputs=["c"]) + results, _ = shaper.run(query="test query", documents=[]) + assert results["invocation_context"]["c"] == [] + + @pytest.mark.unit def test_multiple_outputs(mock_function_two_outputs): shaper = Shaper(func="two_output_test_function", inputs={"a": "query"}, outputs=["c", "d"]) From ad162f2e65530303ee1923672c5526c8254becc5 Mon Sep 17 00:00:00 2001 From: Shukri Date: Thu, 18 May 2023 10:17:11 +0200 Subject: [PATCH 10/13] feat: Support authentication using AuthBearerToken and AuthClientCredentials in Weaviate (#4028) * refactor: make the scope param configurable the scope parameter is used when authenticating using AuthClientPassword and AuthClientCredentials * feat: add support for AuthClientCredentials add support for authenticating using the OIDC Client Credentials authentication flow * feat: add support for AuthBearerToken Add support for authenticating using OIDC and bearer tokens * Update lg * refactor how client is built Signed-off-by: hsm207 * unit test the auth methods Signed-off-by: hsm207 * Update test_weaviate.py * revert formatting change * Fix type hints --------- Signed-off-by: hsm207 Co-authored-by: John Doe Co-authored-by: agnieszka-m Co-authored-by: Massimiliano Pippi --- haystack/document_stores/weaviate.py | 102 ++++++++++++++++---------- test/document_stores/test_weaviate.py | 20 +++++ 2 files changed, 83 insertions(+), 39 deletions(-) diff --git a/haystack/document_stores/weaviate.py b/haystack/document_stores/weaviate.py index b49faac641..eafc52737e 100644 --- a/haystack/document_stores/weaviate.py +++ b/haystack/document_stores/weaviate.py @@ -11,7 +11,7 @@ try: import weaviate - from weaviate import client, AuthClientPassword, gql + from weaviate import client, AuthClientPassword, gql, AuthClientCredentials, AuthBearerToken except (ImportError, ModuleNotFoundError) as ie: from haystack.utils.import_utils import _optional_component_not_installed @@ -78,6 +78,11 @@ def __init__( timeout_config: tuple = (5, 15), username: Optional[str] = None, password: Optional[str] = None, + client_secret: Optional[str] = None, + scope: Optional[str] = "offline_access", + access_token: Optional[str] = None, + expires_in: Optional[int] = 60, + refresh_token: Optional[str] = None, additional_headers: Optional[Dict[str, Any]] = None, index: str = "Document", embedding_dim: int = 768, @@ -95,58 +100,58 @@ def __init__( ): """ :param host: Weaviate server connection URL for storing and processing documents and vectors. - For more details, refer "https://weaviate.io/developers/weaviate/current/getting-started/installation.html" - :param port: port of Weaviate instance - :param timeout_config: Weaviate Timeout config as a tuple of (retries, time out seconds). - :param username: username (standard authentication via http_auth) - :param password: password (standard authentication via http_auth) - :param additional_headers: additional headers to be included in the requests sent to Weaviate e.g. bearer token - :param index: Index name for document text, embedding and metadata (in Weaviate terminology, this is a "Class" in Weaviate schema). + For more details, see [Weaviate installation](https://weaviate.io/developers/weaviate/current/getting-started/installation.html). + :param port: The port of the Weaviate instance. + :param timeout_config: The Weaviate timeout config as a tuple of (retries, time out seconds). + :param username: The Weaviate username (standard authentication using http_auth). + :param password: Weaviate password (standard authentication using http_auth). + :param client_secret: The client secret to use when using the OIDC Client Credentials authentication flow. + :param scope: The scope of the credentials when using the OIDC Resource Owner Password or Client Credentials authentication flow. + :param access_token: Access token to use when using OIDC and bearer tokens to authenticate. + :param expires_in: The time in seconds after which the access token expires. + :param refresh_token: The refresh token to use when using OIDC and bearer tokens to authenticate. + :param additional_headers: Additional headers to be included in the requests sent to Weaviate, for example the bearer token. + :param index: Index name for document text, embedding, and metadata (in Weaviate terminology, this is a "Class" in the Weaviate schema). :param embedding_dim: The embedding vector size. Default: 768. - :param content_field: Name of field that might contain the answer and will therefore be passed to the Reader Model (e.g. "full_text"). - If no Reader is used (e.g. in FAQ-Style QA) the plain content of this field will just be returned. - :param name_field: Name of field that contains the title of the doc - :param similarity: The similarity function used to compare document vectors. Available options are 'cosine' (default), 'dot_product' and 'l2'. + :param content_field: Name of the field that might contain the answer and is passed to the Reader model (for example, "full_text"). + If no Reader is used (for example, in FAQ-Style QA), the plain content of this field is returned. + :param name_field: Name of the field that contains the title of the doc. + :param similarity: The similarity function used to compare document vectors. Available options are 'cosine' (default), 'dot_product', and 'l2'. 'cosine' is recommended for Sentence Transformers. - :param index_type: Index type of any vector object defined in weaviate schema. The vector index type is pluggable. - Currently, HSNW is only supported. - See: https://weaviate.io/developers/weaviate/current/more-resources/performance.html - :param custom_schema: Allows to create custom schema in Weaviate, for more details - See https://weaviate.io/developers/weaviate/current/schema/schema-configuration.html + :param index_type: Index type of any vector object defined in the Weaviate schema. The vector index type is pluggable. + Currently, only HSNW is supported. + See also [Weaviate documentation](https://weaviate.io/developers/weaviate/current/more-resources/performance.html). + :param custom_schema: Allows to create a custom schema in Weaviate. For more details, + see [Weaviate documentation](https://weaviate.io/developers/weaviate/current/schema/schema-configuration.html). :param module_name: Vectorization module to convert data into vectors. Default is "text2vec-trasnformers" - For more details, See https://weaviate.io/developers/weaviate/current/modules/ - :param return_embedding: To return document embedding. - :param embedding_field: Name of field containing an embedding vector. + For more details, see [Weaviate documentation](https://weaviate.io/developers/weaviate/current/modules/). + :param return_embedding: Returns document embedding. + :param embedding_field: Name of the field containing an embedding vector. :param progress_bar: Whether to show a tqdm progress bar or not. Can be helpful to disable in production deployments to keep the logs clean. :param duplicate_documents:Handle duplicates document based on parameter options. - Parameter options : ( 'skip','overwrite','fail') + Parameter options: 'skip','overwrite','fail' skip: Ignore the duplicates documents overwrite: Update any existing documents with the same ID when adding documents. - fail: an error is raised if the document ID of the document being added already exists. - :param recreate_index: If set to True, an existing Weaviate index will be deleted and a new one will be - created using the config you are using for initialization. Be aware that all data in the old index will be + fail: Raises an error if the document ID of the document being added already exists. + :param recreate_index: If set to True, deletes an existing Weaviate index and creates a new one using the config you are using for initialization. Note that all data in the old index is lost if you choose to recreate the index. - :param replication_factor: It sets the Weaviate Class's replication factor in Weaviate at the time of Class creation. - See: https://weaviate.io/developers/weaviate/current/configuration/replication.html + :param replication_factor: Sets the Weaviate Class's replication factor in Weaviate at the time of Class creation. + See also [Weaviate documentation](https://weaviate.io/developers/weaviate/current/configuration/replication.html). """ super().__init__() # Connect to Weaviate server using python binding weaviate_url = f"{host}:{port}" - if username and password: - secret = AuthClientPassword(username, password) - self.weaviate_client = client.Client( - url=weaviate_url, - auth_client_secret=secret, - timeout_config=timeout_config, - additional_headers=additional_headers, - ) - else: - self.weaviate_client = client.Client( - url=weaviate_url, timeout_config=timeout_config, additional_headers=additional_headers - ) - + secret = self._get_auth_secret( + username, password, client_secret, access_token, expires_in, refresh_token, scope + ) + self.weaviate_client = client.Client( + url=weaviate_url, + auth_client_secret=secret, + timeout_config=timeout_config, + additional_headers=additional_headers, + ) # Test Weaviate connection try: status = self.weaviate_client.is_ready() @@ -185,6 +190,25 @@ def __init__( self._create_schema_and_index(self.index, recreate_index=recreate_index) self.uuid_format_warning_raised = False + @staticmethod + def _get_auth_secret( + username: Optional[str] = None, + password: Optional[str] = None, + client_secret: Optional[str] = None, + access_token: Optional[str] = None, + expires_in: Optional[int] = 60, + refresh_token: Optional[str] = None, + scope: Optional[str] = "offline_access", + ) -> Optional[Union["AuthClientPassword", "AuthClientCredentials", "AuthBearerToken"]]: + if username and password: + return AuthClientPassword(username, password, scope=scope) + elif client_secret: + return AuthClientCredentials(client_secret, scope=scope) + elif access_token: + return AuthBearerToken(access_token, expires_in=expires_in, refresh_token=refresh_token) + + return None + def _sanitize_index_name(self, index: Optional[str]) -> Optional[str]: if index is None: return None diff --git a/test/document_stores/test_weaviate.py b/test/document_stores/test_weaviate.py index 830b031833..54c989dbcc 100644 --- a/test/document_stores/test_weaviate.py +++ b/test/document_stores/test_weaviate.py @@ -274,6 +274,26 @@ def test_get_embedding_count(self, ds, documents): ds.write_documents(documents) assert ds.get_embedding_count() == 9 + @pytest.mark.unit + def test__get_auth_secret(self): + # Test with username and password + secret = WeaviateDocumentStore._get_auth_secret("user", "pass", scope="some_scope") + assert isinstance(secret, weaviate.AuthClientPassword) + + # Test with client_secret + secret = WeaviateDocumentStore._get_auth_secret(client_secret="client_secret_value", scope="some_scope") + assert isinstance(secret, weaviate.AuthClientCredentials) + + # Test with access_token + secret = WeaviateDocumentStore._get_auth_secret( + access_token="access_token_value", expires_in=3600, refresh_token="refresh_token_value" + ) + assert isinstance(secret, weaviate.AuthBearerToken) + + # Test with no authentication method + secret = WeaviateDocumentStore._get_auth_secret() + assert secret is None + @pytest.mark.unit def test__get_current_properties(self, mocked_ds): mocked_ds.weaviate_client.schema.get.return_value = json.loads( From df55ec5e61d7163e15eb88e1936f9009a37e65cb Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Thu, 18 May 2023 12:22:16 +0200 Subject: [PATCH 11/13] Pin Weaviate client (#4952) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7518afb4ac..d1b95441ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,7 +117,7 @@ milvus = [ "farm-haystack[sql,only-milvus]", ] weaviate = [ - "weaviate-client>=3.10.0,<4", + "weaviate-client<3.19.0", ] only-pinecone = [ "pinecone-client>=2.0.11,<3", From 401520b1d221733d6aa62f2e29e14e39b908e53a Mon Sep 17 00:00:00 2001 From: Daria Fokina Date: Thu, 18 May 2023 14:20:51 +0200 Subject: [PATCH 12/13] web.py docstring update (#4921) Corrected spelling and added GoogleAPI --- haystack/nodes/search_engine/web.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/haystack/nodes/search_engine/web.py b/haystack/nodes/search_engine/web.py index 9a49713677..087e16cbf3 100644 --- a/haystack/nodes/search_engine/web.py +++ b/haystack/nodes/search_engine/web.py @@ -12,10 +12,11 @@ class WebSearch(BaseComponent): of the underlying search engine provider, provides common interface for all providers, and makes it possible to use various search engines. - WebSerach currently supports the following search engines providers (bridges): + WebSearch currently supports the following search engines providers (bridges): - SerperDev (default) - SerpAPI - BingAPI + - GoogleAPI """ From 5d7ee2e5e616e1e4b216dfa6d5d5547d68699982 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 18 May 2023 15:19:29 +0200 Subject: [PATCH 13/13] feat: Add max_tokens to BaseGenerator params (#4168) * Add max_tokens to BaseGenerator params * Make mypy happy * Rebase and resolve conflicts * Fix signature issues * Update lg * Add a mocked unit test method * end-of-file-fixer corrected file * Convert to unit test * Mark test as integration * make the test unit --------- Co-authored-by: agnieszka-m Co-authored-by: Massimiliano Pippi --- haystack/nodes/answer_generator/base.py | 36 +++++++++++++------ haystack/nodes/answer_generator/openai.py | 4 ++- .../nodes/answer_generator/transformers.py | 14 +++++--- test/conftest.py | 2 +- test/nodes/test_generator.py | 27 ++++++++++++-- 5 files changed, 65 insertions(+), 18 deletions(-) diff --git a/haystack/nodes/answer_generator/base.py b/haystack/nodes/answer_generator/base.py index 4e98bdc05a..01068605de 100644 --- a/haystack/nodes/answer_generator/base.py +++ b/haystack/nodes/answer_generator/base.py @@ -20,27 +20,38 @@ def __init__(self, progress_bar: bool = True): self.progress_bar = progress_bar @abstractmethod - def predict(self, query: str, documents: List[Document], top_k: Optional[int]) -> Dict: + def predict(self, query: str, documents: List[Document], top_k: Optional[int], max_tokens: Optional[int]) -> Dict: """ Abstract method to generate answers. :param query: Query - :param documents: Related documents (e.g. coming from a retriever) that the answer shall be conditioned on. - :param top_k: Number of returned answers + :param documents: Related documents (for example, coming from a retriever) the answer should be based on. + :param top_k: Number of returned answers. + :param max_tokens: THe maximum number of tokens the generated answer can have. :return: Generated answers plus additional infos in a dict """ pass - def run(self, query: str, documents: List[Document], top_k: Optional[int] = None, labels: Optional[MultiLabel] = None, add_isolated_node_eval: bool = False): # type: ignore + def run( # type: ignore + self, + query: str, + documents: List[Document], + top_k: Optional[int] = None, + labels: Optional[MultiLabel] = None, + add_isolated_node_eval: bool = False, + max_tokens: Optional[int] = None, + ): # type: ignore if documents: - results = self.predict(query=query, documents=documents, top_k=top_k) + results = self.predict(query=query, documents=documents, top_k=top_k, max_tokens=max_tokens) else: results = {"answers": []} # run evaluation with "perfect" labels as node inputs to calculate "upper bound" metrics for just this node if add_isolated_node_eval and labels is not None: relevant_documents = list({label.document.id: label.document for label in labels.labels}.values()) - results_label_input = self.predict(query=query, documents=relevant_documents, top_k=top_k) + results_label_input = self.predict( + query=query, documents=relevant_documents, top_k=top_k, max_tokens=max_tokens + ) results["answers_isolated"] = results_label_input["answers"] return results, "output_1" @@ -51,8 +62,11 @@ def run_batch( # type: ignore documents: Union[List[Document], List[List[Document]]], top_k: Optional[int] = None, batch_size: Optional[int] = None, + max_tokens: Optional[int] = None, ): - results = self.predict_batch(queries=queries, documents=documents, top_k=top_k, batch_size=batch_size) + results = self.predict_batch( + queries=queries, documents=documents, top_k=top_k, batch_size=batch_size, max_tokens=max_tokens + ) return results, "output_1" def _flatten_docs(self, documents: List[Document]): @@ -92,6 +106,7 @@ def predict_batch( documents: Union[List[Document], List[List[Document]]], top_k: Optional[int] = None, batch_size: Optional[int] = None, + max_tokens: Optional[int] = None, ): """ Generate the answer to the input queries. The generation will be conditioned on the supplied documents. @@ -110,10 +125,11 @@ def predict_batch( and the Answers will be aggregated per query-Document pair. :param queries: List of queries. - :param documents: Related documents (e.g. coming from a retriever) that the answer shall be conditioned on. + :param documents: Related documents (for example, coming from a retriever) the answer should be based on. Can be a single list of Documents or a list of lists of Documents. :param top_k: Number of returned answers per query. :param batch_size: Not applicable. + :param max_tokens: The maximum number of tokens the generated answer can have. :return: Generated answers plus additional infos in a dict like this: ```python @@ -142,7 +158,7 @@ def predict_batch( for doc in documents: if not isinstance(doc, Document): raise HaystackError(f"doc was of type {type(doc)}, but expected a Document.") - preds = self.predict(query=query, documents=[doc], top_k=top_k) + preds = self.predict(query=query, documents=[doc], top_k=top_k, max_tokens=max_tokens) results["answers"].append(preds["answers"]) pb.update(1) pb.close() @@ -158,7 +174,7 @@ def predict_batch( for query, cur_docs in zip(queries, documents): if not isinstance(cur_docs, list): raise HaystackError(f"cur_docs was of type {type(cur_docs)}, but expected a list of Documents.") - preds = self.predict(query=query, documents=cur_docs, top_k=top_k) + preds = self.predict(query=query, documents=cur_docs, top_k=top_k, max_tokens=max_tokens) results["answers"].append(preds["answers"]) pb.update(1) pb.close() diff --git a/haystack/nodes/answer_generator/openai.py b/haystack/nodes/answer_generator/openai.py index ffe5ae9690..fc1558c880 100644 --- a/haystack/nodes/answer_generator/openai.py +++ b/haystack/nodes/answer_generator/openai.py @@ -170,6 +170,7 @@ def predict( query: str, documents: List[Document], top_k: Optional[int] = None, + max_tokens: Optional[int] = None, timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT, ): """ @@ -193,6 +194,7 @@ def predict( :param query: The query you want to provide. It's a string. :param documents: List of Documents in which to search for the Answer. :param top_k: The maximum number of Answers to return. + :param max_tokens: The maximum number of tokens the generated Answer can have. :param timeout: How many seconds to wait for the server to send data before giving up, as a float, or a :ref:`(connect timeout, read timeout) ` tuple. Defaults to 10 seconds. @@ -208,7 +210,7 @@ def predict( payload = { "model": self.model, "prompt": prompt, - "max_tokens": self.max_tokens, + "max_tokens": max_tokens or self.max_tokens, "stop": self.stop_words, "n": top_k, "temperature": self.temperature, diff --git a/haystack/nodes/answer_generator/transformers.py b/haystack/nodes/answer_generator/transformers.py index 07ce857e9d..d7d832201d 100644 --- a/haystack/nodes/answer_generator/transformers.py +++ b/haystack/nodes/answer_generator/transformers.py @@ -214,7 +214,9 @@ def _prepare_passage_embeddings(self, docs: List[Document], embeddings: numpy.nd return embeddings_in_tensor.to(self.devices[0]) - def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict: + def predict( + self, query: str, documents: List[Document], top_k: Optional[int] = None, max_tokens: Optional[int] = None + ) -> Dict: """ Generate the answer to the input query. The generation will be conditioned on the supplied documents. These documents can for example be retrieved via the Retriever. @@ -222,6 +224,7 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] = :param query: Query :param documents: Related documents (e.g. coming from a retriever) that the answer shall be conditioned on. :param top_k: Number of returned answers + :param max_tokens: Maximum number of tokens to generate :return: Generated answers plus additional infos in a dict like this: ```python @@ -279,7 +282,7 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] = doc_scores=doc_scores, num_return_sequences=top_k, num_beams=self.num_beams, - max_length=self.max_length, + max_length=max_tokens or self.max_length, min_length=self.min_length, n_docs=len(flat_docs_dict["content"]), ) @@ -430,7 +433,9 @@ def _get_converter(cls, model_name_or_path: str) -> Optional[Callable]: model_name_or_path = "yjernite/bart_eli5" return cls._model_input_converters.get(model_name_or_path) - def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict: + def predict( + self, query: str, documents: List[Document], top_k: Optional[int] = None, max_tokens: Optional[int] = None + ) -> Dict: """ Generate the answer to the input query. The generation will be conditioned on the supplied documents. These document can be retrieved via the Retriever or supplied directly via predict method. @@ -438,6 +443,7 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] = :param query: Query :param documents: Related documents (e.g. coming from a retriever) that the answer shall be conditioned on. :param top_k: Number of returned answers + :param max_tokens: Maximum number of tokens in the generated answer :return: Generated answers """ @@ -474,7 +480,7 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] = input_ids=query_and_docs_encoded["input_ids"], attention_mask=query_and_docs_encoded["attention_mask"], min_length=self.min_length, - max_length=self.max_length, + max_length=max_tokens or self.max_length, do_sample=True if self.num_beams == 1 else False, early_stopping=True, num_beams=self.num_beams, diff --git a/test/conftest.py b/test/conftest.py index 1b3d3d19f8..3d2bd80f7b 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -329,7 +329,7 @@ def embed_documents(self, documents: List[Document]): class MockSeq2SegGenerator(BaseGenerator): - def predict(self, query: str, documents: List[Document], top_k: Optional[int]) -> Dict: + def predict(self, query: str, documents: List[Document], top_k: Optional[int], max_tokens: Optional[int]) -> Dict: pass diff --git a/test/nodes/test_generator.py b/test/nodes/test_generator.py index dde821fb0f..a224294e47 100644 --- a/test/nodes/test_generator.py +++ b/test/nodes/test_generator.py @@ -1,9 +1,9 @@ import sys from typing import List -from unittest.mock import patch +from unittest.mock import patch, create_autospec import pytest - +from haystack import Pipeline from haystack.schema import Document from haystack.nodes.answer_generator import Seq2SeqGenerator, OpenAIAnswerGenerator, RAGenerator from haystack.pipelines import GenerativeQAPipeline @@ -218,3 +218,26 @@ def test_build_prompt_within_max_length(): assert len(prompt_docs) == 1 assert prompt_docs[0] == documents[0] + + +@pytest.mark.unit +def test_openai_answer_generator_pipeline_max_tokens(): + """ + tests that the max_tokens parameter is passed to the generator component in the pipeline + """ + question = "What is New York City like?" + mocked_response = "Forget NYC, I was generated by the mock method." + nyc_docs = [Document(content="New York is a cool and amazing city to live in the United States of America.")] + pipeline = Pipeline() + + # mock load_openai_tokenizer to avoid accessing the internet to init tiktoken + with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"): + openai_generator = OpenAIAnswerGenerator(api_key="fake_api_key", model="text-babbage-001", top_k=1) + + pipeline.add_node(component=openai_generator, name="generator", inputs=["Query"]) + openai_generator.run = create_autospec(openai_generator.run) + openai_generator.run.return_value = ({"answers": mocked_response}, "output_1") + + result = pipeline.run(query=question, documents=nyc_docs, params={"generator": {"max_tokens": 3}}) + assert result["answers"] == mocked_response + openai_generator.run.assert_called_with(query=question, documents=nyc_docs, max_tokens=3)