Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions .github/workflows/pr-e2e-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,17 @@ jobs:
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Docker Prune
- name: Free up disk space (ubuntu-latest)
run: |
docker system prune -af
docker volume prune -f
sudo rm -rf /usr/local/lib/android \
/usr/share/dotnet \
/opt/ghc \
/opt/hostedtoolcache
docker system prune -af || true
docker volume prune -f || true
docker builder prune -af || true
sudo apt-get clean || true
df -h
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
Expand Down
7 changes: 7 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,13 @@ Text2CypherRetriever
:members: search


ToolsRetriever
==============

.. autoclass:: neo4j_graphrag.retrievers.ToolsRetriever
:members: search


*******************
External Retrievers
*******************
Expand Down
24 changes: 24 additions & 0 deletions docs/source/types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,30 @@ LLMMessage
===========

.. autoclass:: neo4j_graphrag.types.LLMMessage
:members:
:undoc-members:

Tool
====

.. autoclass:: neo4j_graphrag.tool.Tool


ToolParameter
=============

.. autoclass:: neo4j_graphrag.tool.ToolParameter


ObjectParameter
===============

.. autoclass:: neo4j_graphrag.tool.ObjectParameter

ParameterType
=============

.. autoenum:: neo4j_graphrag.tool.ParameterType


RagResultModel
Expand Down
16 changes: 8 additions & 8 deletions docs/source/user_guide_rag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ Its interface is compatible with our `GraphRAG` interface, facilitating integrat

It is however not mandatory to use LangChain.

.. warning:: ToolsRetriever

LangChain models are not compatible with the :ref:`toolsretriever`.

Using a Custom Model
--------------------

Expand All @@ -265,21 +269,17 @@ Here's an example using the Python Ollama client:

import ollama
from neo4j_graphrag.llm import LLMInterface, LLMResponse
from neo4j_graphrag.types import LLMMessage

class OllamaLLM(LLMInterface):

def invoke(self, input: str) -> LLMResponse:
response = ollama.chat(model=self.model_name, messages=[
{
'role': 'user',
'content': input,
},
])
def _invoke(self, input: list[LLMMessage]) -> LLMResponse:
response = ollama.chat(model=self.model_name, messages=input)
return LLMResponse(
content=response["message"]["content"]
)

async def ainvoke(self, input: str) -> LLMResponse:
async def _ainvoke(self, input: list[LLMMessage]) -> LLMResponse:
return self.invoke(input) # TODO: implement async with ollama.AsyncClient


Expand Down
1 change: 1 addition & 0 deletions examples/customize/embeddings/custom_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

class CustomEmbeddings(Embedder):
def __init__(self, dimension: int = 10, **kwargs: Any):
super().__init__(**kwargs)
self.dimension = dimension

def _embed_query(self, input: str) -> list[float]:
Expand Down
5 changes: 2 additions & 3 deletions src/neo4j_graphrag/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Optional

from neo4j_graphrag.utils.rate_limit import (
Expand All @@ -24,7 +23,7 @@
)


class Embedder(ABC):
class Embedder:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we care about things lost when we remove inheritance from ABC? less safety and early error handling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the only other solution I can think of is to create a LLMInterfaceV2. If you prefer this option, I'm fine with it.

"""
Interface for embedding models.
An embedder passed into a retriever must implement this interface.
Expand All @@ -51,7 +50,6 @@ def embed_query(self, text: str) -> list[float]:
"""
return self._embed_query(text)

