Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: Voyage AI updates default model and batch size #17655

Merged
merged 16 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/docs/integrations/text_embedding/voyageai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"id": "137cfde9-b88c-409a-9394-a9e31a6bf30d",
"metadata": {},
"source": [
"Voyage AI utilizes API keys to monitor usage and manage permissions. To obtain your key, create an account on our [homepage](https://www.voyageai.com). Then, create a VoyageEmbeddings model with your API key."
"Voyage AI utilizes API keys to monitor usage and manage permissions. To obtain your key, create an account on our [homepage](https://www.voyageai.com). Then, create a VoyageEmbeddings model with your API key. Please refer to the documentation for further details on the available models: https://docs.voyageai.com/embeddings/"
]
},
{
Expand All @@ -37,7 +37,9 @@
"metadata": {},
"outputs": [],
"source": [
"embeddings = VoyageEmbeddings(voyage_api_key=\"[ Your Voyage API key ]\")"
"embeddings = VoyageEmbeddings(\n",
" voyage_api_key=\"[ Your Voyage API key ]\", model=\"voyage-2\"\n",
")"
]
},
{
Expand Down
42 changes: 30 additions & 12 deletions libs/community/langchain_community/embeddings/voyageai.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ class VoyageEmbeddings(BaseModel, Embeddings):

from langchain_community.embeddings import VoyageEmbeddings

voyage = VoyageEmbeddings(voyage_api_key="your-api-key")
voyage = VoyageEmbeddings(voyage_api_key="your-api-key", model="voyage-2")
text = "This is a test query."
query_result = voyage.embed_query(text)
"""

model: str = "voyage-01"
model: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a breaking change we probably don't want to make

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, we may want to make a breaking change here. We want the user to specify the model when using the embeddings, as we want to encourage them to use the new models but we can't frequently update the default model here.
Do you have any suggestions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would recommend at least giving folks a month or two heads up before doing this by adding a warning for now that the breaking change will happen by a certain date. here's an example: #14614

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and for now update all of the documentation to show the latest models being passed in

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that giving a warning could be a better idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the default value to validate_environment. Thus it should be backward compatible now.

voyage_api_base: str = "https://api.voyageai.com/v1/embeddings"
voyage_api_key: Optional[SecretStr] = None
batch_size: int = 8
batch_size: int
"""Maximum number of texts to embed in each API request."""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
Expand All @@ -86,15 +86,12 @@ class VoyageEmbeddings(BaseModel, Embeddings):
show_progress_bar: bool = False
"""Whether to show a progress bar when embedding. Must have tqdm installed if set
to True."""
truncation: Optional[bool] = None
truncation: bool = True
"""Whether to truncate the input texts to fit within the context length.

If True, over-length input texts will be truncated to fit within the context
length, before vectorized by the embedding model. If False, an error will be
raised if any given text exceeds the context length. If not specified
(defaults to None), we will truncate the input text before sending it to the
embedding model if it slightly exceeds the context window length. If it
significantly exceeds the context window length, an error will be raised."""
raised if any given text exceeds the context length."""

class Config:
"""Configuration for this pydantic object."""
Expand All @@ -107,6 +104,22 @@ def validate_environment(cls, values: Dict) -> Dict:
values["voyage_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "voyage_api_key", "VOYAGE_API_KEY")
)

if "model" not in values:
values["model"] = "voyage-01"
logger.warning(
"model will become a required arg for VoyageAIEmbeddings, "
"we recommend to specify it when using this class. "
"Currently the default is set to voyage-01."
)

if "batch_size" not in values:
values["batch_size"] = (
72
if "model" in values and (values["model"] in ["voyage-2", "voyage-02"])
else 7
)

return values

def _invocation_params(
Expand All @@ -116,11 +129,14 @@ def _invocation_params(
params: Dict = {
"url": self.voyage_api_base,
"headers": {"Authorization": f"Bearer {api_key}"},
"json": {"model": self.model, "input": input, "input_type": input_type},
"json": {
"model": self.model,
"input": input,
"input_type": input_type,
"truncation": self.truncation,
},
"timeout": self.request_timeout,
}
if self.truncation is not None:
params["json"]["truncation"] = self.truncation
return params

def _get_embeddings(
Expand Down Expand Up @@ -186,7 +202,9 @@ def embed_query(self, text: str) -> List[float]:
Returns:
Embedding for the text.
"""
return self._get_embeddings([text], input_type="query")[0]
return self._get_embeddings(
[text], batch_size=self.batch_size, input_type="query"
)[0]

def embed_general_texts(
self, texts: List[str], *, input_type: Optional[str] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from langchain_community.embeddings.voyageai import VoyageEmbeddings

# Please set VOYAGE_API_KEY in the environment variables
MODEL = "voyage-01"
MODEL = "voyage-2"


def test_voyagi_embedding_documents() -> None:
Expand All @@ -14,10 +14,22 @@ def test_voyagi_embedding_documents() -> None:
assert len(output[0]) == 1024


def test_voyagi_with_default_model() -> None:
"""Test voyage embeddings."""
embedding = VoyageEmbeddings()
assert embedding.model == "voyage-01"
assert embedding.batch_size == 7
documents = [f"foo bar {i}" for i in range(72)]
output = embedding.embed_documents(documents)
assert len(output) == 72
assert len(output[0]) == 1024


def test_voyage_embedding_documents_multiple() -> None:
"""Test voyage embeddings."""
documents = ["foo bar", "bar foo", "foo"]
embedding = VoyageEmbeddings(model=MODEL, batch_size=2)
assert embedding.model == MODEL
output = embedding.embed_documents(documents)
assert len(output) == 3
assert len(output[0]) == 1024
Expand Down
Loading