Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/neo4j_graphrag/embeddings/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler
from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler, async_rate_limit_handler


class OllamaEmbeddings(Embedder):
Expand Down Expand Up @@ -47,6 +47,7 @@ def __init__(
super().__init__(rate_limit_handler)
self.model = model
self.client = ollama.Client(**kwargs)
self.async_client = ollama.AsyncClient(**kwargs)

@rate_limit_handler
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
Expand All @@ -73,3 +74,29 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]:
raise EmbeddingsGenerationError("Embedding is not a list of floats.")

return embedding

@async_rate_limit_handler
async def async_embed_query(self, text: str, **kwargs: Any) -> list[float]:
"""
Generate embeddings for a given query using an Ollama text embedding model.

Args:
text (str): The text to generate an embedding for.
**kwargs (Any): Additional keyword arguments to pass to the Ollama client.
"""
embeddings_response = await self.async_client.embed(
model=self.model,
input=text,
**kwargs,
)

if embeddings_response is None or not embeddings_response.embeddings:
raise EmbeddingsGenerationError("Failed to retrieve embeddings.")

embeddings = embeddings_response.embeddings
# client always returns a sequence of sequences
embedding = embeddings[0]
if not isinstance(embedding, list):
raise EmbeddingsGenerationError("Embedding is not a list of floats.")

return embedding
57 changes: 51 additions & 6 deletions src/neo4j_graphrag/experimental/components/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# limitations under the License.
from pydantic import validate_call

import asyncio
from typing import Any, List, Optional, Union

from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks
from neo4j_graphrag.experimental.pipeline.component import Component
Expand All @@ -24,6 +27,7 @@ class TextChunkEmbedder(Component):

Args:
embedder (Embedder): The embedder to use to create the embeddings.
max_concurrency (int): The maximum number of concurrent embedding requests. Defaults to 5.

Example:

Expand All @@ -34,14 +38,21 @@ class TextChunkEmbedder(Component):
from neo4j_graphrag.experimental.pipeline import Pipeline

embedder = OpenAIEmbeddings(model="text-embedding-3-large")
chunk_embedder = TextChunkEmbedder(embedder)
chunk_embedder = TextChunkEmbedder(embedder=embedder, max_concurrency=10)
pipeline = Pipeline()
pipeline.add_component(chunk_embedder, "chunk_embedder")

"""

def __init__(self, embedder: Embedder):
def __init__(
self,
*args: Any,
embedder: Embedder,
max_concurrency: int = 5,
**kwargs: Any,
) -> None:
self._embedder = embedder
self._max_concurrency = max_concurrency

def _embed_chunk(self, text_chunk: TextChunk) -> TextChunk:
"""Embed a single text chunk.
Expand All @@ -62,9 +73,36 @@ def _embed_chunk(self, text_chunk: TextChunk) -> TextChunk:
metadata=metadata,
uid=text_chunk.uid,
)

async def _async_embed_chunk(
self,
sem: asyncio.Semaphore,
text_chunk: TextChunk) -> TextChunk:
"""Asynchronously embed a single text chunk.

Args:
text_chunk (TextChunk): The text chunk to embed.

Returns:
TextChunk: The text chunk with an added "embedding" key in its
metadata containing the embeddings of the text chunk's text.
"""
async with sem:
embedding = await self._embedder.async_embed_query(text_chunk.text)
metadata = text_chunk.metadata if text_chunk.metadata else {}
metadata["embedding"] = embedding
return TextChunk(
text=text_chunk.text,
index=text_chunk.index,
metadata=metadata,
uid=text_chunk.uid,
)

@validate_call
async def run(self, text_chunks: TextChunks) -> TextChunks:
async def run(
self,
text_chunks: TextChunks
) -> TextChunks:
"""Embed a list of text chunks.

Args:
Expand All @@ -73,6 +111,13 @@ async def run(self, text_chunks: TextChunks) -> TextChunks:
Returns:
TextChunks: The input text chunks with each one having an added embedding.
"""
return TextChunks(
chunks=[self._embed_chunk(text_chunk) for text_chunk in text_chunks.chunks]
)
sem = asyncio.Semaphore(self._max_concurrency)
tasks = [
self._async_embed_chunk(
sem,
text_chunk,
)
for text_chunk in text_chunks.chunks
]
text_chunks: TextChunks = list(await asyncio.gather(*tasks))
return TextChunks(chunks=text_chunks)