@abstractmethod
def _embed_query(self, text: str) -> list[float]:
"""Embed query text.

Expand All @@ -61,3 +59,4 @@ def _embed_query(self, text: str) -> list[float]:
Returns:
list[float]: A vector embedding.
"""
raise NotImplementedError()
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/experimental/components/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class DocumentInfo(DataModel):
"""

path: str
metadata: Optional[Dict[str, str]] = None
metadata: Optional[Dict[str, Any]] = None
uid: str = Field(default_factory=lambda: str(uuid.uuid4()))
document_type: Optional[str] = None

Expand Down
28 changes: 24 additions & 4 deletions src/neo4j_graphrag/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
from __future__ import annotations

from abc import ABC, abstractmethod
import warnings
from typing import Any, List, Optional, Sequence, Union

from pydantic import ValidationError
Expand All @@ -36,7 +36,7 @@
from ..exceptions import LLMGenerationError


class LLMInterface(ABC):
class LLMInterface:
"""Interface for large language models.

Args:
Expand Down Expand Up @@ -68,6 +68,16 @@ def invoke(
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
if message_history:
warnings.warn(
"Using 'message_history' in the llm.invoke method is deprecated. Please use invoke(list[LLMMessage]) instead.",
DeprecationWarning,
)
if system_instruction:
warnings.warn(
"Using 'system_instruction' in the llm.invoke method is deprecated. Please use invoke(list[LLMMessage]) instead.",
DeprecationWarning,
)
try:
messages = legacy_inputs_to_messages(
input, message_history, system_instruction
Expand All @@ -76,7 +86,6 @@ def invoke(
raise LLMGenerationError("Input validation failed") from e
return self._invoke(messages)

@abstractmethod
def _invoke(
self,
input: list[LLMMessage],
Expand All @@ -92,6 +101,7 @@ def _invoke(
Raises:
LLMGenerationError: If anything goes wrong.
"""
raise NotImplementedError()

@async_rate_limit_handler
async def ainvoke(
Expand All @@ -100,10 +110,19 @@ async def ainvoke(
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
if message_history:
warnings.warn(
"Using 'message_history' in the llm.ainvoke method is deprecated. Please use invoke(list[LLMMessage]) instead.",
DeprecationWarning,
)
if system_instruction:
warnings.warn(
"Using 'system_instruction' in the llm.ainvoke method is deprecated. Please use invoke(list[LLMMessage]) instead.",
DeprecationWarning,
)
messages = legacy_inputs_to_messages(input, message_history, system_instruction)
return await self._ainvoke(messages)

@abstractmethod
async def _ainvoke(
self,
input: list[LLMMessage],
Expand All @@ -119,6 +138,7 @@ async def _ainvoke(
Raises:
LLMGenerationError: If anything goes wrong.
"""
raise NotImplementedError()

@rate_limit_handler
def invoke_with_tools(
Expand Down
17 changes: 11 additions & 6 deletions tests/unit/llm/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,11 @@

@fixture(scope="module") # type: ignore[misc]
def llm_interface() -> Generator[Type[LLMInterface], None, None]:
real_abstract_methods = LLMInterface.__abstractmethods__
LLMInterface.__abstractmethods__ = frozenset()

class CustomLLMInterface(LLMInterface):
pass

yield CustomLLMInterface

LLMInterface.__abstractmethods__ = real_abstract_methods


@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages")
def test_base_llm_interface_invoke_with_input_as_str(
Expand All @@ -52,7 +47,8 @@ def test_base_llm_interface_invoke_with_input_as_str(
system_instruction = "You are a genius."

with patch.object(llm, "_invoke") as mock_invoke:
llm.invoke(question, message_history, system_instruction)
with pytest.warns(DeprecationWarning) as record:
llm.invoke(question, message_history, system_instruction)
mock_invoke.assert_called_once_with(
[
LLMMessage(
Expand All @@ -66,6 +62,15 @@ def test_base_llm_interface_invoke_with_input_as_str(
message_history,
system_instruction,
)
assert len(record) == 2
assert (
"Using 'message_history' in the llm.invoke method is deprecated"
in record[0].message.args[0] # type: ignore[union-attr]
)
assert (
"Using 'system_instruction' in the llm.invoke method is deprecated"
in record[1].message.args[0] # type: ignore[union-attr]
)


@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages")
Expand Down