diff --git a/docs/docs/integrations/text_embedding/voyageai.ipynb b/docs/docs/integrations/text_embedding/voyageai.ipynb index 8dd992f16a710b..e3d213db371db3 100644 --- a/docs/docs/integrations/text_embedding/voyageai.ipynb +++ b/docs/docs/integrations/text_embedding/voyageai.ipynb @@ -27,7 +27,7 @@ "id": "137cfde9-b88c-409a-9394-a9e31a6bf30d", "metadata": {}, "source": [ - "Voyage AI utilizes API keys to monitor usage and manage permissions. To obtain your key, create an account on our [homepage](https://www.voyageai.com). Then, create a VoyageEmbeddings model with your API key." + "Voyage AI utilizes API keys to monitor usage and manage permissions. To obtain your key, create an account on our [homepage](https://www.voyageai.com). Then, create a VoyageEmbeddings model with your API key. Please refer to the documentation for further details on the available models: https://docs.voyageai.com/embeddings/" ] }, { @@ -37,7 +37,9 @@ "metadata": {}, "outputs": [], "source": [ - "embeddings = VoyageEmbeddings(voyage_api_key=\"[ Your Voyage API key ]\")" + "embeddings = VoyageEmbeddings(\n", + " voyage_api_key=\"[ Your Voyage API key ]\", model=\"voyage-2\"\n", + ")" ] }, { diff --git a/libs/community/langchain_community/embeddings/voyageai.py b/libs/community/langchain_community/embeddings/voyageai.py index f8b1a4059e6d5e..ceb930c987f0d8 100644 --- a/libs/community/langchain_community/embeddings/voyageai.py +++ b/libs/community/langchain_community/embeddings/voyageai.py @@ -69,15 +69,15 @@ class VoyageEmbeddings(BaseModel, Embeddings): from langchain_community.embeddings import VoyageEmbeddings - voyage = VoyageEmbeddings(voyage_api_key="your-api-key") + voyage = VoyageEmbeddings(voyage_api_key="your-api-key", model="voyage-2") text = "This is a test query." query_result = voyage.embed_query(text) """ - model: str = "voyage-01" + model: str voyage_api_base: str = "https://api.voyageai.com/v1/embeddings" voyage_api_key: Optional[SecretStr] = None - batch_size: int = 8 + batch_size: int """Maximum number of texts to embed in each API request.""" max_retries: int = 6 """Maximum number of retries to make when generating.""" @@ -86,15 +86,12 @@ class VoyageEmbeddings(BaseModel, Embeddings): show_progress_bar: bool = False """Whether to show a progress bar when embedding. Must have tqdm installed if set to True.""" - truncation: Optional[bool] = None + truncation: bool = True """Whether to truncate the input texts to fit within the context length. If True, over-length input texts will be truncated to fit within the context length, before vectorized by the embedding model. If False, an error will be - raised if any given text exceeds the context length. If not specified - (defaults to None), we will truncate the input text before sending it to the - embedding model if it slightly exceeds the context window length. If it - significantly exceeds the context window length, an error will be raised.""" + raised if any given text exceeds the context length.""" class Config: """Configuration for this pydantic object.""" @@ -107,6 +104,22 @@ def validate_environment(cls, values: Dict) -> Dict: values["voyage_api_key"] = convert_to_secret_str( get_from_dict_or_env(values, "voyage_api_key", "VOYAGE_API_KEY") ) + + if "model" not in values: + values["model"] = "voyage-01" + logger.warning( + "model will become a required arg for VoyageAIEmbeddings, " + "we recommend to specify it when using this class. " + "Currently the default is set to voyage-01." + ) + + if "batch_size" not in values: + values["batch_size"] = ( + 72 + if "model" in values and (values["model"] in ["voyage-2", "voyage-02"]) + else 7 + ) + return values def _invocation_params( @@ -116,11 +129,14 @@ def _invocation_params( params: Dict = { "url": self.voyage_api_base, "headers": {"Authorization": f"Bearer {api_key}"}, - "json": {"model": self.model, "input": input, "input_type": input_type}, + "json": { + "model": self.model, + "input": input, + "input_type": input_type, + "truncation": self.truncation, + }, "timeout": self.request_timeout, } - if self.truncation is not None: - params["json"]["truncation"] = self.truncation return params def _get_embeddings( @@ -186,7 +202,9 @@ def embed_query(self, text: str) -> List[float]: Returns: Embedding for the text. """ - return self._get_embeddings([text], input_type="query")[0] + return self._get_embeddings( + [text], batch_size=self.batch_size, input_type="query" + )[0] def embed_general_texts( self, texts: List[str], *, input_type: Optional[str] = None diff --git a/libs/community/tests/integration_tests/embeddings/test_voyageai.py b/libs/community/tests/integration_tests/embeddings/test_voyageai.py index b23dbd7f538b30..c14c08c5db7943 100644 --- a/libs/community/tests/integration_tests/embeddings/test_voyageai.py +++ b/libs/community/tests/integration_tests/embeddings/test_voyageai.py @@ -2,7 +2,7 @@ from langchain_community.embeddings.voyageai import VoyageEmbeddings # Please set VOYAGE_API_KEY in the environment variables -MODEL = "voyage-01" +MODEL = "voyage-2" def test_voyagi_embedding_documents() -> None: @@ -14,10 +14,22 @@ def test_voyagi_embedding_documents() -> None: assert len(output[0]) == 1024 +def test_voyagi_with_default_model() -> None: + """Test voyage embeddings.""" + embedding = VoyageEmbeddings() + assert embedding.model == "voyage-01" + assert embedding.batch_size == 7 + documents = [f"foo bar {i}" for i in range(72)] + output = embedding.embed_documents(documents) + assert len(output) == 72 + assert len(output[0]) == 1024 + + def test_voyage_embedding_documents_multiple() -> None: """Test voyage embeddings.""" documents = ["foo bar", "bar foo", "foo"] embedding = VoyageEmbeddings(model=MODEL, batch_size=2) + assert embedding.model == MODEL output = embedding.embed_documents(documents) assert len(output) == 3 assert len(output[0]) == 1024