From 801c705482a9982493922cdcb2b1b71d86238314 Mon Sep 17 00:00:00 2001 From: Estelle Scifo Date: Tue, 14 Oct 2025 13:51:09 +0200 Subject: [PATCH] Revert "Update LLMInterface to restore LC compatibility (#416)" This reverts commit 7b4e2d39f6931721c0899f37a1c428238990cbe6. --- CHANGELOG.md | 7 - examples/README.md | 4 +- ...atibility.py => langchain_compatiblity.py} | 0 examples/customize/llms/anthropic_llm.py | 18 +- examples/customize/llms/cohere_llm.py | 14 +- examples/customize/llms/custom_llm.py | 25 +- examples/customize/llms/mistalai_llm.py | 10 + examples/customize/llms/mistralai_llm.py | 32 -- examples/customize/llms/ollama_llm.py | 19 +- examples/customize/llms/openai_llm.py | 22 +- examples/customize/llms/vertexai_llm.py | 17 +- src/neo4j_graphrag/generation/graphrag.py | 19 +- src/neo4j_graphrag/llm/anthropic_llm.py | 76 +++-- src/neo4j_graphrag/llm/base.py | 64 +--- src/neo4j_graphrag/llm/cohere_llm.py | 75 +++-- src/neo4j_graphrag/llm/mistralai_llm.py | 79 +++-- src/neo4j_graphrag/llm/ollama_llm.py | 61 +++- src/neo4j_graphrag/llm/openai_llm.py | 113 ++++--- src/neo4j_graphrag/llm/utils.py | 70 ----- src/neo4j_graphrag/llm/vertexai_llm.py | 131 +++++--- src/neo4j_graphrag/message_history.py | 3 - tests/e2e/test_graphrag_e2e.py | 57 +--- tests/unit/llm/test_anthropic_llm.py | 166 +++++++---- tests/unit/llm/test_base.py | 216 -------------- tests/unit/llm/test_cohere_llm.py | 81 ++++- tests/unit/llm/test_mistralai_llm.py | 91 ++++++ tests/unit/llm/test_ollama_llm.py | 110 ++++++- tests/unit/llm/test_openai_llm.py | 281 ++++++++++++++++-- tests/unit/llm/test_utils.py | 144 --------- tests/unit/llm/test_vertexai_llm.py | 107 ++++++- tests/unit/test_graphrag.py | 54 +--- 31 files changed, 1177 insertions(+), 989 deletions(-) rename examples/customize/answer/{langchain_compatibility.py => langchain_compatiblity.py} (100%) create mode 100644 examples/customize/llms/mistalai_llm.py delete mode 100644 examples/customize/llms/mistralai_llm.py delete mode 100644 src/neo4j_graphrag/llm/utils.py delete mode 100644 tests/unit/llm/test_base.py delete mode 100644 tests/unit/llm/test_utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 19f1f3999..1c0b6c9b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,15 +4,8 @@ ### Added -- Document node is now always created when running SimpleKGPipeline, even if `from_pdf=False`. -- Document metadata is exposed in SimpleKGPipeline run method. - Added automatic rate limiting with retry logic and exponential backoff for all Embedding providers using tenacity. The `RateLimitHandler` interface allows for custom rate limiting strategies, including the ability to disable rate limiting entirely. -### Fixed - -- LangChain Chat models compatibility is now working again. - - ## 1.10.0 ### Added diff --git a/examples/README.md b/examples/README.md index 6cd0e758b..774739b32 100644 --- a/examples/README.md +++ b/examples/README.md @@ -69,7 +69,7 @@ are listed in [the last section of this file](#customize). - [OpenAI (GPT)](./customize/llms/openai_llm.py) - [Azure OpenAI]() - [VertexAI (Gemini)](./customize/llms/vertexai_llm.py) -- [MistralAI](customize/llms/mistralai_llm.py) +- [MistralAI](./customize/llms/mistalai_llm.py) - [Cohere](./customize/llms/cohere_llm.py) - [Anthropic (Claude)](./customize/llms/anthropic_llm.py) - [Ollama](./customize/llms/ollama_llm.py) @@ -142,7 +142,7 @@ are listed in [the last section of this file](#customize). ### Answer: GraphRAG -- [LangChain compatibility](customize/answer/langchain_compatibility.py) +- [LangChain compatibility](./customize/answer/langchain_compatiblity.py) - [Use a custom prompt](./customize/answer/custom_prompt.py) diff --git a/examples/customize/answer/langchain_compatibility.py b/examples/customize/answer/langchain_compatiblity.py similarity index 100% rename from examples/customize/answer/langchain_compatibility.py rename to examples/customize/answer/langchain_compatiblity.py diff --git a/examples/customize/llms/anthropic_llm.py b/examples/customize/llms/anthropic_llm.py index dbd3f56fd..85c4ad03a 100644 --- a/examples/customize/llms/anthropic_llm.py +++ b/examples/customize/llms/anthropic_llm.py @@ -1,28 +1,12 @@ from neo4j_graphrag.llm import AnthropicLLM, LLMResponse -from neo4j_graphrag.types import LLMMessage # set api key here on in the ANTHROPIC_API_KEY env var api_key = None -messages: list[LLMMessage] = [ - { - "role": "system", - "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", - }, - { - "role": "user", - "content": "say something", - }, -] - - llm = AnthropicLLM( model_name="claude-3-opus-20240229", model_params={"max_tokens": 1000}, # max_tokens must be specified api_key=api_key, ) -res: LLMResponse = llm.invoke( - # "say something", - messages, -) +res: LLMResponse = llm.invoke("say something") print(res.content) diff --git a/examples/customize/llms/cohere_llm.py b/examples/customize/llms/cohere_llm.py index daa3926ef..d631d3e41 100644 --- a/examples/customize/llms/cohere_llm.py +++ b/examples/customize/llms/cohere_llm.py @@ -1,23 +1,11 @@ from neo4j_graphrag.llm import CohereLLM, LLMResponse -from neo4j_graphrag.types import LLMMessage # set api key here on in the CO_API_KEY env var api_key = None -messages: list[LLMMessage] = [ - { - "role": "system", - "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", - }, - { - "role": "user", - "content": "say something", - }, -] - llm = CohereLLM( model_name="command-r", api_key=api_key, ) -res: LLMResponse = llm.invoke(input=messages) +res: LLMResponse = llm.invoke("say something") print(res.content) diff --git a/examples/customize/llms/custom_llm.py b/examples/customize/llms/custom_llm.py index eccd7beec..86b3cb993 100644 --- a/examples/customize/llms/custom_llm.py +++ b/examples/customize/llms/custom_llm.py @@ -1,6 +1,6 @@ import random import string -from typing import Any, Awaitable, Callable, Optional, TypeVar +from typing import Any, Awaitable, Callable, List, Optional, TypeVar, Union from neo4j_graphrag.llm import LLMInterface, LLMResponse from neo4j_graphrag.utils.rate_limit import ( @@ -8,6 +8,7 @@ # rate_limit_handler, # async_rate_limit_handler, ) +from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage @@ -17,27 +18,37 @@ def __init__( ): super().__init__(model_name, **kwargs) - def _invoke( + # Optional: Apply rate limit handling to synchronous invoke method + # @rate_limit_handler + def invoke( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> LLMResponse: content: str = ( self.model_name + ": " + "".join(random.choices(string.ascii_letters, k=30)) ) return LLMResponse(content=content) - async def _ainvoke( + # Optional: Apply rate limit handling to asynchronous ainvoke method + # @async_rate_limit_handler + async def ainvoke( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> LLMResponse: raise NotImplementedError() -llm = CustomLLM("") +llm = CustomLLM( + "" +) # if rate_limit_handler and async_rate_limit_handler decorators are used, the default rate limit handler will be applied automatically (retry with exponential backoff) res: LLMResponse = llm.invoke("text") print(res.content) -# If you want to use a custom rate limit handler +# If rate_limit_handler and async_rate_limit_handler decorators are used and you want to use a custom rate limit handler # Type variables for function signatures used in rate limit handlers F = TypeVar("F", bound=Callable[..., Any]) AF = TypeVar("AF", bound=Callable[..., Awaitable[Any]]) diff --git a/examples/customize/llms/mistalai_llm.py b/examples/customize/llms/mistalai_llm.py new file mode 100644 index 000000000..b829baad4 --- /dev/null +++ b/examples/customize/llms/mistalai_llm.py @@ -0,0 +1,10 @@ +from neo4j_graphrag.llm import MistralAILLM + +# set api key here on in the MISTRAL_API_KEY env var +api_key = None + +llm = MistralAILLM( + model_name="mistral-small-latest", + api_key=api_key, +) +llm.invoke("say something") diff --git a/examples/customize/llms/mistralai_llm.py b/examples/customize/llms/mistralai_llm.py deleted file mode 100644 index 66db280b1..000000000 --- a/examples/customize/llms/mistralai_llm.py +++ /dev/null @@ -1,32 +0,0 @@ -from neo4j_graphrag.llm import MistralAILLM, LLMResponse -from neo4j_graphrag.message_history import InMemoryMessageHistory -from neo4j_graphrag.types import LLMMessage - -# set api key here on in the MISTRAL_API_KEY env var -api_key = None - - -messages: list[LLMMessage] = [ - { - "role": "system", - "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", - }, - { - "role": "user", - "content": "say something", - }, -] - - -llm = MistralAILLM( - model_name="mistral-small-latest", - api_key=api_key, -) -res: LLMResponse = llm.invoke( - # "say something", - # messages, - InMemoryMessageHistory( - messages=messages, - ) -) -print(res.content) diff --git a/examples/customize/llms/ollama_llm.py b/examples/customize/llms/ollama_llm.py index 37dd1dbec..dc42f7466 100644 --- a/examples/customize/llms/ollama_llm.py +++ b/examples/customize/llms/ollama_llm.py @@ -3,26 +3,11 @@ """ from neo4j_graphrag.llm import LLMResponse, OllamaLLM -from neo4j_graphrag.types import LLMMessage - -messages: list[LLMMessage] = [ - { - "role": "system", - "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", - }, - { - "role": "user", - "content": "say something", - }, -] - llm = OllamaLLM( - model_name="orca-mini:latest", + model_name="", # model_params={"options": {"temperature": 0}, "format": "json"}, # host="...", # if using a remote server ) -res: LLMResponse = llm.invoke( - messages, -) +res: LLMResponse = llm.invoke("What is the additive color model?") print(res.content) diff --git a/examples/customize/llms/openai_llm.py b/examples/customize/llms/openai_llm.py index 501ccdb53..d4b38244e 100644 --- a/examples/customize/llms/openai_llm.py +++ b/examples/customize/llms/openai_llm.py @@ -1,28 +1,8 @@ from neo4j_graphrag.llm import LLMResponse, OpenAILLM -from neo4j_graphrag.message_history import InMemoryMessageHistory -from neo4j_graphrag.types import LLMMessage # set api key here on in the OPENAI_API_KEY env var api_key = None -messages: list[LLMMessage] = [ - { - "role": "system", - "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", - }, - { - "role": "user", - "content": "say something", - }, -] - - llm = OpenAILLM(model_name="gpt-4o", api_key=api_key) -res: LLMResponse = llm.invoke( - # "say something", - # messages, - InMemoryMessageHistory( - messages=messages, - ) -) +res: LLMResponse = llm.invoke("say something") print(res.content) diff --git a/examples/customize/llms/vertexai_llm.py b/examples/customize/llms/vertexai_llm.py index 34fc179ae..f43864935 100644 --- a/examples/customize/llms/vertexai_llm.py +++ b/examples/customize/llms/vertexai_llm.py @@ -1,20 +1,6 @@ from neo4j_graphrag.llm import LLMResponse, VertexAILLM from vertexai.generative_models import GenerationConfig -from neo4j_graphrag.types import LLMMessage - -messages: list[LLMMessage] = [ - { - "role": "system", - "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", - }, - { - "role": "user", - "content": "say something", - }, -] - - generation_config = GenerationConfig(temperature=1.0) llm = VertexAILLM( model_name="gemini-2.0-flash-001", @@ -23,6 +9,7 @@ # vertexai.generative_models.GenerativeModel client ) res: LLMResponse = llm.invoke( - input=messages, + "say something", + system_instruction="You are living in 3000 where AI rules the world", ) print(res.content) diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index e79622dc3..08f08a368 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -27,7 +27,6 @@ from neo4j_graphrag.generation.prompts import RagTemplate from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel from neo4j_graphrag.llm import LLMInterface -from neo4j_graphrag.llm.utils import legacy_inputs_to_messages from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.retrievers.base import Retriever from neo4j_graphrag.types import LLMMessage, RetrieverResult @@ -146,17 +145,12 @@ def search( prompt = self.prompt_template.format( query_text=query_text, context=context, examples=validated_data.examples ) - - messages = legacy_inputs_to_messages( - prompt, - message_history=message_history, - system_instruction=self.prompt_template.system_instructions, - ) - logger.debug(f"RAG: retriever_result={prettify(retriever_result)}") logger.debug(f"RAG: prompt={prompt}") llm_response = self.llm.invoke( - messages, + prompt, + message_history, + system_instruction=self.prompt_template.system_instructions, ) answer = llm_response.content result: dict[str, Any] = {"answer": answer} @@ -174,12 +168,9 @@ def _build_query( summarization_prompt = self._chat_summary_prompt( message_history=message_history ) - messages = legacy_inputs_to_messages( - summarization_prompt, - system_instruction=summary_system_message, - ) summary = self.llm.invoke( - messages, + input=summarization_prompt, + system_instruction=summary_system_message, ).content return self.conversation_prompt(summary=summary, current_query=query_text) return query_text diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index 9264b005f..21560d3f2 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -13,21 +13,28 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, Optional, Union +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, cast + +from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, ) from neo4j_graphrag.llm.types import ( + BaseMessage, LLMResponse, + MessageList, + UserMessage, ) +from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage if TYPE_CHECKING: from anthropic.types.message_param import MessageParam - from anthropic import NotGiven class AnthropicLLM(LLMInterface): @@ -77,41 +84,46 @@ def __init__( def get_messages( self, - input: list[LLMMessage], - ) -> tuple[Union[str, NotGiven], Iterable[MessageParam]]: - messages: list[MessageParam] = [] - system_instruction: Union[str, NotGiven] = self.anthropic.NOT_GIVEN - for i in input: - if i["role"] == "system": - system_instruction = i["content"] - else: - if i["role"] not in ("user", "assistant"): - raise ValueError(f"Unknown role: {i['role']}") - messages.append( - self.anthropic.types.MessageParam( - role=i["role"], - content=i["content"], - ) - ) - return system_instruction, messages - - def _invoke( + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + ) -> Iterable[MessageParam]: + messages: list[dict[str, str]] = [] + if message_history: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + try: + MessageList(messages=cast(list[BaseMessage], message_history)) + except ValidationError as e: + raise LLMGenerationError(e.errors()) from e + messages.extend(cast(Iterable[dict[str, Any]], message_history)) + messages.append(UserMessage(content=input).model_dump()) + return messages # type: ignore + + @rate_limit_handler + def invoke( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: - system_instruction, messages = self.get_messages(input) + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + messages = self.get_messages(input, message_history) response = self.client.messages.create( model=self.model_name, - system=system_instruction, + system=system_instruction or self.anthropic.NOT_GIVEN, messages=messages, **self.model_params, ) @@ -124,23 +136,31 @@ def _invoke( except self.anthropic.APIError as e: raise LLMGenerationError(e) - async def _ainvoke( + @async_rate_limit_handler + async def ainvoke( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: - system_instruction, messages = self.get_messages(input) + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + messages = self.get_messages(input, message_history) response = await self.async_client.messages.create( model=self.model_name, - system=system_instruction, + system=system_instruction or self.anthropic.NOT_GIVEN, messages=messages, **self.model_params, ) diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index b66713b2a..ff7af1c70 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -17,23 +17,17 @@ from abc import ABC, abstractmethod from typing import Any, List, Optional, Sequence, Union -from pydantic import ValidationError - from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage from .types import LLMResponse, ToolCallResponse from neo4j_graphrag.utils.rate_limit import ( DEFAULT_RATE_LIMIT_HANDLER, - rate_limit_handler, - async_rate_limit_handler, - RateLimitHandler, ) from neo4j_graphrag.tool import Tool -from .utils import legacy_inputs_to_messages -from ..exceptions import LLMGenerationError +from neo4j_graphrag.utils.rate_limit import RateLimitHandler class LLMInterface(ABC): @@ -61,30 +55,20 @@ def __init__( else: self._rate_limit_handler = DEFAULT_RATE_LIMIT_HANDLER - @rate_limit_handler + @abstractmethod def invoke( self, - input: Union[str, List[LLMMessage], MessageHistory], + input: str, message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, - ) -> LLMResponse: - try: - messages = legacy_inputs_to_messages( - input, message_history, system_instruction - ) - except ValidationError as e: - raise LLMGenerationError("Input validation failed") from e - return self._invoke(messages) - - @abstractmethod - def _invoke( - self, - input: list[LLMMessage], ) -> LLMResponse: """Sends a text input to the LLM and retrieves a response. Args: - input (MessageHistory): Text sent to the LLM. + input (str): Text sent to the LLM. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. @@ -93,25 +77,20 @@ def _invoke( LLMGenerationError: If anything goes wrong. """ - @async_rate_limit_handler + @abstractmethod async def ainvoke( self, - input: Union[str, List[LLMMessage], MessageHistory], + input: str, message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, - ) -> LLMResponse: - messages = legacy_inputs_to_messages(input, message_history, system_instruction) - return await self._ainvoke(messages) - - @abstractmethod - async def _ainvoke( - self, - input: list[LLMMessage], ) -> LLMResponse: """Asynchronously sends a text input to the LLM and retrieves a response. Args: input (str): Text sent to the LLM. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. @@ -120,7 +99,6 @@ async def _ainvoke( LLMGenerationError: If anything goes wrong. """ - @rate_limit_handler def invoke_with_tools( self, input: str, @@ -146,20 +124,8 @@ def invoke_with_tools( LLMGenerationError: If anything goes wrong. NotImplementedError: If the LLM provider does not support tool calling. """ - try: - messages = legacy_inputs_to_messages( - input, message_history, system_instruction - ) - except ValidationError as e: - raise LLMGenerationError("Input validation failed") from e - return self._invoke_with_tools(messages, tools) - - def _invoke_with_tools( - self, inputs: list[LLMMessage], tools: Sequence[Tool] - ) -> ToolCallResponse: raise NotImplementedError("This LLM provider does not support tool calling.") - @async_rate_limit_handler async def ainvoke_with_tools( self, input: str, @@ -185,10 +151,4 @@ async def ainvoke_with_tools( LLMGenerationError: If anything goes wrong. NotImplementedError: If the LLM provider does not support tool calling. """ - messages = legacy_inputs_to_messages(input, message_history, system_instruction) - return await self._ainvoke_with_tools(messages, tools) - - async def _ainvoke_with_tools( - self, inputs: list[LLMMessage], tools: Sequence[Tool] - ) -> ToolCallResponse: raise NotImplementedError("This LLM provider does not support tool calling.") diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index d34f20f05..2e3ca0cea 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -14,16 +14,25 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, cast + +from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, ) from neo4j_graphrag.llm.types import ( + BaseMessage, LLMResponse, + MessageList, + SystemMessage, + UserMessage, ) +from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage if TYPE_CHECKING: @@ -75,66 +84,84 @@ def __init__( def get_messages( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> ChatMessages: - messages: ChatMessages = [] - for i in input: - if i["role"] == "system": - messages.append(self.cohere.SystemChatMessageV2(content=i["content"])) - elif i["role"] == "user": - messages.append(self.cohere.UserChatMessageV2(content=i["content"])) - elif i["role"] == "assistant": - messages.append( - self.cohere.AssistantChatMessageV2(content=i["content"]) - ) - else: - raise ValueError(f"Unknown role: {i['role']}") - return messages - - def _invoke( + messages = [] + if system_instruction: + messages.append(SystemMessage(content=system_instruction).model_dump()) + if message_history: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + try: + MessageList(messages=cast(list[BaseMessage], message_history)) + except ValidationError as e: + raise LLMGenerationError(e.errors()) from e + messages.extend(cast(Iterable[dict[str, Any]], message_history)) + messages.append(UserMessage(content=input).model_dump()) + return messages # type: ignore + + @rate_limit_handler + def invoke( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: - messages = self.get_messages(input) + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + messages = self.get_messages(input, message_history, system_instruction) res = self.client.chat( messages=messages, model=self.model_name, ) except self.cohere_api_error as e: - raise LLMGenerationError("Error calling cohere") from e + raise LLMGenerationError(e) return LLMResponse( content=res.message.content[0].text if res.message.content else "", ) - async def _ainvoke( + @async_rate_limit_handler + async def ainvoke( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: - messages = self.get_messages(input) + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + messages = self.get_messages(input, message_history, system_instruction) res = await self.async_client.chat( messages=messages, model=self.model_name, ) except self.cohere_api_error as e: - raise LLMGenerationError("Error calling cohere") from e + raise LLMGenerationError(e) return LLMResponse( content=res.message.content[0].text if res.message.content else "", ) diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index 76da39a3d..3fa8663ae 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -15,31 +15,33 @@ from __future__ import annotations import os -from typing import Any, Optional +from typing import Any, Iterable, List, Optional, Union, cast + +from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, ) from neo4j_graphrag.llm.types import ( + BaseMessage, LLMResponse, + MessageList, + SystemMessage, + UserMessage, ) +from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage try: - from mistralai import ( - Messages, - UserMessage, - AssistantMessage, - SystemMessage, - Mistral, - ) + from mistralai import Messages, Mistral from mistralai.models.sdkerror import SDKError except ImportError: Mistral = None # type: ignore SDKError = None # type: ignore - Messages = None # type: ignore class MistralAILLM(LLMInterface): @@ -73,31 +75,38 @@ def __init__( def get_messages( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> list[Messages]: - messages: list[Messages] = [] - for m in input: - if m["role"] == "system": - messages.append(SystemMessage(content=m["content"])) - continue - if m["role"] == "user": - messages.append(UserMessage(content=m["content"])) - continue - if m["role"] == "assistant": - messages.append(AssistantMessage(content=m["content"])) - continue - raise ValueError(f"Unknown role: {m['role']}") - return messages - - def _invoke( + messages = [] + if system_instruction: + messages.append(SystemMessage(content=system_instruction).model_dump()) + if message_history: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + try: + MessageList(messages=cast(list[BaseMessage], message_history)) + except ValidationError as e: + raise LLMGenerationError(e.errors()) from e + messages.extend(cast(Iterable[dict[str, Any]], message_history)) + messages.append(UserMessage(content=input).model_dump()) + return cast(list[Messages], messages) + + @rate_limit_handler + def invoke( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the Mistral chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from MistralAI. @@ -106,7 +115,9 @@ def _invoke( LLMGenerationError: If anything goes wrong. """ try: - messages = self.get_messages(input) + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + messages = self.get_messages(input, message_history, system_instruction) response = self.client.chat.complete( model=self.model_name, messages=messages, @@ -121,15 +132,21 @@ def _invoke( except SDKError as e: raise LLMGenerationError(e) - async def _ainvoke( + @async_rate_limit_handler + async def ainvoke( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends a text input to the MistralAI chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from MistralAI. @@ -138,7 +155,9 @@ async def _ainvoke( LLMGenerationError: If anything goes wrong. """ try: - messages = self.get_messages(input) + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + messages = self.get_messages(input, message_history, system_instruction) response = await self.client.chat.complete_async( model=self.model_name, messages=messages, diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index 512db928d..94541e033 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -15,15 +15,26 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Optional, Sequence +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, Union, cast + +from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError +from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage from .base import LLMInterface -from neo4j_graphrag.utils.rate_limit import RateLimitHandler +from neo4j_graphrag.utils.rate_limit import ( + RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, +) from .types import ( + BaseMessage, LLMResponse, + MessageList, + SystemMessage, + UserMessage, ) if TYPE_CHECKING: @@ -69,26 +80,48 @@ def __init__( def get_messages( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> Sequence[Message]: - return [self.ollama.Message(**i) for i in input] + messages = [] + if system_instruction: + messages.append(SystemMessage(content=system_instruction).model_dump()) + if message_history: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + try: + MessageList(messages=cast(list[BaseMessage], message_history)) + except ValidationError as e: + raise LLMGenerationError(e.errors()) from e + messages.extend(cast(Iterable[dict[str, Any]], message_history)) + messages.append(UserMessage(content=input).model_dump()) + return messages # type: ignore - def _invoke( + @rate_limit_handler + def invoke( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages response = self.client.chat( model=self.model_name, - messages=self.get_messages(input), + messages=self.get_messages(input, message_history, system_instruction), **self.model_params, ) content = response.message.content or "" @@ -96,15 +129,21 @@ def _invoke( except self.ollama.ResponseError as e: raise LLMGenerationError(e) - async def _ainvoke( + @async_rate_limit_handler + async def ainvoke( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends a text input to the OpenAI chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from OpenAI. @@ -113,9 +152,11 @@ async def _ainvoke( LLMGenerationError: If anything goes wrong. """ try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages response = await self.async_client.chat( model=self.model_name, - messages=self.get_messages(input), + messages=self.get_messages(input, message_history, system_instruction), options=self.model_params, ) content = response.message.content or "" diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 588352009..afdf0234d 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -24,18 +24,30 @@ Optional, Iterable, Sequence, + Union, cast, - Type, ) +from pydantic import ValidationError + +from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage from ..exceptions import LLMGenerationError from .base import LLMInterface +from neo4j_graphrag.utils.rate_limit import ( + RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, +) from .types import ( + BaseMessage, LLMResponse, + MessageList, ToolCall, ToolCallResponse, + SystemMessage, + UserMessage, ) from neo4j_graphrag.tool import Tool @@ -46,13 +58,11 @@ ChatCompletionToolParam, ) from openai import OpenAI, AsyncOpenAI - from neo4j_graphrag.utiles.rate_limit import RateLimitHandler else: ChatCompletionMessageParam = Any ChatCompletionToolParam = Any OpenAI = Any AsyncOpenAI = Any - RateLimitHandler = Any class BaseOpenAILLM(LLMInterface, abc.ABC): @@ -87,28 +97,23 @@ def __init__( def get_messages( self, - messages: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> Iterable[ChatCompletionMessageParam]: - chat_messages = [] - for m in messages: - message_type: Type[ChatCompletionMessageParam] - if m["role"] == "system": - message_type = self.openai.types.chat.ChatCompletionSystemMessageParam - elif m["role"] == "user": - message_type = self.openai.types.chat.ChatCompletionUserMessageParam - elif m["role"] == "assistant": - message_type = ( - self.openai.types.chat.ChatCompletionAssistantMessageParam - ) - else: - raise ValueError(f"Unknown role: {m['role']}") - chat_messages.append( - message_type( - role=m["role"], # type: ignore - content=m["content"], - ) - ) - return chat_messages + messages = [] + if system_instruction: + messages.append(SystemMessage(content=system_instruction).model_dump()) + if message_history: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + try: + MessageList(messages=cast(list[BaseMessage], message_history)) + except ValidationError as e: + raise LLMGenerationError(e.errors()) from e + messages.extend(cast(Iterable[dict[str, Any]], message_history)) + messages.append(UserMessage(content=input).model_dump()) + return messages # type: ignore def _convert_tool_to_openai_format(self, tool: Tool) -> Dict[str, Any]: """Convert a Tool object to OpenAI's expected format. @@ -131,15 +136,21 @@ def _convert_tool_to_openai_format(self, tool: Tool) -> Dict[str, Any]: except AttributeError: raise LLMGenerationError(f"Tool {tool} is not a valid Tool object") - def _invoke( + @rate_limit_handler + def invoke( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the OpenAI chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from OpenAI. @@ -148,8 +159,10 @@ def _invoke( LLMGenerationError: If anything goes wrong. """ try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages response = self.client.chat.completions.create( - messages=self.get_messages(input), + messages=self.get_messages(input, message_history, system_instruction), model=self.model_name, **self.model_params, ) @@ -158,10 +171,13 @@ def _invoke( except self.openai.OpenAIError as e: raise LLMGenerationError(e) - def _invoke_with_tools( + @rate_limit_handler + def invoke_with_tools( self, - input: list[LLMMessage], - tools: Sequence[Tool], + input: str, + tools: Sequence[Tool], # Tools definition as a sequence of Tool objects + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> ToolCallResponse: """Sends a text input to the OpenAI chat completion model with tool definitions and retrieves a tool call response. @@ -169,6 +185,9 @@ def _invoke_with_tools( Args: input (str): Text sent to the LLM. tools (List[Tool]): List of Tools for the LLM to choose from. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: ToolCallResponse: The response from the LLM containing a tool call. @@ -177,6 +196,9 @@ def _invoke_with_tools( LLMGenerationError: If anything goes wrong. """ try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + params = self.model_params.copy() if self.model_params else {} if "temperature" not in params: params["temperature"] = 0.0 @@ -188,7 +210,7 @@ def _invoke_with_tools( openai_tools.append(cast(ChatCompletionToolParam, openai_format_tool)) response = self.client.chat.completions.create( - messages=self.get_messages(input), + messages=self.get_messages(input, message_history, system_instruction), model=self.model_name, tools=openai_tools, tool_choice="auto", @@ -224,15 +246,21 @@ def _invoke_with_tools( except self.openai.OpenAIError as e: raise LLMGenerationError(e) - async def _ainvoke( + @async_rate_limit_handler + async def ainvoke( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends a text input to the OpenAI chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from OpenAI. @@ -241,8 +269,10 @@ async def _ainvoke( LLMGenerationError: If anything goes wrong. """ try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages response = await self.async_client.chat.completions.create( - messages=self.get_messages(input), + messages=self.get_messages(input, message_history, system_instruction), model=self.model_name, **self.model_params, ) @@ -251,10 +281,13 @@ async def _ainvoke( except self.openai.OpenAIError as e: raise LLMGenerationError(e) - async def _ainvoke_with_tools( + @async_rate_limit_handler + async def ainvoke_with_tools( self, - input: list[LLMMessage], + input: str, tools: Sequence[Tool], # Tools definition as a sequence of Tool objects + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> ToolCallResponse: """Asynchronously sends a text input to the OpenAI chat completion model with tool definitions and retrieves a tool call response. @@ -262,6 +295,9 @@ async def _ainvoke_with_tools( Args: input (str): Text sent to the LLM. tools (List[Tool]): List of Tools for the LLM to choose from. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: ToolCallResponse: The response from the LLM containing a tool call. @@ -270,6 +306,9 @@ async def _ainvoke_with_tools( LLMGenerationError: If anything goes wrong. """ try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + params = self.model_params.copy() if "temperature" not in params: params["temperature"] = 0.0 @@ -281,7 +320,7 @@ async def _ainvoke_with_tools( openai_tools.append(cast(ChatCompletionToolParam, openai_format_tool)) response = await self.async_client.chat.completions.create( - messages=self.get_messages(input), + messages=self.get_messages(input, message_history, system_instruction), model=self.model_name, tools=openai_tools, tool_choice="auto", diff --git a/src/neo4j_graphrag/llm/utils.py b/src/neo4j_graphrag/llm/utils.py deleted file mode 100644 index 5746ca91c..000000000 --- a/src/neo4j_graphrag/llm/utils.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# https://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations -import warnings -from typing import Union, Optional - -from pydantic import TypeAdapter - -from neo4j_graphrag.message_history import MessageHistory -from neo4j_graphrag.types import LLMMessage - - -def system_instruction_from_messages(messages: list[LLMMessage]) -> str | None: - for message in messages: - if message["role"] == "system": - return message["content"] - return None - - -llm_messages_adapter = TypeAdapter(list[LLMMessage]) - - -def legacy_inputs_to_messages( - input: Union[str, list[LLMMessage], MessageHistory], - message_history: Optional[Union[list[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, -) -> list[LLMMessage]: - if message_history: - if isinstance(message_history, MessageHistory): - messages = message_history.messages - else: # list[LLMMessage] - messages = llm_messages_adapter.validate_python(message_history) - else: - messages = [] - if system_instruction is not None: - if system_instruction_from_messages(messages) is not None: - warnings.warn( - "system_instruction provided but ignored as the message history already contains a system message", - UserWarning, - ) - else: - messages.insert( - 0, - LLMMessage( - role="system", - content=system_instruction, - ), - ) - - if isinstance(input, str): - messages.append(LLMMessage(role="user", content=input)) - return messages - if isinstance(input, list): - messages.extend(input) - return messages - # input is a MessageHistory instance - messages.extend(input.messages) - return messages diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index adb13312d..b9f1e40e8 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -13,19 +13,25 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Optional, Sequence +from typing import Any, List, Optional, Union, cast, Sequence +from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, ) from neo4j_graphrag.llm.types import ( + BaseMessage, LLMResponse, + MessageList, ToolCall, ToolCallResponse, ) +from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.tool import Tool from neo4j_graphrag.types import LLMMessage @@ -92,75 +98,92 @@ def __init__( def get_messages( self, - input: list[LLMMessage], - ) -> tuple[str | None, list[Content]]: + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + ) -> list[Content]: messages = [] - system_instruction = self.system_instruction - for message in input: - role = message.get("role") - if role == "system": - system_instruction = message.get("content") - continue - if role == "user": - messages.append( - Content( - role="user", - parts=[Part.from_text(message.get("content", ""))], + if message_history: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + try: + MessageList(messages=cast(list[BaseMessage], message_history)) + except ValidationError as e: + raise LLMGenerationError(e.errors()) from e + + for message in message_history: + if message.get("role") == "user": + messages.append( + Content( + role="user", + parts=[Part.from_text(message.get("content", ""))], + ) ) - ) - continue - if role == "assistant": - messages.append( - Content( - role="model", - parts=[Part.from_text(message.get("content", ""))], + elif message.get("role") == "assistant": + messages.append( + Content( + role="model", + parts=[Part.from_text(message.get("content", ""))], + ) ) - ) - continue - raise ValueError(f"Unknown role: {role}") - return system_instruction, messages - def _invoke( + messages.append(Content(role="user", parts=[Part.from_text(input)])) + return messages + + @rate_limit_handler + def invoke( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ - system_instruction, messages = self.get_messages(input) model = self._get_model( system_instruction=system_instruction, ) try: - options = self._get_call_params(messages, tools=None) + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + options = self._get_call_params(input, message_history, tools=None) response = model.generate_content(**options) return self._parse_content_response(response) except ResponseValidationError as e: raise LLMGenerationError("Error calling VertexAILLM") from e - async def _ainvoke( + @async_rate_limit_handler + async def ainvoke( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: - system_instruction, messages = self.get_messages(input) + if isinstance(message_history, MessageHistory): + message_history = message_history.messages model = self._get_model( system_instruction=system_instruction, ) - options = self._get_call_params(messages, tools=None) + options = self._get_call_params(input, message_history, tools=None) response = await model.generate_content_async(**options) return self._parse_content_response(response) except ResponseValidationError as e: @@ -190,6 +213,7 @@ def _get_model( self, system_instruction: Optional[str] = None, ) -> GenerativeModel: + # system_message = [system_instruction] if system_instruction is not None else [] model = GenerativeModel( model_name=self.model_name, system_instruction=system_instruction, @@ -198,7 +222,8 @@ def _get_model( def _get_call_params( self, - contents: list[Content], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]], tools: Optional[Sequence[Tool]], ) -> dict[str, Any]: options = dict(self.options) @@ -215,28 +240,32 @@ def _get_call_params( else: # no tools, remove tool_config if defined options.pop("tool_config", None) - options["contents"] = contents + + messages = self.get_messages(input, message_history) + options["contents"] = messages return options async def _acall_llm( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, tools: Optional[Sequence[Tool]] = None, ) -> GenerationResponse: - system_instruction, contents = self.get_messages(input) - model = self._get_model(system_instruction) - options = self._get_call_params(contents, tools) + model = self._get_model(system_instruction=system_instruction) + options = self._get_call_params(input, message_history, tools) response = await model.generate_content_async(**options) return response # type: ignore[no-any-return] def _call_llm( self, - input: list[LLMMessage], + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, tools: Optional[Sequence[Tool]] = None, ) -> GenerationResponse: - system_instruction, contents = self.get_messages(input) - model = self._get_model(system_instruction) - options = self._get_call_params(contents, tools) + model = self._get_model(system_instruction=system_instruction) + options = self._get_call_params(input, message_history, tools) response = model.generate_content(**options) return response # type: ignore[no-any-return] @@ -258,24 +287,32 @@ def _parse_content_response(self, response: GenerationResponse) -> LLMResponse: content=response.text, ) - async def _ainvoke_with_tools( + async def ainvoke_with_tools( self, - input: list[LLMMessage], + input: str, tools: Sequence[Tool], + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> ToolCallResponse: response = await self._acall_llm( input, + message_history=message_history, + system_instruction=system_instruction, tools=tools, ) return self._parse_tool_response(response) - def _invoke_with_tools( + def invoke_with_tools( self, - input: list[LLMMessage], + input: str, tools: Sequence[Tool], + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, ) -> ToolCallResponse: response = self._call_llm( input, + message_history=message_history, + system_instruction=system_instruction, tools=tools, ) return self._parse_tool_response(response) diff --git a/src/neo4j_graphrag/message_history.py b/src/neo4j_graphrag/message_history.py index f4df4576f..59ba033d9 100644 --- a/src/neo4j_graphrag/message_history.py +++ b/src/neo4j_graphrag/message_history.py @@ -74,9 +74,6 @@ class MessageHistory(ABC): @abstractmethod def messages(self) -> List[LLMMessage]: ... - def is_empty(self) -> bool: - return len(self.messages) == 0 - @abstractmethod def add_message(self, message: LLMMessage) -> None: ... diff --git a/tests/e2e/test_graphrag_e2e.py b/tests/e2e/test_graphrag_e2e.py index d747aec08..895a9adb0 100644 --- a/tests/e2e/test_graphrag_e2e.py +++ b/tests/e2e/test_graphrag_e2e.py @@ -60,14 +60,7 @@ def test_graphrag_happy_path( ) llm.invoke.assert_called_once_with( - [ - { - "role": "system", - "content": "Answer the user question using the provided context.", - }, - { - "role": "user", - "content": """Context: + """Context: @@ -79,8 +72,8 @@ def test_graphrag_happy_path( Answer: """, - }, - ] + None, + system_instruction="Answer the user question using the provided context.", ) assert isinstance(result, RagResultModel) assert result.answer == "some text" @@ -155,21 +148,13 @@ def test_graphrag_happy_path_with_neo4j_message_history( llm.invoke.assert_has_calls( [ call( - [ - {"role": "system", "content": first_invocation_system_instruction}, - {"role": "user", "content": first_invocation_input}, - ] + input=first_invocation_input, + system_instruction=first_invocation_system_instruction, ), call( - [ - { - "role": "system", - "content": "Answer the user question using the provided context.", - }, - {"role": "user", "content": "initial question"}, - {"role": "assistant", "content": "answer to initial question"}, - {"role": "user", "content": second_invocation}, - ] + second_invocation, + message_history.messages, + system_instruction="Answer the user question using the provided context.", ), ] ) @@ -205,14 +190,7 @@ def test_graphrag_happy_path_return_context( ) llm.invoke.assert_called_once_with( - [ - { - "role": "system", - "content": "Answer the user question using the provided context.", - }, - { - "role": "user", - "content": """Context: + """Context: @@ -224,8 +202,8 @@ def test_graphrag_happy_path_return_context( Answer: """, - }, - ], + None, + system_instruction="Answer the user question using the provided context.", ) assert isinstance(result, RagResultModel) assert result.answer == "some text" @@ -258,14 +236,7 @@ def test_graphrag_happy_path_examples( ) llm.invoke.assert_called_once_with( - [ - { - "role": "system", - "content": "Answer the user question using the provided context.", - }, - { - "role": "user", - "content": """Context: + """Context: @@ -277,8 +248,8 @@ def test_graphrag_happy_path_examples( Answer: """, - }, - ] + None, + system_instruction="Answer the user question using the provided context.", ) assert result.answer == "some text" diff --git a/tests/unit/llm/test_anthropic_llm.py b/tests/unit/llm/test_anthropic_llm.py index 75156c014..029d75778 100644 --- a/tests/unit/llm/test_anthropic_llm.py +++ b/tests/unit/llm/test_anthropic_llm.py @@ -19,11 +19,9 @@ import anthropic import pytest -from anthropic import NOT_GIVEN, NotGiven - -from neo4j_graphrag.llm import LLMResponse +from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.anthropic_llm import AnthropicLLM -from neo4j_graphrag.types import LLMMessage +from neo4j_graphrag.llm.types import LLMResponse @pytest.fixture @@ -42,74 +40,132 @@ def test_anthropic_llm_missing_dependency(mock_import: Mock) -> None: AnthropicLLM(model_name="claude-3-opus-20240229") -def test_anthropic_llm_get_messages_with_system_instructions() -> None: - llm = AnthropicLLM(api_key="my key", model_name="claude") - message_history = [ - LLMMessage(**{"role": "system", "content": "do something"}), - LLMMessage( - **{"role": "user", "content": "When does the sun come up in the summer?"} - ), - LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}), - ] - - system_instruction, messages = llm.get_messages(message_history) - assert isinstance(system_instruction, str) - assert system_instruction == "do something" - assert isinstance(messages, list) - assert len(messages) == 2 # exclude system instruction - for actual, expected in zip(messages, message_history[1:]): - assert isinstance(actual, dict) - assert actual["role"] == expected["role"] - assert actual["content"] == expected["content"] +def test_anthropic_invoke_happy_path(mock_anthropic: Mock) -> None: + mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( + content=[MagicMock(text="generated text")] + ) + model_params = {"temperature": 0.3} + llm = AnthropicLLM("claude-3-opus-20240229", model_params=model_params) + input_text = "may thy knife chip and shatter" + response = llm.invoke(input_text) + assert response.content == "generated text" + llm.client.messages.create.assert_called_once_with( # type: ignore + messages=[{"role": "user", "content": input_text}], + model="claude-3-opus-20240229", + system=anthropic.NOT_GIVEN, + **model_params, + ) -def test_anthropic_llm_get_messages_without_system_instructions() -> None: - llm = AnthropicLLM(api_key="my key", model_name="claude") +def test_anthropic_invoke_with_message_history_happy_path(mock_anthropic: Mock) -> None: + mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( + content=[MagicMock(text="generated text")] + ) + model_params = {"temperature": 0.3} + llm = AnthropicLLM( + "claude-3-opus-20240229", + model_params=model_params, + ) message_history = [ - LLMMessage( - **{"role": "user", "content": "When does the sun come up in the summer?"} - ), - LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}), + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, ] + question = "What about next season?" - system_instruction, messages = llm.get_messages(message_history) - assert isinstance(system_instruction, NotGiven) - assert system_instruction == NOT_GIVEN - assert isinstance(messages, list) - assert len(messages) == 2 - for actual, expected in zip(messages, message_history): - assert isinstance(actual, dict) - assert actual["role"] == expected["role"] - assert actual["content"] == expected["content"] + response = llm.invoke(question, message_history) # type: ignore + assert response.content == "generated text" + message_history.append({"role": "user", "content": question}) + llm.client.messages.create.assert_called_once_with( # type: ignore[attr-defined] + messages=message_history, + model="claude-3-opus-20240229", + system=anthropic.NOT_GIVEN, + **model_params, + ) -def test_anthropic_llm_get_messages_unknown_role() -> None: - llm = AnthropicLLM(api_key="my key", model_name="claude") - message_history = [ - LLMMessage(**{"role": "unknown role", "content": "Usually around 6am."}), # type: ignore[typeddict-item] - ] - with pytest.raises(ValueError, match="Unknown role"): - llm.get_messages(message_history) +def test_anthropic_invoke_with_system_instruction( + mock_anthropic: Mock, +) -> None: + mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( + content=[MagicMock(text="generated text")] + ) + model_params = {"temperature": 0.3} + system_instruction = "You are a helpful assistant." + llm = AnthropicLLM( + "claude-3-opus-20240229", + model_params=model_params, + ) + + question = "When does it come up in the winter?" + response = llm.invoke(question, system_instruction=system_instruction) + assert isinstance(response, LLMResponse) + assert response.content == "generated text" + messages = [{"role": "user", "content": question}] + llm.client.messages.create.assert_called_with( # type: ignore[attr-defined] + model="claude-3-opus-20240229", + system=system_instruction, + messages=messages, + **model_params, + ) + assert llm.client.messages.create.call_count == 1 # type: ignore -def test_anthropic_invoke_happy_path(mock_anthropic: Mock) -> None: + +def test_anthropic_invoke_with_message_history_and_system_instruction( + mock_anthropic: Mock, +) -> None: mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( content=[MagicMock(text="generated text")] ) - mock_anthropic.types.MessageParam.return_value = {"role": "user", "content": "hi"} model_params = {"temperature": 0.3} - llm = AnthropicLLM("claude-3-opus-20240229", model_params=model_params) - input_text = "may thy knife chip and shatter" - response = llm.invoke(input_text) + system_instruction = "You are a helpful assistant." + llm = AnthropicLLM( + "claude-3-opus-20240229", + model_params=model_params, + ) + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + + question = "When does it come up in the winter?" + response = llm.invoke(question, message_history, system_instruction) # type: ignore assert isinstance(response, LLMResponse) assert response.content == "generated text" - llm.client.messages.create.assert_called_once_with( # type: ignore - messages=[{"role": "user", "content": "hi"}], + message_history.append({"role": "user", "content": question}) + llm.client.messages.create.assert_called_with( # type: ignore[attr-defined] model="claude-3-opus-20240229", - system=anthropic.NOT_GIVEN, + system=system_instruction, + messages=message_history, **model_params, ) + assert llm.client.messages.create.call_count == 1 # type: ignore + + +def test_anthropic_invoke_with_message_history_validation_error( + mock_anthropic: Mock, +) -> None: + mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( + content=[MagicMock(text="generated text")] + ) + model_params = {"temperature": 0.3} + system_instruction = "You are a helpful assistant." + llm = AnthropicLLM( + "claude-3-opus-20240229", + model_params=model_params, + system_instruction=system_instruction, + ) + message_history = [ + {"role": "human", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + with pytest.raises(LLMGenerationError) as exc_info: + llm.invoke(question, message_history) # type: ignore + assert "Input should be 'user', 'assistant' or 'system'" in str(exc_info.value) + @pytest.mark.asyncio async def test_anthropic_ainvoke_happy_path(mock_anthropic: Mock) -> None: @@ -117,16 +173,14 @@ async def test_anthropic_ainvoke_happy_path(mock_anthropic: Mock) -> None: mock_response.content = [MagicMock(text="Return text")] mock_model = mock_anthropic.AsyncAnthropic.return_value mock_model.messages.create = AsyncMock(return_value=mock_response) - mock_anthropic.types.MessageParam.return_value = {"role": "user", "content": "hi"} model_params = {"temperature": 0.3} llm = AnthropicLLM("claude-3-opus-20240229", model_params) input_text = "may thy knife chip and shatter" response = await llm.ainvoke(input_text) - assert isinstance(response, LLMResponse) assert response.content == "Return text" llm.async_client.messages.create.assert_awaited_once_with( # type: ignore model="claude-3-opus-20240229", system=anthropic.NOT_GIVEN, - messages=[{"role": "user", "content": "hi"}], + messages=[{"role": "user", "content": input_text}], **model_params, ) diff --git a/tests/unit/llm/test_base.py b/tests/unit/llm/test_base.py deleted file mode 100644 index 4eff7cb94..000000000 --- a/tests/unit/llm/test_base.py +++ /dev/null @@ -1,216 +0,0 @@ -"""The base LLMInterface is responsible -for formatting the inputs as a list of LLMMessage objects -and handling the rate limits. This is what is being tested -in this file. -""" - -from typing import Type, Generator -from unittest import mock -from unittest.mock import patch, Mock, call - -import pytest -import tenacity -from joblib.testing import fixture -from pydantic import ValidationError - -from neo4j_graphrag.exceptions import LLMGenerationError -from neo4j_graphrag.llm import LLMInterface, LLMResponse -from neo4j_graphrag.types import LLMMessage - - -@fixture(scope="module") # type: ignore[misc] -def llm_interface() -> Generator[Type[LLMInterface], None, None]: - real_abstract_methods = LLMInterface.__abstractmethods__ - LLMInterface.__abstractmethods__ = frozenset() - - class CustomLLMInterface(LLMInterface): - pass - - yield CustomLLMInterface - - LLMInterface.__abstractmethods__ = real_abstract_methods - - -@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") -def test_base_llm_interface_invoke_with_input_as_str( - mock_inputs: Mock, llm_interface: Type[LLMInterface] -) -> None: - mock_inputs.return_value = [ - LLMMessage( - role="user", - content="return value of the legacy_inputs_to_messages function", - ) - ] - llm = llm_interface(model_name="test") - message_history = [ - LLMMessage( - **{"role": "user", "content": "When does the sun come up in the summer?"} - ), - LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}), - ] - question = "What about next season?" - system_instruction = "You are a genius." - - with patch.object(llm, "_invoke") as mock_invoke: - llm.invoke(question, message_history, system_instruction) - mock_invoke.assert_called_once_with( - [ - LLMMessage( - role="user", - content="return value of the legacy_inputs_to_messages function", - ) - ] - ) - mock_inputs.assert_called_once_with( - question, - message_history, - system_instruction, - ) - - -@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") -def test_base_llm_interface_invoke_with_invalid_inputs( - mock_inputs: Mock, llm_interface: Type[LLMInterface] -) -> None: - mock_inputs.side_effect = [ - ValidationError.from_exception_data("Invalid data", line_errors=[]) - ] - llm = llm_interface(model_name="test") - question = "What about next season?" - - with pytest.raises(LLMGenerationError, match="Input validation failed"): - llm.invoke(question) - mock_inputs.assert_called_once_with( - question, - None, - None, - ) - - -@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") -def test_base_llm_interface_invoke_with_tools_with_input_as_str( - mock_inputs: Mock, llm_interface: Type[LLMInterface] -) -> None: - mock_inputs.return_value = [ - LLMMessage( - role="user", - content="return value of the legacy_inputs_to_messages function", - ) - ] - llm = llm_interface(model_name="test") - message_history = [ - LLMMessage( - **{"role": "user", "content": "When does the sun come up in the summer?"} - ), - LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}), - ] - question = "What about next season?" - system_instruction = "You are a genius." - - with patch.object(llm, "_invoke_with_tools") as mock_invoke: - llm.invoke_with_tools(question, [], message_history, system_instruction) - mock_invoke.assert_called_once_with( - [ - LLMMessage( - role="user", - content="return value of the legacy_inputs_to_messages function", - ) - ], - [], # tools - ) - mock_inputs.assert_called_once_with( - question, - message_history, - system_instruction, - ) - - -@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") -def test_base_llm_interface_invoke_with_tools_with_invalid_inputs( - mock_inputs: Mock, llm_interface: Type[LLMInterface] -) -> None: - mock_inputs.side_effect = [ - ValidationError.from_exception_data("Invalid data", line_errors=[]) - ] - llm = llm_interface(model_name="test") - question = "What about next season?" - - with pytest.raises(LLMGenerationError, match="Input validation failed"): - llm.invoke_with_tools(question, []) - mock_inputs.assert_called_once_with( - question, - None, - None, - ) - - -@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") -def test_base_llm_interface_invoke_retry_ok( - mock_inputs: Mock, llm_interface: Type[LLMInterface] -) -> None: - mock_inputs.return_value = [ - LLMMessage( - role="user", - content="return value of the legacy_inputs_to_messages function", - ) - ] - llm = llm_interface(model_name="test") - question = "What about next season?" - - with mock.patch.object(llm, "_invoke") as mock_invoke_core: - mock_invoke_core.side_effect = [ - LLMGenerationError("rate limit"), - LLMResponse(content="all good"), - ] - res = llm.invoke(question, []) - assert res.content == "all good" - call_args = [ - { - "role": "user", - "content": "return value of the legacy_inputs_to_messages function", - } - ] - assert mock_invoke_core.call_count == 2 - mock_invoke_core.assert_has_calls( - [ - call(call_args), - call(call_args), - ] - ) - - -@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") -def test_base_llm_interface_invoke_retry_fail( - mock_inputs: Mock, llm_interface: Type[LLMInterface] -) -> None: - mock_inputs.return_value = [ - LLMMessage( - role="user", - content="return value of the legacy_inputs_to_messages function", - ) - ] - llm = llm_interface(model_name="test") - question = "What about next season?" - - with mock.patch.object(llm, "_invoke") as mock_invoke_core: - mock_invoke_core.side_effect = [ - LLMGenerationError("rate limit"), - LLMGenerationError("rate limit"), - LLMGenerationError("rate limit"), - ] - with pytest.raises(tenacity.RetryError): - llm.invoke(question, []) - call_args = [ - { - "role": "user", - "content": "return value of the legacy_inputs_to_messages function", - } - ] - assert mock_invoke_core.call_count == 3 - mock_invoke_core.assert_has_calls( - [ - call(call_args), - call(call_args), - call(call_args), - ] - ) diff --git a/tests/unit/llm/test_cohere_llm.py b/tests/unit/llm/test_cohere_llm.py index fb968d0c9..10a02ec86 100644 --- a/tests/unit/llm/test_cohere_llm.py +++ b/tests/unit/llm/test_cohere_llm.py @@ -41,17 +41,86 @@ def test_cohere_llm_happy_path(mock_cohere: Mock) -> None: chat_response_mock = MagicMock() chat_response_mock.message.content = [MagicMock(text="cohere response text")] mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock - mock_cohere.UserChatMessageV2.return_value = {"role": "user", "content": "test"} llm = CohereLLM(model_name="something") res = llm.invoke("my text") assert isinstance(res, LLMResponse) assert res.content == "cohere response text" - mock_cohere.ClientV2.return_value.chat.assert_called_once_with( - messages=[{"role": "user", "content": "test"}], + + +def test_cohere_llm_invoke_with_message_history_happy_path(mock_cohere: Mock) -> None: + chat_response_mock = MagicMock() + chat_response_mock.message.content = [MagicMock(text="cohere response text")] + mock_cohere_client_chat = mock_cohere.ClientV2.return_value.chat + mock_cohere_client_chat.return_value = chat_response_mock + + system_instruction = "You are a helpful assistant." + llm = CohereLLM(model_name="something") + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore + assert isinstance(res, LLMResponse) + assert res.content == "cohere response text" + messages = [{"role": "system", "content": system_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + mock_cohere_client_chat.assert_called_once_with( + messages=messages, model="something", ) +def test_cohere_llm_invoke_with_message_history_and_system_instruction( + mock_cohere: Mock, +) -> None: + chat_response_mock = MagicMock() + chat_response_mock.message.content = [MagicMock(text="cohere response text")] + mock_cohere_client_chat = mock_cohere.ClientV2.return_value.chat + mock_cohere_client_chat.return_value = chat_response_mock + + system_instruction = "You are a helpful assistant." + llm = CohereLLM(model_name="gpt") + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore + assert isinstance(res, LLMResponse) + assert res.content == "cohere response text" + messages = [{"role": "system", "content": system_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + mock_cohere_client_chat.assert_called_once_with( + messages=messages, + model="gpt", + ) + + +def test_cohere_llm_invoke_with_message_history_validation_error( + mock_cohere: Mock, +) -> None: + chat_response_mock = MagicMock() + chat_response_mock.message.content = [MagicMock(text="cohere response text")] + mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock + + system_instruction = "You are a helpful assistant." + llm = CohereLLM(model_name="something", system_instruction=system_instruction) + message_history = [ + {"role": "robot", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + with pytest.raises(LLMGenerationError) as exc_info: + llm.invoke(question, message_history) # type: ignore + assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) + + @pytest.mark.asyncio async def test_cohere_llm_happy_path_async(mock_cohere: Mock) -> None: chat_response_mock = MagicMock( @@ -70,8 +139,9 @@ async def test_cohere_llm_happy_path_async(mock_cohere: Mock) -> None: def test_cohere_llm_failed(mock_cohere: Mock) -> None: mock_cohere.ClientV2.return_value.chat.side_effect = cohere.core.ApiError llm = CohereLLM(model_name="something") - with pytest.raises(LLMGenerationError, match="Error calling cohere"): + with pytest.raises(LLMGenerationError) as excinfo: llm.invoke("my text") + assert "ApiError" in str(excinfo) @pytest.mark.asyncio @@ -79,5 +149,6 @@ async def test_cohere_llm_failed_async(mock_cohere: Mock) -> None: mock_cohere.AsyncClientV2.return_value.chat.side_effect = cohere.core.ApiError llm = CohereLLM(model_name="something") - with pytest.raises(LLMGenerationError, match="Error calling cohere"): + with pytest.raises(LLMGenerationError) as excinfo: await llm.ainvoke("my text") + assert "ApiError" in str(excinfo) diff --git a/tests/unit/llm/test_mistralai_llm.py b/tests/unit/llm/test_mistralai_llm.py index f22cfb3e6..324798f2f 100644 --- a/tests/unit/llm/test_mistralai_llm.py +++ b/tests/unit/llm/test_mistralai_llm.py @@ -46,6 +46,97 @@ def test_mistralai_llm_invoke(mock_mistral: Mock) -> None: assert res.content == "mistral response" +@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") +def test_mistralai_llm_invoke_with_message_history(mock_mistral: Mock) -> None: + mock_mistral_instance = mock_mistral.return_value + chat_response_mock = MagicMock() + chat_response_mock.choices = [ + MagicMock(message=MagicMock(content="mistral response")) + ] + mock_mistral_instance.chat.complete.return_value = chat_response_mock + model = "mistral-model" + system_instruction = "You are a helpful assistant." + + llm = MistralAILLM(model_name=model) + + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore + + assert isinstance(res, LLMResponse) + assert res.content == "mistral response" + messages = [{"role": "system", "content": system_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.complete.assert_called_once_with( # type: ignore[attr-defined] + messages=messages, + model=model, + ) + + +@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") +def test_mistralai_llm_invoke_with_message_history_and_system_instruction( + mock_mistral: Mock, +) -> None: + mock_mistral_instance = mock_mistral.return_value + chat_response_mock = MagicMock() + chat_response_mock.choices = [ + MagicMock(message=MagicMock(content="mistral response")) + ] + mock_mistral_instance.chat.complete.return_value = chat_response_mock + model = "mistral-model" + system_instruction = "You are a helpful assistant." + llm = MistralAILLM(model_name=model) + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + # first invocation - initial instructions + res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore + assert isinstance(res, LLMResponse) + assert res.content == "mistral response" + messages = [{"role": "system", "content": system_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.complete.assert_called_once_with( # type: ignore[attr-defined] + messages=messages, + model=model, + ) + + assert llm.client.chat.complete.call_count == 1 # type: ignore + + +@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") +def test_mistralai_llm_invoke_with_message_history_validation_error( + mock_mistral: Mock, +) -> None: + mock_mistral_instance = mock_mistral.return_value + chat_response_mock = MagicMock() + chat_response_mock.choices = [ + MagicMock(message=MagicMock(content="mistral response")) + ] + mock_mistral_instance.chat.complete.return_value = chat_response_mock + model = "mistral-model" + system_instruction = "You are a helpful assistant." + + llm = MistralAILLM(model_name=model, system_instruction=system_instruction) + + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "monkey", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + with pytest.raises(LLMGenerationError) as exc_info: + llm.invoke(question, message_history) # type: ignore + assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) + + @pytest.mark.asyncio @patch("neo4j_graphrag.llm.mistralai_llm.Mistral") async def test_mistralai_llm_ainvoke(mock_mistral: Mock) -> None: diff --git a/tests/unit/llm/test_ollama_llm.py b/tests/unit/llm/test_ollama_llm.py index 6ecb9fb13..c1d3f9fdc 100644 --- a/tests/unit/llm/test_ollama_llm.py +++ b/tests/unit/llm/test_ollama_llm.py @@ -41,7 +41,6 @@ def test_ollama_llm_happy_path_deprecated_options(mock_import: Mock) -> None: mock_ollama.Client.return_value.chat.return_value = MagicMock( message=MagicMock(content="ollama chat response"), ) - mock_ollama.Message.return_value = {"role": "user", "content": "test"} model = "gpt" model_params = {"temperature": 0.3} with pytest.warns(DeprecationWarning) as record: @@ -60,10 +59,11 @@ def test_ollama_llm_happy_path_deprecated_options(mock_import: Mock) -> None: res = llm.invoke(question) assert isinstance(res, LLMResponse) assert res.content == "ollama chat response" + messages = [ + {"role": "user", "content": question}, + ] llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] - model=model, - messages=[{"role": "user", "content": "test"}], - options={"temperature": 0.3}, + model=model, messages=messages, options={"temperature": 0.3} ) @@ -90,7 +90,6 @@ def test_ollama_llm_happy_path(mock_import: Mock) -> None: mock_ollama.Client.return_value.chat.return_value = MagicMock( message=MagicMock(content="ollama chat response"), ) - mock_ollama.Message.return_value = {"role": "user", "content": "test"} model = "gpt" options = {"temperature": 0.3} model_params = {"options": options, "format": "json"} @@ -103,7 +102,7 @@ def test_ollama_llm_happy_path(mock_import: Mock) -> None: assert isinstance(res, LLMResponse) assert res.content == "ollama chat response" messages = [ - {"role": "user", "content": "test"}, + {"role": "user", "content": question}, ] llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] model=model, @@ -113,6 +112,102 @@ def test_ollama_llm_happy_path(mock_import: Mock) -> None: ) +@patch("builtins.__import__") +def test_ollama_invoke_with_system_instruction_happy_path(mock_import: Mock) -> None: + mock_ollama = get_mock_ollama() + mock_import.return_value = mock_ollama + mock_ollama.Client.return_value.chat.return_value = MagicMock( + message=MagicMock(content="ollama chat response"), + ) + model = "gpt" + options = {"temperature": 0.3} + model_params = {"options": options, "format": "json"} + llm = OllamaLLM( + model, + model_params=model_params, + ) + system_instruction = "You are a helpful assistant." + question = "What about next season?" + + response = llm.invoke(question, system_instruction=system_instruction) + assert response.content == "ollama chat response" + messages = [{"role": "system", "content": system_instruction}] + messages.append({"role": "user", "content": question}) + llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] + model=model, + messages=messages, + options=options, + format="json", + ) + + +@patch("builtins.__import__") +def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> None: + mock_ollama = get_mock_ollama() + mock_import.return_value = mock_ollama + mock_ollama.Client.return_value.chat.return_value = MagicMock( + message=MagicMock(content="ollama chat response"), + ) + model = "gpt" + options = {"temperature": 0.3} + model_params = {"options": options} + llm = OllamaLLM( + model, + model_params=model_params, + ) + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + response = llm.invoke(question, message_history) # type: ignore + assert response.content == "ollama chat response" + messages = [m for m in message_history] + messages.append({"role": "user", "content": question}) + llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] + model=model, messages=messages, options=options + ) + + +@patch("builtins.__import__") +def test_ollama_invoke_with_message_history_and_system_instruction( + mock_import: Mock, +) -> None: + mock_ollama = get_mock_ollama() + mock_import.return_value = mock_ollama + mock_ollama.Client.return_value.chat.return_value = MagicMock( + message=MagicMock(content="ollama chat response"), + ) + model = "gpt" + options = {"temperature": 0.3} + model_params = {"options": options} + system_instruction = "You are a helpful assistant." + llm = OllamaLLM( + model, + model_params=model_params, + ) + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + response = llm.invoke( + question, + message_history, # type: ignore + system_instruction=system_instruction, + ) + assert response.content == "ollama chat response" + messages = [{"role": "system", "content": system_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] + model=model, messages=messages, options=options + ) + assert llm.client.chat.call_count == 1 # type: ignore + + @patch("builtins.__import__") def test_ollama_invoke_with_message_history_validation_error(mock_import: Mock) -> None: mock_ollama = get_mock_ollama() @@ -133,8 +228,9 @@ def test_ollama_invoke_with_message_history_validation_error(mock_import: Mock) ] question = "What about next season?" - with pytest.raises(LLMGenerationError, match="Input validation failed"): + with pytest.raises(LLMGenerationError) as exc_info: llm.invoke(question, message_history) # type: ignore + assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) @pytest.mark.asyncio diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 39571989f..3c5ee1b9e 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -21,7 +21,6 @@ from neo4j_graphrag.llm.openai_llm import AzureOpenAILLM, OpenAILLM from neo4j_graphrag.llm.types import ToolCallResponse from neo4j_graphrag.tool import Tool -from neo4j_graphrag.types import LLMMessage def get_mock_openai() -> MagicMock: @@ -37,7 +36,7 @@ def test_openai_llm_missing_dependency(mock_import: Mock) -> None: @patch("builtins.__import__") -def test_openai_llm_happy_path_e2e(mock_import: Mock) -> None: +def test_openai_llm_happy_path(mock_import: Mock) -> None: mock_openai = get_mock_openai() mock_import.return_value = mock_openai mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( @@ -50,31 +49,89 @@ def test_openai_llm_happy_path_e2e(mock_import: Mock) -> None: assert res.content == "openai chat response" -def test_openai_llm_get_messages() -> None: +@patch("builtins.__import__") +def test_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None: + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content="openai chat response"))], + ) llm = OpenAILLM(api_key="my key", model_name="gpt") message_history = [ - LLMMessage(**{"role": "system", "content": "do something"}), - LLMMessage( - **{"role": "user", "content": "When does the sun come up in the summer?"} - ), - LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}), + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + res = llm.invoke(question, message_history) # type: ignore + assert isinstance(res, LLMResponse) + assert res.content == "openai chat response" + message_history.append({"role": "user", "content": question}) + # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions + llm.client.chat.completions.create.assert_called_once() # type: ignore + # Check call arguments individually + call_args = llm.client.chat.completions.create.call_args[ # type: ignore + 1 + ] # Get the keyword arguments + assert call_args["messages"] == message_history + assert call_args["model"] == "gpt" + + +@patch("builtins.__import__") +def test_openai_llm_with_message_history_and_system_instruction( + mock_import: Mock, +) -> None: + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content="openai chat response"))], + ) + system_instruction = "You are a helpful assistent." + llm = OpenAILLM( + api_key="my key", + model_name="gpt", + ) + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, ] + question = "What about next season?" + + res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore + assert isinstance(res, LLMResponse) + assert res.content == "openai chat response" + messages = [{"role": "system", "content": system_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions + llm.client.chat.completions.create.assert_called_once() # type: ignore + # Check call arguments individually + call_args = llm.client.chat.completions.create.call_args[ # type: ignore + 1 + ] # Get the keyword arguments + assert call_args["messages"] == messages + assert call_args["model"] == "gpt" - messages = llm.get_messages(message_history) - assert isinstance(messages, list) - for actual, expected in zip(messages, message_history): - assert isinstance(actual, dict) - assert actual["role"] == expected["role"] - assert actual["content"] == expected["content"] + assert llm.client.chat.completions.create.call_count == 1 # type: ignore -def test_openai_llm_get_messages_unknown_role() -> None: +@patch("builtins.__import__") +def test_openai_llm_with_message_history_validation_error(mock_import: Mock) -> None: + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content="openai chat response"))], + ) llm = OpenAILLM(api_key="my key", model_name="gpt") message_history = [ - LLMMessage(**{"role": "unknown role", "content": "Usually around 6am."}), # type: ignore[typeddict-item] + {"role": "human", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, ] - with pytest.raises(ValueError, match="Unknown role"): - llm.get_messages(message_history) + question = "What about next season?" + + with pytest.raises(LLMGenerationError) as exc_info: + llm.invoke(question, message_history) # type: ignore + assert "Input should be 'user', 'assistant' or 'system'" in str(exc_info.value) @patch("builtins.__import__") @@ -119,6 +176,130 @@ def test_openai_llm_invoke_with_tools_happy_path( assert res.content == "openai tool response" +@patch("builtins.__import__") +@patch("json.loads") +def test_openai_llm_invoke_with_tools_with_message_history( + mock_json_loads: Mock, + mock_import: Mock, + test_tool: Tool, +) -> None: + # Set up json.loads to return a dictionary + mock_json_loads.return_value = {"param1": "value1"} + + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + + # Mock the tool call response + mock_function = MagicMock() + mock_function.name = "test_tool" + mock_function.arguments = '{"param1": "value1"}' + + mock_tool_call = MagicMock() + mock_tool_call.function = mock_function + + mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( + choices=[ + MagicMock( + message=MagicMock( + content="openai tool response", tool_calls=[mock_tool_call] + ) + ) + ], + ) + + llm = OpenAILLM(api_key="my key", model_name="gpt") + tools = [test_tool] + + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + res = llm.invoke_with_tools(question, tools, message_history) # type: ignore + assert isinstance(res, ToolCallResponse) + assert len(res.tool_calls) == 1 + assert res.tool_calls[0].name == "test_tool" + assert res.tool_calls[0].arguments == {"param1": "value1"} + + # Verify the correct messages were passed + message_history.append({"role": "user", "content": question}) + # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions + llm.client.chat.completions.create.assert_called_once() # type: ignore + # Check call arguments individually + call_args = llm.client.chat.completions.create.call_args[ # type: ignore + 1 + ] # Get the keyword arguments + assert call_args["messages"] == message_history + assert call_args["model"] == "gpt" + # Check tools content rather than direct equality + assert len(call_args["tools"]) == 1 + assert call_args["tools"][0]["type"] == "function" + assert call_args["tools"][0]["function"]["name"] == "test_tool" + assert call_args["tools"][0]["function"]["description"] == "A test tool" + assert call_args["tool_choice"] == "auto" + assert call_args["temperature"] == 0.0 + + +@patch("builtins.__import__") +@patch("json.loads") +def test_openai_llm_invoke_with_tools_with_system_instruction( + mock_json_loads: Mock, + mock_import: Mock, + test_tool: Mock, +) -> None: + # Set up json.loads to return a dictionary + mock_json_loads.return_value = {"param1": "value1"} + + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + + # Mock the tool call response + mock_function = MagicMock() + mock_function.name = "test_tool" + mock_function.arguments = '{"param1": "value1"}' + + mock_tool_call = MagicMock() + mock_tool_call.function = mock_function + + mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( + choices=[ + MagicMock( + message=MagicMock( + content="openai tool response", tool_calls=[mock_tool_call] + ) + ) + ], + ) + + llm = OpenAILLM(api_key="my key", model_name="gpt") + tools = [test_tool] + + system_instruction = "You are a helpful assistant." + + res = llm.invoke_with_tools("my text", tools, system_instruction=system_instruction) + assert isinstance(res, ToolCallResponse) + + # Verify system instruction was included + messages = [{"role": "system", "content": system_instruction}] + messages.append({"role": "user", "content": "my text"}) + # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions + llm.client.chat.completions.create.assert_called_once() # type: ignore + # Check call arguments individually + call_args = llm.client.chat.completions.create.call_args[ # type: ignore + 1 + ] # Get the keyword arguments + assert call_args["messages"] == messages + assert call_args["model"] == "gpt" + # Check tools content rather than direct equality + assert len(call_args["tools"]) == 1 + assert call_args["tools"][0]["type"] == "function" + assert call_args["tools"][0]["function"]["name"] == "test_tool" + assert call_args["tools"][0]["function"]["description"] == "A test tool" + assert call_args["tool_choice"] == "auto" + assert call_args["temperature"] == 0.0 + + @patch("builtins.__import__") def test_openai_llm_invoke_with_tools_error(mock_import: Mock, test_tool: Tool) -> None: mock_openai = get_mock_openai() @@ -161,3 +342,67 @@ def test_azure_openai_llm_happy_path(mock_import: Mock) -> None: res = llm.invoke("my text") assert isinstance(res, LLMResponse) assert res.content == "openai chat response" + + +@patch("builtins.__import__") +def test_azure_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None: + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + mock_openai.AzureOpenAI.return_value.chat.completions.create.return_value = ( + MagicMock( + choices=[MagicMock(message=MagicMock(content="openai chat response"))], + ) + ) + llm = AzureOpenAILLM( + model_name="gpt", + azure_endpoint="https://test.openai.azure.com/", + api_key="my key", + api_version="version", + ) + + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + res = llm.invoke(question, message_history) # type: ignore + assert isinstance(res, LLMResponse) + assert res.content == "openai chat response" + message_history.append({"role": "user", "content": question}) + # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions + llm.client.chat.completions.create.assert_called_once() # type: ignore + # Check call arguments individually + call_args = llm.client.chat.completions.create.call_args[ # type: ignore + 1 + ] # Get the keyword arguments + assert call_args["messages"] == message_history + assert call_args["model"] == "gpt" + + +@patch("builtins.__import__") +def test_azure_openai_llm_with_message_history_validation_error( + mock_import: Mock, +) -> None: + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + mock_openai.AzureOpenAI.return_value.chat.completions.create.return_value = ( + MagicMock( + choices=[MagicMock(message=MagicMock(content="openai chat response"))], + ) + ) + llm = AzureOpenAILLM( + model_name="gpt", + azure_endpoint="https://test.openai.azure.com/", + api_key="my key", + api_version="version", + ) + + message_history = [ + {"role": "user", "content": 33}, + ] + question = "What about next season?" + + with pytest.raises(LLMGenerationError) as exc_info: + llm.invoke(question, message_history) # type: ignore + assert "Input should be a valid string" in str(exc_info.value) diff --git a/tests/unit/llm/test_utils.py b/tests/unit/llm/test_utils.py deleted file mode 100644 index 6a969864d..000000000 --- a/tests/unit/llm/test_utils.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# https://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -from pydantic import ValidationError - -from neo4j_graphrag.llm.utils import ( - system_instruction_from_messages, - legacy_inputs_to_messages, -) -from neo4j_graphrag.message_history import InMemoryMessageHistory -from neo4j_graphrag.types import LLMMessage - - -def test_system_instruction_from_messages() -> None: - messages = [ - LLMMessage(role="system", content="text"), - ] - assert system_instruction_from_messages(messages) == "text" - - messages = [] - assert system_instruction_from_messages(messages) is None - - messages = [ - LLMMessage(role="assistant", content="text"), - ] - assert system_instruction_from_messages(messages) is None - - -def test_legacy_inputs_to_messages_only_input_as_llm_message_list() -> None: - messages = legacy_inputs_to_messages( - input=[ - LLMMessage(role="user", content="text"), - ] - ) - assert messages == [ - LLMMessage(role="user", content="text"), - ] - - -def test_legacy_inputs_to_messages_only_input_as_message_history() -> None: - messages = legacy_inputs_to_messages( - input=InMemoryMessageHistory( - messages=[ - LLMMessage(role="user", content="text"), - ] - ) - ) - assert messages == [ - LLMMessage(role="user", content="text"), - ] - - -def test_legacy_inputs_to_messages_only_input_as_str() -> None: - messages = legacy_inputs_to_messages(input="text") - assert messages == [ - LLMMessage(role="user", content="text"), - ] - - -def test_legacy_inputs_to_messages_input_as_str_and_message_history_as_llm_message_list() -> ( - None -): - messages = legacy_inputs_to_messages( - input="text", - message_history=[ - LLMMessage(role="assistant", content="How can I assist you today?"), - ], - ) - assert messages == [ - LLMMessage(role="assistant", content="How can I assist you today?"), - LLMMessage(role="user", content="text"), - ] - - -def test_legacy_inputs_to_messages_input_as_str_and_message_history_as_message_history() -> ( - None -): - messages = legacy_inputs_to_messages( - input="text", - message_history=InMemoryMessageHistory( - messages=[ - LLMMessage(role="assistant", content="How can I assist you today?"), - ] - ), - ) - assert messages == [ - LLMMessage(role="assistant", content="How can I assist you today?"), - LLMMessage(role="user", content="text"), - ] - - -def test_legacy_inputs_to_messages_with_explicit_system_instruction() -> None: - messages = legacy_inputs_to_messages( - input="text", - message_history=[ - LLMMessage(role="assistant", content="How can I assist you today?"), - ], - system_instruction="You are a genius.", - ) - assert messages == [ - LLMMessage(role="system", content="You are a genius."), - LLMMessage(role="assistant", content="How can I assist you today?"), - LLMMessage(role="user", content="text"), - ] - - -def test_legacy_inputs_to_messages_do_not_duplicate_system_instruction() -> None: - with pytest.warns( - UserWarning, - match="system_instruction provided but ignored as the message history already contains a system message", - ): - messages = legacy_inputs_to_messages( - input="text", - message_history=[ - LLMMessage(role="system", content="You are super smart."), - ], - system_instruction="You are a genius.", - ) - assert messages == [ - LLMMessage(role="system", content="You are super smart."), - LLMMessage(role="user", content="text"), - ] - - -def test_legacy_inputs_to_messages_wrong_type_in_message_list() -> None: - with pytest.raises(ValidationError, match="Input should be a valid string"): - legacy_inputs_to_messages( - input="text", - message_history=[ - {"role": "system", "content": 10}, # type: ignore - ], - ) diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index caef9c25c..5d0e9b959 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -13,9 +13,11 @@ # limitations under the License. from __future__ import annotations +from typing import cast from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest +from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.types import ToolCallResponse from neo4j_graphrag.llm.vertexai_llm import VertexAILLM from neo4j_graphrag.tool import Tool @@ -57,14 +59,79 @@ def test_vertexai_invoke_happy_path(GenerativeModelMock: MagicMock) -> None: assert content[0].parts[0].text == input_text +@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM.get_messages") +def test_vertexai_invoke_with_system_instruction( + mock_get_messages: MagicMock, + GenerativeModelMock: MagicMock, +) -> None: + system_instruction = "You are a helpful assistant." + model_name = "gemini-1.5-flash-001" + input_text = "may thy knife chip and shatter" + mock_response = Mock() + mock_response.text = "Return text" + mock_model = GenerativeModelMock.return_value + mock_model.generate_content.return_value = mock_response + + mock_get_messages.return_value = [{"text": "some text"}] + + model_params = {"temperature": 0.5} + llm = VertexAILLM(model_name, model_params) + + response = llm.invoke(input_text, system_instruction=system_instruction) + assert response.content == "Return text" + GenerativeModelMock.assert_called_once_with( + model_name=model_name, + system_instruction=system_instruction, + ) + mock_model.generate_content.assert_called_once_with( + contents=[{"text": "some text"}] + ) + + +@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") +def test_vertexai_invoke_with_message_history_and_system_instruction( + GenerativeModelMock: MagicMock, +) -> None: + system_instruction = "You are a helpful assistant." + model_name = "gemini-1.5-flash-001" + mock_response = Mock() + mock_response.text = "Return text" + mock_model = GenerativeModelMock.return_value + mock_model.generate_content.return_value = mock_response + model_params = {"temperature": 0.5} + llm = VertexAILLM(model_name, model_params) + + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + response = llm.invoke( + question, + message_history, # type: ignore + system_instruction=system_instruction, + ) + assert response.content == "Return text" + GenerativeModelMock.assert_called_once_with( + model_name=model_name, + system_instruction=system_instruction, + ) + last_call = mock_model.generate_content.call_args_list[0] + content = last_call.kwargs["contents"] + assert len(content) == 3 # question + 2 messages in history + + @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") def test_vertexai_get_messages(GenerativeModelMock: MagicMock) -> None: model_name = "gemini-1.5-flash-001" + question = "When does it set?" message_history: list[LLMMessage] = [ - {"role": "system", "content": "Answer to a 3yo kid"}, {"role": "user", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, {"role": "user", "content": "What about next season?"}, + {"role": "assistant", "content": "Around 8am."}, ] expected_response = [ Content( @@ -73,29 +140,33 @@ def test_vertexai_get_messages(GenerativeModelMock: MagicMock) -> None: ), Content(role="model", parts=[Part.from_text("Usually around 6am.")]), Content(role="user", parts=[Part.from_text("What about next season?")]), + Content(role="model", parts=[Part.from_text("Around 8am.")]), + Content(role="user", parts=[Part.from_text("When does it set?")]), ] llm = VertexAILLM(model_name=model_name) - system_instructions, messages = llm.get_messages(message_history) + response = llm.get_messages(question, message_history) GenerativeModelMock.assert_not_called() - assert system_instructions == "Answer to a 3yo kid" - assert len(messages) == len(expected_response) - for actual, expected in zip(messages, expected_response): + assert len(response) == len(expected_response) + for actual, expected in zip(response, expected_response): assert actual.role == expected.role assert actual.parts[0].text == expected.parts[0].text @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock) -> None: + system_instruction = "You are a helpful assistant." model_name = "gemini-1.5-flash-001" + question = "hi!" message_history = [ - LLMMessage(**{"role": "model", "content": "hello!"}), # type: ignore[typeddict-item] + {"role": "model", "content": "hello!"}, ] - llm = VertexAILLM(model_name=model_name) - with pytest.raises(ValueError, match="Unknown role"): - llm.get_messages(message_history) + llm = VertexAILLM(model_name=model_name, system_instruction=system_instruction) + with pytest.raises(LLMGenerationError) as exc_info: + llm.invoke(question, cast(list[LLMMessage], message_history)) + assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) @pytest.mark.asyncio @@ -108,7 +179,7 @@ async def test_vertexai_ainvoke_happy_path( mock_response.text = "Return text" mock_model = GenerativeModelMock.return_value mock_model.generate_content_async = AsyncMock(return_value=mock_response) - mock_get_messages.return_value = None, [{"text": "Return text"}] + mock_get_messages.return_value = [{"text": "Return text"}] model_params = {"temperature": 0.5} llm = VertexAILLM("gemini-1.5-flash-001", model_params) input_text = "may thy knife chip and shatter" @@ -152,7 +223,9 @@ def test_vertexai_invoke_with_tools( res = llm.invoke_with_tools("my text", tools) mock_call_llm.assert_called_once_with( - [{"role": "user", "content": "my text"}], + "my text", + message_history=None, + system_instruction=None, tools=tools, ) mock_parse_tool.assert_called_once() @@ -171,11 +244,11 @@ def test_vertexai_call_llm_with_tools(mock_model: Mock, test_tool: Tool) -> None tools = [test_tool] with patch.object(llm, "_get_llm_tools", return_value=["my tools"]): - res = llm._call_llm([{"role": "user", "content": "my text"}], tools=tools) + res = llm._call_llm("my text", tools=tools) assert isinstance(res, GenerationResponse) mock_model.assert_called_once_with( - None, + system_instruction=None, ) calls = mock_generate_content.call_args_list assert len(calls) == 1 @@ -204,7 +277,9 @@ def test_vertexai_ainvoke_with_tools( res = llm.invoke_with_tools("my text", tools) mock_call_llm.assert_called_once_with( - [{"role": "user", "content": "my text"}], + "my text", + message_history=None, + system_instruction=None, tools=tools, ) mock_parse_tool.assert_called_once() @@ -226,8 +301,8 @@ async def test_vertexai_acall_llm_with_tools(mock_model: Mock, test_tool: Tool) llm = VertexAILLM(model_name="gemini") tools = [test_tool] - res = await llm._acall_llm([{"role": "user", "content": "my text"}], tools=tools) + res = await llm._acall_llm("my text", tools=tools) mock_model.assert_called_once_with( - None, + system_instruction=None, ) assert isinstance(res, GenerationResponse) diff --git a/tests/unit/test_graphrag.py b/tests/unit/test_graphrag.py index b58d8a9e8..925b48b78 100644 --- a/tests/unit/test_graphrag.py +++ b/tests/unit/test_graphrag.py @@ -63,14 +63,7 @@ def test_graphrag_happy_path(retriever_mock: MagicMock, llm: MagicMock) -> None: retriever_mock.search.assert_called_once_with(query_text="question", top_k=111) llm.invoke.assert_called_once_with( - [ - { - "role": "system", - "content": "Answer the user question using the provided context.", - }, - { - "role": "user", - "content": """Context: + """Context: item content 1 item content 2 @@ -82,8 +75,8 @@ def test_graphrag_happy_path(retriever_mock: MagicMock, llm: MagicMock) -> None: Answer: """, - }, - ] + None, # message history + system_instruction="Answer the user question using the provided context.", ) assert isinstance(res, RagResultModel) @@ -149,20 +142,13 @@ def test_graphrag_happy_path_with_message_history( llm.invoke.assert_has_calls( [ call( - [ - {"role": "system", "content": first_invocation_system_instruction}, - {"role": "user", "content": first_invocation_input}, - ] + input=first_invocation_input, + system_instruction=first_invocation_system_instruction, ), call( - [ - { - "role": "system", - "content": "Answer the user question using the provided context.", - }, - *message_history, - {"role": "user", "content": second_invocation}, - ] + second_invocation, + message_history, + system_instruction="Answer the user question using the provided context.", ), ] ) @@ -232,20 +218,13 @@ def test_graphrag_happy_path_with_in_memory_message_history( llm.invoke.assert_has_calls( [ call( - [ - {"role": "system", "content": first_invocation_system_instruction}, - {"role": "user", "content": first_invocation_input}, - ] + input=first_invocation_input, + system_instruction=first_invocation_system_instruction, ), call( - [ - { - "role": "system", - "content": "Answer the user question using the provided context.", - }, - *message_history.messages, - {"role": "user", "content": second_invocation}, - ] + second_invocation, + message_history.messages, + system_instruction="Answer the user question using the provided context.", ), ] ) @@ -274,10 +253,9 @@ def test_graphrag_happy_path_custom_system_instruction( llm.invoke.assert_has_calls( [ call( - [ - {"role": "system", "content": "Custom instruction"}, - {"role": "user", "content": mock.ANY}, - ] + mock.ANY, + None, # no message history + system_instruction="Custom instruction", ), ] )