diff --git a/src/neo4j_graphrag/embeddings/ollama.py b/src/neo4j_graphrag/embeddings/ollama.py index 88f85096..cfb43dd6 100644 --- a/src/neo4j_graphrag/embeddings/ollama.py +++ b/src/neo4j_graphrag/embeddings/ollama.py @@ -19,7 +19,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler +from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler, async_rate_limit_handler class OllamaEmbeddings(Embedder): @@ -47,6 +47,7 @@ def __init__( super().__init__(rate_limit_handler) self.model = model self.client = ollama.Client(**kwargs) + self.async_client = ollama.AsyncClient(**kwargs) @rate_limit_handler def embed_query(self, text: str, **kwargs: Any) -> list[float]: @@ -73,3 +74,29 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]: raise EmbeddingsGenerationError("Embedding is not a list of floats.") return embedding + + @async_rate_limit_handler + async def async_embed_query(self, text: str, **kwargs: Any) -> list[float]: + """ + Generate embeddings for a given query using an Ollama text embedding model. + + Args: + text (str): The text to generate an embedding for. + **kwargs (Any): Additional keyword arguments to pass to the Ollama client. + """ + embeddings_response = await self.async_client.embed( + model=self.model, + input=text, + **kwargs, + ) + + if embeddings_response is None or not embeddings_response.embeddings: + raise EmbeddingsGenerationError("Failed to retrieve embeddings.") + + embeddings = embeddings_response.embeddings + # client always returns a sequence of sequences + embedding = embeddings[0] + if not isinstance(embedding, list): + raise EmbeddingsGenerationError("Embedding is not a list of floats.") + + return embedding \ No newline at end of file diff --git a/src/neo4j_graphrag/experimental/components/embedder.py b/src/neo4j_graphrag/experimental/components/embedder.py index f113ecde..cdf24d76 100644 --- a/src/neo4j_graphrag/experimental/components/embedder.py +++ b/src/neo4j_graphrag/experimental/components/embedder.py @@ -14,6 +14,9 @@ # limitations under the License. from pydantic import validate_call +import asyncio +from typing import Any, List, Optional, Union + from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks from neo4j_graphrag.experimental.pipeline.component import Component @@ -24,6 +27,7 @@ class TextChunkEmbedder(Component): Args: embedder (Embedder): The embedder to use to create the embeddings. + max_concurrency (int): The maximum number of concurrent embedding requests. Defaults to 5. Example: @@ -34,14 +38,21 @@ class TextChunkEmbedder(Component): from neo4j_graphrag.experimental.pipeline import Pipeline embedder = OpenAIEmbeddings(model="text-embedding-3-large") - chunk_embedder = TextChunkEmbedder(embedder) + chunk_embedder = TextChunkEmbedder(embedder=embedder, max_concurrency=10) pipeline = Pipeline() pipeline.add_component(chunk_embedder, "chunk_embedder") """ - def __init__(self, embedder: Embedder): + def __init__( + self, + *args: Any, + embedder: Embedder, + max_concurrency: int = 5, + **kwargs: Any, + ) -> None: self._embedder = embedder + self._max_concurrency = max_concurrency def _embed_chunk(self, text_chunk: TextChunk) -> TextChunk: """Embed a single text chunk. @@ -62,9 +73,36 @@ def _embed_chunk(self, text_chunk: TextChunk) -> TextChunk: metadata=metadata, uid=text_chunk.uid, ) + + async def _async_embed_chunk( + self, + sem: asyncio.Semaphore, + text_chunk: TextChunk) -> TextChunk: + """Asynchronously embed a single text chunk. + + Args: + text_chunk (TextChunk): The text chunk to embed. + + Returns: + TextChunk: The text chunk with an added "embedding" key in its + metadata containing the embeddings of the text chunk's text. + """ + async with sem: + embedding = await self._embedder.async_embed_query(text_chunk.text) + metadata = text_chunk.metadata if text_chunk.metadata else {} + metadata["embedding"] = embedding + return TextChunk( + text=text_chunk.text, + index=text_chunk.index, + metadata=metadata, + uid=text_chunk.uid, + ) @validate_call - async def run(self, text_chunks: TextChunks) -> TextChunks: + async def run( + self, + text_chunks: TextChunks + ) -> TextChunks: """Embed a list of text chunks. Args: @@ -73,6 +111,13 @@ async def run(self, text_chunks: TextChunks) -> TextChunks: Returns: TextChunks: The input text chunks with each one having an added embedding. """ - return TextChunks( - chunks=[self._embed_chunk(text_chunk) for text_chunk in text_chunks.chunks] - ) + sem = asyncio.Semaphore(self._max_concurrency) + tasks = [ + self._async_embed_chunk( + sem, + text_chunk, + ) + for text_chunk in text_chunks.chunks + ] + text_chunks: TextChunks = list(await asyncio.gather(*tasks)) + return TextChunks(chunks=text_chunks) \ No newline at end of file