Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Support langchain embedding functions with chroma #1880

Merged
merged 2 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions chromadb/utils/data_loaders.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import importlib
import multiprocessing
from typing import Optional, Sequence, List
from typing import Optional, Sequence, List, Tuple
import numpy as np
from chromadb.api.types import URI, DataLoader, Image
from chromadb.api.types import URI, DataLoader, Image, URIs
from concurrent.futures import ThreadPoolExecutor


Expand All @@ -22,3 +22,10 @@ def _load_image(self, uri: Optional[URI]) -> Optional[Image]:
def __call__(self, uris: Sequence[Optional[URI]]) -> List[Optional[Image]]:
with ThreadPoolExecutor(max_workers=self._max_workers) as executor:
return list(executor.map(self._load_image, uris))


class ChromaLangchainPassthroughDataLoader(DataLoader[List[Optional[Image]]]):
# This is a simple pass through data loader that just returns the input data with "images"
# flag which lets the langchain embedding function know that the data is image uris
def __call__(self, uris: URIs) -> Tuple[str, URIs]: # type: ignore
return ("images", uris)
64 changes: 64 additions & 0 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,69 @@ def __call__(self, input: Documents) -> Embeddings:
)


def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore
try:
from langchain_core.embeddings import Embeddings as LangchainEmbeddings
except ImportError:
raise ValueError(
"The langchain_core python package is not installed. Please install it with `pip install langchain-core`"
)

class ChromaLangchainEmbeddingFunction(
LangchainEmbeddings, EmbeddingFunction[Union[Documents, Images]] # type: ignore
):
"""
This class is used as bridge between langchain embedding functions and custom chroma embedding functions.
"""

def __init__(self, embedding_function: LangchainEmbeddings) -> None:
"""
Initialize the ChromaLangchainEmbeddingFunction

Args:
embedding_function : The embedding function implementing Embeddings from langchain_core.
"""
self.embedding_function = embedding_function

def embed_documents(self, documents: Documents) -> List[List[float]]:
return self.embedding_function.embed_documents(documents) # type: ignore

def embed_query(self, query: str) -> List[float]:
return self.embedding_function.embed_query(query) # type: ignore

def embed_image(self, uris: List[str]) -> List[List[float]]:
if hasattr(self.embedding_function, "embed_image"):
return self.embedding_function.embed_image(uris) # type: ignore
else:
raise ValueError(
"The provided embedding function does not support image embeddings."
)

def __call__(self, input: Documents) -> Embeddings: # type: ignore
"""
Get the embeddings for a list of texts or images.

Args:
input (Documents | Images): A list of texts or images to get embeddings for.
Images should be provided as a list of URIs passed through the langchain data loader

Returns:
Embeddings: The embeddings for the texts or images.

Example:
>>> langchain_embedding = ChromaLangchainEmbeddingFunction(embedding_function=OpenAIEmbeddings(model="text-embedding-3-large"))
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = langchain_embedding(texts)
"""
# Due to langchain quirks, the dataloader returns a tuple if the input is uris of images
if input[0] == "images":
return self.embed_image(list(input[1])) # type: ignore

return self.embed_documents(list(input)) # type: ignore

return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn)


class OllamaEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings).
Expand Down Expand Up @@ -955,6 +1018,7 @@ def __call__(self, input: Documents) -> Embeddings:
],
)


# List of all classes in this module
_classes = [
name
Expand Down
Loading