-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] Nomic Text Embed function #2182
base: main
Are you sure you want to change the base?
Changes from 2 commits
1172bd3
af22d27
6ba0864
96b41d1
718344b
1a1d30d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
import requests | ||
from requests import HTTPError | ||
from requests.exceptions import ConnectionError | ||
from pytest_httpserver import HTTPServer | ||
import json | ||
from unittest.mock import patch | ||
|
||
from chromadb.utils.embedding_functions import NomicEmbeddingFunction | ||
|
||
|
||
@pytest.mark.skipif( | ||
"NOMIC_API_KEY" not in os.environ, | ||
reason="NOMIC_API_KEY environment variable not set, skipping test.", | ||
) | ||
def test_nomic() -> None: | ||
""" | ||
To learn more about the Nomic API: https://docs.nomic.ai/reference/endpoints/nomic-embed-text | ||
Export the NOMIC_API_KEY and optionally the NOMIC_MODEL environment variables. | ||
""" | ||
try: | ||
response = requests.get("https://api-atlas.nomic.ai/v1/health", timeout=10) | ||
# If the response was successful, no Exception will be raised | ||
response.raise_for_status() | ||
except (HTTPError, ConnectionError): | ||
pytest.skip("Nomic API server can't be reached. Skipping test.") | ||
ef = NomicEmbeddingFunction( | ||
api_key=os.environ.get("NOMIC_API_KEY") or "", | ||
model_name=os.environ.get("NOMIC_MODEL") or "nomic-embed-text-v1.5", | ||
) | ||
embeddings = ef( | ||
["Henceforth, it is the map that precedes the territory", "nom nom Nomic"] | ||
) | ||
assert len(embeddings) == 2 | ||
|
||
|
||
def test_nomic_no_api_key() -> None: | ||
""" | ||
To learn more about the Nomic API: https://docs.nomic.ai/reference/endpoints/nomic-embed-text | ||
Test intentionaly excludes the NOMIC_API_KEY. | ||
""" | ||
with pytest.raises(ValueError, match="No Nomic API key provided"): | ||
NomicEmbeddingFunction( | ||
api_key="", | ||
model_name=os.environ.get("NOMIC_MODEL") or "nomic-embed-text-v1.5", | ||
) | ||
|
||
|
||
@pytest.mark.skipif( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can remove the decorator here, as this is a negative test. It is supposed to fail regardless of whether we're testing with the Nomic API key or not. |
||
"NOMIC_API_KEY" not in os.environ, | ||
reason="NOMIC_API_KEY environment variable not set, skipping test.", | ||
) | ||
def test_nomic_no_model() -> None: | ||
""" | ||
To learn more about the Nomic API: https://docs.nomic.ai/reference/endpoints/nomic-embed-text | ||
Test intentionaly excludes the NOMIC_MODEL. | ||
""" | ||
with pytest.raises(ValueError, match="No Nomic embedding model provided"): | ||
NomicEmbeddingFunction( | ||
api_key=os.environ.get("NOMIC_API_KEY") or "", | ||
model_name="", | ||
) | ||
|
||
|
||
@pytest.mark.skipif( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe also remove this decorator and let the test run, as this is a local mock test. I think it makes sense to add some sort of activation flags (other than API keys) for EF testing, but this requires further consideration. |
||
"NOMIC_API_KEY" not in os.environ, | ||
reason="NOMIC_API_KEY environment variable not set, skipping test.", | ||
) | ||
def test_handle_nomic_api_returns_error() -> None: | ||
""" | ||
To learn more about the Nomic API: https://docs.nomic.ai/reference/endpoints/nomic-embed-text | ||
""" | ||
with HTTPServer() as httpserver: | ||
httpserver.expect_oneshot_request( | ||
"/embedding/text", method="POST" | ||
).respond_with_data( | ||
json.dumps({"detail": "error"}), | ||
status=400, | ||
) | ||
nomic_ef = NomicEmbeddingFunction( | ||
api_key=os.environ.get("NOMIC_API_KEY") or "", | ||
model_name=os.environ.get("NOMIC_MODEL") or "nomic-embed-text-v1.5", | ||
) | ||
with patch.object( | ||
nomic_ef, | ||
"_api_url", | ||
f"http://{httpserver.host}:{httpserver.port}/embedding/text", | ||
): | ||
with pytest.raises(Exception): | ||
nomic_ef(["test text"]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -907,7 +907,8 @@ def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore | |
) | ||
|
||
class ChromaLangchainEmbeddingFunction( | ||
LangchainEmbeddings, EmbeddingFunction[Union[Documents, Images]] # type: ignore | ||
LangchainEmbeddings, | ||
EmbeddingFunction[Union[Documents, Images]], # type: ignore | ||
): | ||
""" | ||
This class is used as bridge between langchain embedding functions and custom chroma embedding functions. | ||
|
@@ -1017,6 +1018,68 @@ def __call__(self, input: Documents) -> Embeddings: | |
) | ||
|
||
|
||
class NomicEmbeddingFunction(EmbeddingFunction[Documents]): | ||
""" | ||
This class is used to generate embeddings for a list of texts using the Nomic Embedding API (https://docs.nomic.ai/reference/endpoints/nomic-embed-text). | ||
""" | ||
|
||
def __init__(self, api_key: str, model_name: str) -> None: | ||
""" | ||
Initialize the Nomic Embedding Function. | ||
|
||
Args: | ||
model_name (str): The name of the model to use for text embeddings. E.g. "nomic-embed-text-v1.5" (see https://docs.nomic.ai/atlas/models/text-embedding for available models). | ||
""" | ||
try: | ||
import requests | ||
except ImportError: | ||
raise ValueError( | ||
"The requests python package is not installed. Please install it with `pip install requests`" | ||
) | ||
|
||
if not api_key: | ||
raise ValueError("No Nomic API key provided") | ||
if not model_name: | ||
raise ValueError("No Nomic embedding model provided") | ||
|
||
self._api_url = "https://api-atlas.nomic.ai/v1/embedding/text" | ||
self._api_key = api_key | ||
self._model_name = model_name | ||
self._session = requests.Session() | ||
|
||
def __call__(self, input: Documents) -> Embeddings: | ||
""" | ||
Get the embeddings for a list of texts. | ||
|
||
Args: | ||
input (Documents): A list of texts to get embeddings for. | ||
|
||
Returns: | ||
Embeddings: The embeddings for the texts. | ||
|
||
Example: | ||
>>> nomic_ef = NomicEmbeddingFunction(model_name="nomic-embed-text-v1.5") | ||
>>> texts = ["Hello, world!", "How are you?"] | ||
>>> embeddings = nomic_ef(texts) | ||
""" | ||
texts = input if isinstance(input, list) else [input] | ||
headers = { | ||
"Content-Type": "application/json", | ||
"Authorization": f"Bearer {self._api_key}", | ||
} | ||
response = self._session.post( | ||
self._api_url, | ||
headers=headers, | ||
json={"model": self._model_name, "texts": texts}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the limit to the maximum number of texts that can be sent at once? If there's a limit, let's implement it on the client side so we don't do the round trip to raise an error. Important: Do not add any loop logic here. If a chunk fails, throw an exception and let users implement their own chunking and subsequent error handling. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nomic responded: You can send as many as you'd like per request, they will get processed in parallel. It's recommended you break it up yourself into several requests or use the Nomic python client because you will see network latency due to the large request/response size if you send dozens of megabytes of text in a single request" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's leave it as-is for now. We have some ideas that we're trying to develop to make error handling a little bit more consistent across all EFs. |
||
) | ||
response.raise_for_status() | ||
response_json = response.json() | ||
if "embeddings" not in response_json: | ||
raise RuntimeError("Nomic API did not return embeddings") | ||
|
||
return cast(Embeddings, response_json["embeddings"]) | ||
|
||
|
||
# List of all classes in this module | ||
_classes = [ | ||
name | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add couple of negative test cases:
For the mock server:
pip install pytest_httpserver>=1.0.10
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added these tests as well as one for missing model name, but did not include the "too many texts" test. Nomic responded saying:
So I am not sure 1) what we want to consider too large 2) best way to check for that 3) if you still want this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also -- should I add pytest_httpserver to the projects requirements_dev file and check it in as part of this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, please add it to the
requirements_dev
.We do have a limitation in Chroma API regarding the maximum number of embeddings, but we generally don't tie that to the embeddings functions batch size. So you can leave it as is for now.