Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

- Added automatic rate limiting with retry logic and exponential backoff for all Embedding providers using tenacity. The `RateLimitHandler` interface allows for custom rate limiting strategies, including the ability to disable rate limiting entirely.
- JSON response returned to `SchemaFromTextExtractor` is cleansed of any markdown code blocks before being loaded.
- Tool calling support for Ollama in LLMInterface.

## 1.10.0

Expand Down
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ are listed in [the last section of this file](#customize).

- [Tool Calling with OpenAI](./customize/llms/openai_tool_calls.py)
- [Tool Calling with VertexAI](./customize/llms/vertexai_tool_calls.py)
- [Tool Calling with Ollama](./customize/llms/ollama_tool_calls.py)


### Prompts
Expand Down
99 changes: 99 additions & 0 deletions examples/customize/llms/ollama_tool_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
Example showing how to use Ollama tool calls with parameter extraction.
Both synchronous and asynchronous examples are provided.

To run this example:
1. Make sure you have `ollama serve` running
2. Run: python examples/tool_calls/ollama_tool_calls.py
"""

import asyncio
import json
from typing import Dict, Any

from neo4j_graphrag.llm import OllamaLLM
from neo4j_graphrag.llm.types import ToolCallResponse
from neo4j_graphrag.tool import (
Tool,
ObjectParameter,
StringParameter,
IntegerParameter,
)


# Create a custom Tool implementation for person info extraction
parameters = ObjectParameter(
description="Parameters for extracting person information",
properties={
"name": StringParameter(description="The person's full name"),
"age": IntegerParameter(description="The person's age"),
"occupation": StringParameter(description="The person's occupation"),
},
required_properties=["name"],
additional_properties=False,
)
person_info_tool = Tool(
name="extract_person_info",
description="Extract information about a person from text",
parameters=parameters,
execute_func=lambda **kwargs: kwargs,
)

# Create the tool instance
TOOLS = [person_info_tool]


def process_tool_calls(response: ToolCallResponse) -> Dict[str, Any]:
"""Process all tool calls in the response and return the extracted parameters."""
if not response.tool_calls:
raise ValueError("No tool calls found in response")

print(f"\nNumber of tool calls: {len(response.tool_calls)}")
print(f"Additional content: {response.content or 'None'}")

results = []
for i, tool_call in enumerate(response.tool_calls):
print(f"\nTool call #{i + 1}: {tool_call.name}")
print(f"Arguments: {tool_call.arguments}")
results.append(tool_call.arguments)

# For backward compatibility, return the first tool call's arguments
return results[0] if results else {}


async def main() -> None:
# Initialize the Ollama LLM
llm = OllamaLLM(
# model_name="gpt-4o",
model_name="mistral:latest",
model_params={"temperature": 0},
)

# Example text containing information about a person
text = "Stella Hane is a 35-year-old software engineer who loves coding."

print("\n=== Synchronous Tool Call ===")
# Make a synchronous tool call
sync_response = llm.invoke_with_tools(
input=f"Extract information about the person from this text: {text}",
tools=TOOLS,
)
sync_result = process_tool_calls(sync_response)
print("\n=== Synchronous Tool Call Result ===")
print(json.dumps(sync_result, indent=2))

print("\n=== Asynchronous Tool Call ===")
# Make an asynchronous tool call with a different text
text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning."
async_response = await llm.ainvoke_with_tools(
input=f"Extract information about the person from this text: {text2}",
tools=TOOLS,
)
async_result = process_tool_calls(async_response)
print("\n=== Asynchronous Tool Call Result ===")
print(json.dumps(async_result, indent=2))


if __name__ == "__main__":
# Run the async main function
asyncio.run(main())
158 changes: 157 additions & 1 deletion src/neo4j_graphrag/llm/ollama_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Iterable,
List,
Optional,
Sequence,
Union,
cast,
Dict,
)

from pydantic import ValidationError

Expand All @@ -33,9 +43,12 @@
BaseMessage,
LLMResponse,
MessageList,
ToolCall,
ToolCallResponse,
SystemMessage,
UserMessage,
)
from neo4j_graphrag.tool import Tool

if TYPE_CHECKING:
from ollama import Message
Expand Down Expand Up @@ -163,3 +176,146 @@ async def ainvoke(
return LLMResponse(content=content)
except self.ollama.ResponseError as e:
raise LLMGenerationError(e)

@rate_limit_handler
def invoke_with_tools(
self,
input: str,
tools: Sequence[Tool], # Tools definition as a sequence of Tool objects
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
) -> ToolCallResponse:
"""Sends a text input to the LLM with tool definitions
and retrieves a tool call response.
Args:
input (str): Text sent to the LLM.
tools (List[Tool]): List of Tools for the LLM to choose from.
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
Returns:
ToolCallResponse: The response from the LLM containing a tool call.
Raises:
LLMGenerationError: If anything goes wrong.
"""
try:
if isinstance(message_history, MessageHistory):
message_history = message_history.messages

# Convert tools to Ollama's expected type
ollama_tools = []
for tool in tools:
ollama_tool_format = self._convert_tool_to_ollama_format(tool)
ollama_tools.append(ollama_tool_format)
response = self.client.chat(
model=self.model_name,
messages=self.get_messages(input, message_history, system_instruction),
tools=ollama_tools,
**self.model_params,
)
message = response.message
# If there's no tool call, return the content as a regular response
if not message.tool_calls or len(message.tool_calls) == 0:
return ToolCallResponse(
tool_calls=[],
content=message.content,
)

# Process all tool calls
tool_calls = []

for tool_call in message.tool_calls:
args = tool_call.function.arguments
tool_calls.append(
ToolCall(name=tool_call.function.name, arguments=args)
)

return ToolCallResponse(tool_calls=tool_calls, content=message.content)
except self.ollama.ResponseError as e:
raise LLMGenerationError(e)

@async_rate_limit_handler
async def ainvoke_with_tools(
self,
input: str,
tools: Sequence[Tool], # Tools definition as a sequence of Tool objects
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
) -> ToolCallResponse:
"""Sends a text input to the LLM with tool definitions
and retrieves a tool call response.
Args:
input (str): Text sent to the LLM.
tools (List[Tool]): List of Tools for the LLM to choose from.
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
Returns:
ToolCallResponse: The response from the LLM containing a tool call.
Raises:
LLMGenerationError: If anything goes wrong.
"""
try:
if isinstance(message_history, MessageHistory):
message_history = message_history.messages

# Convert tools to Ollama's expected type
ollama_tools = []
for tool in tools:
ollama_tool_format = self._convert_tool_to_ollama_format(tool)
ollama_tools.append(ollama_tool_format)

response = await self.async_client.chat(
model=self.model_name,
messages=self.get_messages(input, message_history, system_instruction),
tools=ollama_tools,
**self.model_params,
)
message = response.message

# If there's no tool call, return the content as a regular response
if not message.tool_calls or len(message.tool_calls) == 0:
return ToolCallResponse(
tool_calls=[],
content=message.content,
)

# Process all tool calls
tool_calls = []

for tool_call in message.tool_calls:
args = tool_call.function.arguments
tool_calls.append(
ToolCall(name=tool_call.function.name, arguments=args)
)

return ToolCallResponse(tool_calls=tool_calls, content=message.content)
except self.ollama.ResponseError as e:
raise LLMGenerationError(e)

def _convert_tool_to_ollama_format(self, tool: Tool) -> Dict[str, Any]:
"""Convert a Tool object to Ollama's expected format.
Args:
tool: A Tool object to convert to Ollama's format.
Returns:
A dictionary in Ollama's tool format.
"""
try:
return {
"type": "function",
"function": {
"name": tool.get_name(),
"description": tool.get_description(),
"parameters": tool.get_parameters(),
},
}
except AttributeError:
raise LLMGenerationError(f"Tool {tool} is not a valid Tool object")
Loading
Loading