diff --git a/chromadb/test/ef/test_nomic_ef.py b/chromadb/test/ef/test_nomic_ef.py new file mode 100644 index 00000000000..7e8aa8bc801 --- /dev/null +++ b/chromadb/test/ef/test_nomic_ef.py @@ -0,0 +1,83 @@ +import os +import pytest +import requests +from requests import HTTPError +from requests.exceptions import ConnectionError +from pytest_httpserver import HTTPServer +import json +from unittest.mock import patch +from chromadb.utils.embedding_functions import NomicEmbeddingFunction + + +@pytest.mark.skipif( + "NOMIC_API_KEY" not in os.environ, + reason="NOMIC_API_KEY environment variable not set, skipping test.", +) +def test_nomic() -> None: + """ + To learn more about the Nomic API: https://docs.nomic.ai/reference/endpoints/nomic-embed-text + Export the NOMIC_API_KEY and optionally the NOMIC_MODEL environment variables. + """ + try: + response = requests.get("https://api-atlas.nomic.ai/v1/health", timeout=10) + # If the response was successful, no Exception will be raised + response.raise_for_status() + except (HTTPError, ConnectionError): + pytest.skip("Nomic API server can't be reached. Skipping test.") + ef = NomicEmbeddingFunction( + api_key=os.environ.get("NOMIC_API_KEY") or "", + model_name=os.environ.get("NOMIC_MODEL") or "nomic-embed-text-v1.5", + ) + embeddings = ef( + ["Henceforth, it is the map that precedes the territory", "nom nom Nomic"] + ) + assert len(embeddings) == 2 + + +def test_nomic_no_api_key() -> None: + """ + To learn more about the Nomic API: https://docs.nomic.ai/reference/endpoints/nomic-embed-text + Test intentionaly excludes the NOMIC_API_KEY. + """ + with pytest.raises(ValueError, match="No Nomic API key provided"): + NomicEmbeddingFunction( + api_key="", + model_name=os.environ.get("NOMIC_MODEL") or "nomic-embed-text-v1.5", + ) + + +def test_nomic_no_model() -> None: + """ + To learn more about the Nomic API: https://docs.nomic.ai/reference/endpoints/nomic-embed-text + Test intentionally excludes the NOMIC_MODEL. api_key does not matter since we expect an error before hitting API. + """ + with pytest.raises(ValueError, match="No Nomic embedding model provided"): + NomicEmbeddingFunction( + api_key="does-not-matter", + model_name="", + ) + + +def test_handle_nomic_api_returns_error() -> None: + """ + To learn more about the Nomic API: https://docs.nomic.ai/reference/endpoints/nomic-embed-text + Mocks an error from the Nomic API, so model and api key don't matter. + """ + with HTTPServer() as httpserver: + httpserver.expect_oneshot_request( + "/embedding/text", method="POST" + ).respond_with_data( + json.dumps({"detail": "error"}), + status=400, + ) + nomic_ef = NomicEmbeddingFunction( + api_key="does-not-matter", + model_name="does-not-matter", + ) + with patch.object( + nomic_ef, + "_api_url", + f"http://{httpserver.host}:{httpserver.port}/embedding/text", + ): + with pytest.raises(Exception): + nomic_ef(["test text"]) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index cc779865675..7275b07fd79 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -907,7 +907,8 @@ def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore ) class ChromaLangchainEmbeddingFunction( - LangchainEmbeddings, EmbeddingFunction[Union[Documents, Images]] # type: ignore + LangchainEmbeddings, + EmbeddingFunction[Union[Documents, Images]], # type: ignore ): """ This class is used as bridge between langchain embedding functions and custom chroma embedding functions. @@ -1017,6 +1018,68 @@ def __call__(self, input: Documents) -> Embeddings: ) +class NomicEmbeddingFunction(EmbeddingFunction[Documents]): + """ + This class is used to generate embeddings for a list of texts using the Nomic Embedding API (https://docs.nomic.ai/reference/endpoints/nomic-embed-text). + """ + + def __init__(self, api_key: str, model_name: str) -> None: + """ + Initialize the Nomic Embedding Function. + + Args: + model_name (str): The name of the model to use for text embeddings. E.g. "nomic-embed-text-v1.5" (see https://docs.nomic.ai/atlas/models/text-embedding for available models). + """ + try: + import requests + except ImportError: + raise ValueError( + "The requests python package is not installed. Please install it with `pip install requests`" + ) + + if not api_key: + raise ValueError("No Nomic API key provided") + if not model_name: + raise ValueError("No Nomic embedding model provided") + + self._api_url = "https://api-atlas.nomic.ai/v1/embedding/text" + self._api_key = api_key + self._model_name = model_name + self._session = requests.Session() + + def __call__(self, input: Documents) -> Embeddings: + """ + Get the embeddings for a list of texts. + + Args: + input (Documents): A list of texts to get embeddings for. + + Returns: + Embeddings: The embeddings for the texts. + + Example: + >>> nomic_ef = NomicEmbeddingFunction(model_name="nomic-embed-text-v1.5") + >>> texts = ["Hello, world!", "How are you?"] + >>> embeddings = nomic_ef(texts) + """ + texts = input if isinstance(input, list) else [input] + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self._api_key}", + } + response = self._session.post( + self._api_url, + headers=headers, + json={"model": self._model_name, "texts": texts}, + ) + response.raise_for_status() + response_json = response.json() + if "embeddings" not in response_json: + raise RuntimeError("Nomic API did not return embeddings") + + return cast(Embeddings, response_json["embeddings"]) + + # List of all classes in this module _classes = [ name diff --git a/requirements_dev.txt b/requirements_dev.txt index 9622fbb7c0c..8f6a9a9d83f 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -9,6 +9,7 @@ pre-commit pytest pytest-asyncio pytest-xdist +pytest_httpserver==1.0.10 setuptools_scm types-protobuf types-requests==2.30.0.0