Skip to content

Commit

Permalink
[ENH] Support langchain embedding functions with chroma (#1880)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - New functionality
- Adding a function to create a chroma langchain embedding interface.
This interface acts as a bridge between the langchain embedding function
and the chroma custom embedding function.
- Native Langchain multimodal support: The PR adds a Passthrough data
loader that lets langchain users use OpenClip and other multi-modal
embedding functions from langchain with chroma without having to handle
storing images themselves.

## Test plan
*How are these changes tested?*


- installing chroma as an editable package locally and passing langchain
integration tests
-  pytest test_api.py test_client.py succeeds

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*

Co-authored-by: Anton Troynikov <atroyn@users.noreply.github.com>
  • Loading branch information
Mihir1003 and atroyn committed Apr 2, 2024
1 parent fdfda56 commit 193988d
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 2 deletions.
11 changes: 9 additions & 2 deletions chromadb/utils/data_loaders.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import importlib
import multiprocessing
from typing import Optional, Sequence, List
from typing import Optional, Sequence, List, Tuple
import numpy as np
from chromadb.api.types import URI, DataLoader, Image
from chromadb.api.types import URI, DataLoader, Image, URIs
from concurrent.futures import ThreadPoolExecutor


Expand All @@ -22,3 +22,10 @@ def _load_image(self, uri: Optional[URI]) -> Optional[Image]:
def __call__(self, uris: Sequence[Optional[URI]]) -> List[Optional[Image]]:
with ThreadPoolExecutor(max_workers=self._max_workers) as executor:
return list(executor.map(self._load_image, uris))


class ChromaLangchainPassthroughDataLoader(DataLoader[List[Optional[Image]]]):
# This is a simple pass through data loader that just returns the input data with "images"
# flag which lets the langchain embedding function know that the data is image uris
def __call__(self, uris: URIs) -> Tuple[str, URIs]: # type: ignore
return ("images", uris)
64 changes: 64 additions & 0 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,69 @@ def __call__(self, input: Documents) -> Embeddings:
)


def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore
try:
from langchain_core.embeddings import Embeddings as LangchainEmbeddings
except ImportError:
raise ValueError(
"The langchain_core python package is not installed. Please install it with `pip install langchain-core`"
)

class ChromaLangchainEmbeddingFunction(
LangchainEmbeddings, EmbeddingFunction[Union[Documents, Images]] # type: ignore
):
"""
This class is used as bridge between langchain embedding functions and custom chroma embedding functions.
"""

def __init__(self, embedding_function: LangchainEmbeddings) -> None:
"""
Initialize the ChromaLangchainEmbeddingFunction
Args:
embedding_function : The embedding function implementing Embeddings from langchain_core.
"""
self.embedding_function = embedding_function

def embed_documents(self, documents: Documents) -> List[List[float]]:
return self.embedding_function.embed_documents(documents) # type: ignore

def embed_query(self, query: str) -> List[float]:
return self.embedding_function.embed_query(query) # type: ignore

def embed_image(self, uris: List[str]) -> List[List[float]]:
if hasattr(self.embedding_function, "embed_image"):
return self.embedding_function.embed_image(uris) # type: ignore
else:
raise ValueError(
"The provided embedding function does not support image embeddings."
)

def __call__(self, input: Documents) -> Embeddings: # type: ignore
"""
Get the embeddings for a list of texts or images.
Args:
input (Documents | Images): A list of texts or images to get embeddings for.
Images should be provided as a list of URIs passed through the langchain data loader
Returns:
Embeddings: The embeddings for the texts or images.
Example:
>>> langchain_embedding = ChromaLangchainEmbeddingFunction(embedding_function=OpenAIEmbeddings(model="text-embedding-3-large"))
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = langchain_embedding(texts)
"""
# Due to langchain quirks, the dataloader returns a tuple if the input is uris of images
if input[0] == "images":
return self.embed_image(list(input[1])) # type: ignore

return self.embed_documents(list(input)) # type: ignore

return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn)


class OllamaEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings).
Expand Down Expand Up @@ -955,6 +1018,7 @@ def __call__(self, input: Documents) -> Embeddings:
],
)


# List of all classes in this module
_classes = [
name
Expand Down

0 comments on commit 193988d

Please sign in to comment.