From e0b5aca0ea6c43228903e16645181d464a2883ec Mon Sep 17 00:00:00 2001 From: Samridhi Kalra Date: Sat, 25 Oct 2025 16:40:41 +0200 Subject: [PATCH] Implement tool calling for Ollama and include examples --- CHANGELOG.md | 1 + examples/README.md | 1 + examples/customize/llms/ollama_tool_calls.py | 99 +++++++++++ src/neo4j_graphrag/llm/ollama_llm.py | 158 +++++++++++++++++- tests/unit/llm/test_ollama_llm.py | 167 +++++++++++++++++++ 5 files changed, 425 insertions(+), 1 deletion(-) create mode 100644 examples/customize/llms/ollama_tool_calls.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 721403ecf..4cbbb3b6d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Added automatic rate limiting with retry logic and exponential backoff for all Embedding providers using tenacity. The `RateLimitHandler` interface allows for custom rate limiting strategies, including the ability to disable rate limiting entirely. - JSON response returned to `SchemaFromTextExtractor` is cleansed of any markdown code blocks before being loaded. +- Tool calling support for Ollama in LLMInterface. ## 1.10.0 diff --git a/examples/README.md b/examples/README.md index 774739b32..9c82c2a5f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -81,6 +81,7 @@ are listed in [the last section of this file](#customize). - [Tool Calling with OpenAI](./customize/llms/openai_tool_calls.py) - [Tool Calling with VertexAI](./customize/llms/vertexai_tool_calls.py) +- [Tool Calling with Ollama](./customize/llms/ollama_tool_calls.py) ### Prompts diff --git a/examples/customize/llms/ollama_tool_calls.py b/examples/customize/llms/ollama_tool_calls.py new file mode 100644 index 000000000..d355ae60f --- /dev/null +++ b/examples/customize/llms/ollama_tool_calls.py @@ -0,0 +1,99 @@ +""" +Example showing how to use Ollama tool calls with parameter extraction. +Both synchronous and asynchronous examples are provided. + +To run this example: +1. Make sure you have `ollama serve` running +2. Run: python examples/tool_calls/ollama_tool_calls.py +""" + +import asyncio +import json +from typing import Dict, Any + +from neo4j_graphrag.llm import OllamaLLM +from neo4j_graphrag.llm.types import ToolCallResponse +from neo4j_graphrag.tool import ( + Tool, + ObjectParameter, + StringParameter, + IntegerParameter, +) + + +# Create a custom Tool implementation for person info extraction +parameters = ObjectParameter( + description="Parameters for extracting person information", + properties={ + "name": StringParameter(description="The person's full name"), + "age": IntegerParameter(description="The person's age"), + "occupation": StringParameter(description="The person's occupation"), + }, + required_properties=["name"], + additional_properties=False, +) +person_info_tool = Tool( + name="extract_person_info", + description="Extract information about a person from text", + parameters=parameters, + execute_func=lambda **kwargs: kwargs, +) + +# Create the tool instance +TOOLS = [person_info_tool] + + +def process_tool_calls(response: ToolCallResponse) -> Dict[str, Any]: + """Process all tool calls in the response and return the extracted parameters.""" + if not response.tool_calls: + raise ValueError("No tool calls found in response") + + print(f"\nNumber of tool calls: {len(response.tool_calls)}") + print(f"Additional content: {response.content or 'None'}") + + results = [] + for i, tool_call in enumerate(response.tool_calls): + print(f"\nTool call #{i + 1}: {tool_call.name}") + print(f"Arguments: {tool_call.arguments}") + results.append(tool_call.arguments) + + # For backward compatibility, return the first tool call's arguments + return results[0] if results else {} + + +async def main() -> None: + # Initialize the Ollama LLM + llm = OllamaLLM( + # model_name="gpt-4o", + model_name="mistral:latest", + model_params={"temperature": 0}, + ) + + # Example text containing information about a person + text = "Stella Hane is a 35-year-old software engineer who loves coding." + + print("\n=== Synchronous Tool Call ===") + # Make a synchronous tool call + sync_response = llm.invoke_with_tools( + input=f"Extract information about the person from this text: {text}", + tools=TOOLS, + ) + sync_result = process_tool_calls(sync_response) + print("\n=== Synchronous Tool Call Result ===") + print(json.dumps(sync_result, indent=2)) + + print("\n=== Asynchronous Tool Call ===") + # Make an asynchronous tool call with a different text + text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning." + async_response = await llm.ainvoke_with_tools( + input=f"Extract information about the person from this text: {text2}", + tools=TOOLS, + ) + async_result = process_tool_calls(async_response) + print("\n=== Asynchronous Tool Call Result ===") + print(json.dumps(async_result, indent=2)) + + +if __name__ == "__main__": + # Run the async main function + asyncio.run(main()) diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index 94541e033..d962fc073 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -15,7 +15,17 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + List, + Optional, + Sequence, + Union, + cast, + Dict, +) from pydantic import ValidationError @@ -33,9 +43,12 @@ BaseMessage, LLMResponse, MessageList, + ToolCall, + ToolCallResponse, SystemMessage, UserMessage, ) +from neo4j_graphrag.tool import Tool if TYPE_CHECKING: from ollama import Message @@ -163,3 +176,146 @@ async def ainvoke( return LLMResponse(content=content) except self.ollama.ResponseError as e: raise LLMGenerationError(e) + + @rate_limit_handler + def invoke_with_tools( + self, + input: str, + tools: Sequence[Tool], # Tools definition as a sequence of Tool objects + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + ) -> ToolCallResponse: + """Sends a text input to the LLM with tool definitions + and retrieves a tool call response. + + Args: + input (str): Text sent to the LLM. + tools (List[Tool]): List of Tools for the LLM to choose from. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. + + Returns: + ToolCallResponse: The response from the LLM containing a tool call. + + Raises: + LLMGenerationError: If anything goes wrong. + """ + try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + + # Convert tools to Ollama's expected type + ollama_tools = [] + for tool in tools: + ollama_tool_format = self._convert_tool_to_ollama_format(tool) + ollama_tools.append(ollama_tool_format) + response = self.client.chat( + model=self.model_name, + messages=self.get_messages(input, message_history, system_instruction), + tools=ollama_tools, + **self.model_params, + ) + message = response.message + # If there's no tool call, return the content as a regular response + if not message.tool_calls or len(message.tool_calls) == 0: + return ToolCallResponse( + tool_calls=[], + content=message.content, + ) + + # Process all tool calls + tool_calls = [] + + for tool_call in message.tool_calls: + args = tool_call.function.arguments + tool_calls.append( + ToolCall(name=tool_call.function.name, arguments=args) + ) + + return ToolCallResponse(tool_calls=tool_calls, content=message.content) + except self.ollama.ResponseError as e: + raise LLMGenerationError(e) + + @async_rate_limit_handler + async def ainvoke_with_tools( + self, + input: str, + tools: Sequence[Tool], # Tools definition as a sequence of Tool objects + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + ) -> ToolCallResponse: + """Sends a text input to the LLM with tool definitions + and retrieves a tool call response. + + Args: + input (str): Text sent to the LLM. + tools (List[Tool]): List of Tools for the LLM to choose from. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. + + Returns: + ToolCallResponse: The response from the LLM containing a tool call. + + Raises: + LLMGenerationError: If anything goes wrong. + """ + try: + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + + # Convert tools to Ollama's expected type + ollama_tools = [] + for tool in tools: + ollama_tool_format = self._convert_tool_to_ollama_format(tool) + ollama_tools.append(ollama_tool_format) + + response = await self.async_client.chat( + model=self.model_name, + messages=self.get_messages(input, message_history, system_instruction), + tools=ollama_tools, + **self.model_params, + ) + message = response.message + + # If there's no tool call, return the content as a regular response + if not message.tool_calls or len(message.tool_calls) == 0: + return ToolCallResponse( + tool_calls=[], + content=message.content, + ) + + # Process all tool calls + tool_calls = [] + + for tool_call in message.tool_calls: + args = tool_call.function.arguments + tool_calls.append( + ToolCall(name=tool_call.function.name, arguments=args) + ) + + return ToolCallResponse(tool_calls=tool_calls, content=message.content) + except self.ollama.ResponseError as e: + raise LLMGenerationError(e) + + def _convert_tool_to_ollama_format(self, tool: Tool) -> Dict[str, Any]: + """Convert a Tool object to Ollama's expected format. + + Args: + tool: A Tool object to convert to Ollama's format. + + Returns: + A dictionary in Ollama's tool format. + """ + try: + return { + "type": "function", + "function": { + "name": tool.get_name(), + "description": tool.get_description(), + "parameters": tool.get_parameters(), + }, + } + except AttributeError: + raise LLMGenerationError(f"Tool {tool} is not a valid Tool object") diff --git a/tests/unit/llm/test_ollama_llm.py b/tests/unit/llm/test_ollama_llm.py index c1d3f9fdc..289449472 100644 --- a/tests/unit/llm/test_ollama_llm.py +++ b/tests/unit/llm/test_ollama_llm.py @@ -20,6 +20,8 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm import LLMResponse from neo4j_graphrag.llm.ollama_llm import OllamaLLM +from neo4j_graphrag.llm.types import ToolCallResponse +from neo4j_graphrag.tool import Tool def get_mock_ollama() -> MagicMock: @@ -257,3 +259,168 @@ async def mock_chat_async(*args: Any, **kwargs: Any) -> MagicMock: res = await llm.ainvoke(question) assert isinstance(res, LLMResponse) assert res.content == "ollama chat response" + + +@patch("builtins.__import__") +@patch("json.loads") +def test_ollama_llm_invoke_with_tools_happy_path( + mock_json_loads: Mock, + mock_import: Mock, + test_tool: Tool, +) -> None: + # Set up json.loads to return a dictionary + mock_json_loads.return_value = {"param1": "value1"} + + mock_ollama = get_mock_ollama() + mock_import.return_value = mock_ollama + + # Mock the tool call response + mock_function = MagicMock() + mock_function.name = "test_tool" + mock_function.arguments = {"param1": "value1"} + + mock_tool_call = MagicMock() + mock_tool_call.function = mock_function + + mock_ollama.Client.return_value.chat.return_value = MagicMock( + message=MagicMock(content="ollama tool response", tool_calls=[mock_tool_call]) + ) + + llm = OllamaLLM(model_name="gpt", model_params={"options": {"temperature": 0}}) + tools = [test_tool] + + res = llm.invoke_with_tools("my text", tools) + assert isinstance(res, ToolCallResponse) + assert len(res.tool_calls) == 1 + assert res.tool_calls[0].name == "test_tool" + assert res.tool_calls[0].arguments == {"param1": "value1"} + assert res.content == "ollama tool response" + + +@patch("builtins.__import__") +@patch("json.loads") +def test_ollama_llm_invoke_with_tools_with_message_history( + mock_json_loads: Mock, + mock_import: Mock, + test_tool: Tool, +) -> None: + # Set up json.loads to return a dictionary + mock_json_loads.return_value = {"param1": "value1"} + + mock_ollama = get_mock_ollama() + mock_import.return_value = mock_ollama + + # Mock the tool call response + mock_function = MagicMock() + mock_function.name = "test_tool" + mock_function.arguments = {"param1": "value1"} + + mock_tool_call = MagicMock() + mock_tool_call.function = mock_function + mock_ollama.Client.return_value.chat.return_value = MagicMock( + message=MagicMock(content="ollama tool response", tool_calls=[mock_tool_call]) + ) + llm = OllamaLLM( + api_key="my key", model_name="gpt", model_params={"options": {"temperature": 0}} + ) + tools = [test_tool] + + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + res = llm.invoke_with_tools(question, tools, message_history) # type: ignore + assert isinstance(res, ToolCallResponse) + assert len(res.tool_calls) == 1 + assert res.tool_calls[0].name == "test_tool" + assert res.tool_calls[0].arguments == {"param1": "value1"} + + # Verify the correct messages were passed + message_history.append({"role": "user", "content": question}) + # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions + llm.client.chat.assert_called_once() # type: ignore + # Check call arguments individually + call_args = llm.client.chat.call_args[ # type: ignore + 1 + ] # Get the keyword arguments + assert call_args["messages"] == message_history + assert call_args["model"] == "gpt" + # Check tools content rather than direct equality + assert len(call_args["tools"]) == 1 + assert call_args["tools"][0]["type"] == "function" + assert call_args["tools"][0]["function"]["name"] == "test_tool" + assert call_args["tools"][0]["function"]["description"] == "A test tool" + + +@patch("builtins.__import__") +@patch("json.loads") +def test_ollama_llm_invoke_with_tools_with_system_instruction( + mock_json_loads: Mock, + mock_import: Mock, + test_tool: Mock, +) -> None: + # Set up json.loads to return a dictionary + mock_json_loads.return_value = {"param1": "value1"} + + mock_ollama = get_mock_ollama() + mock_import.return_value = mock_ollama + + # Mock the tool call response + mock_function = MagicMock() + mock_function.name = "test_tool" + mock_function.arguments = {"param1": "value1"} + + mock_tool_call = MagicMock() + mock_tool_call.function = mock_function + + mock_ollama.Client.return_value.chat.return_value = MagicMock( + message=MagicMock(content="ollama tool response", tool_calls=[mock_tool_call]) + ) + + llm = OllamaLLM( + api_key="my key", model_name="gpt", model_params={"options": {"temperature": 0}} + ) + tools = [test_tool] + + system_instruction = "You are a helpful assistant." + + res = llm.invoke_with_tools("my text", tools, system_instruction=system_instruction) + assert isinstance(res, ToolCallResponse) + + # Verify system instruction was included + messages = [{"role": "system", "content": system_instruction}] + messages.append({"role": "user", "content": "my text"}) + # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions + llm.client.chat.assert_called_once() # type: ignore + # Check call arguments individually + call_args = llm.client.chat.call_args[ # type: ignore + 1 + ] # Get the keyword arguments + assert call_args["messages"] == messages + assert call_args["model"] == "gpt" + # Check tools content rather than direct equality + assert len(call_args["tools"]) == 1 + assert call_args["tools"][0]["type"] == "function" + assert call_args["tools"][0]["function"]["name"] == "test_tool" + assert call_args["tools"][0]["function"]["description"] == "A test tool" + + +@patch("builtins.__import__") +def test_ollama_llm_invoke_with_tools_error(mock_import: Mock, test_tool: Tool) -> None: + mock_ollama = get_mock_ollama() + mock_import.return_value = mock_ollama + + # Mock an Ollama response error + mock_ollama.Client.return_value.chat.side_effect = ollama.ResponseError( + "Test error" + ) + + llm = OllamaLLM( + api_key="my key", model_name="gpt", model_params={"options": {"temperature": 0}} + ) + tools = [test_tool] + + with pytest.raises(LLMGenerationError): + llm.invoke_with_tools("my text", tools)