Skip to content

Commit

Permalink
feat: Add support for output_dimensionality parameter through get_emb…
Browse files Browse the repository at this point in the history
…eddings.

PiperOrigin-RevId: 617251035
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Mar 19, 2024
1 parent be4922a commit b1cab3f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
2 changes: 2 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4164,6 +4164,7 @@ def test_text_embedding(self):
),
],
auto_truncate=False,
output_dimensionality=3,
)
prediction_instances = mock_predict.call_args[1]["instances"]
assert prediction_instances == [
Expand All @@ -4180,6 +4181,7 @@ def test_text_embedding(self):
]
prediction_parameters = mock_predict.call_args[1]["parameters"]
assert not prediction_parameters["autoTruncate"]
assert prediction_parameters["outputDimensionality"] == 3
assert embeddings
for embedding in embeddings:
vector = embedding.values
Expand Down
18 changes: 14 additions & 4 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2003,12 +2003,14 @@ def _prepare_text_embedding_request(
texts: List[Union[str, TextEmbeddingInput]],
*,
auto_truncate: bool = True,
output_dimensionality: Optional[int] = None,
) -> _MultiInstancePredictionRequest:
"""Asynchronously calculates embeddings for the given texts.
Args:
texts(str): A list of texts or `TextEmbeddingInput` objects to embed.
auto_truncate(bool): Whether to automatically truncate long texts. Default: True.
output_dimensionality: Optional dimensions of embeddings. Range: [1, 768]. Default: None.
Returns:
A `_MultiInstancePredictionRequest` object.
Expand All @@ -2029,6 +2031,8 @@ def _prepare_text_embedding_request(
raise TypeError(f"Unsupported text embedding input type: {text}.")
instances.append(instance)
parameters = {"autoTruncate": auto_truncate}
if output_dimensionality is not None:
parameters["outputDimensionality"] = output_dimensionality
return _MultiInstancePredictionRequest(
instances=instances,
parameters=parameters,
Expand Down Expand Up @@ -2057,19 +2061,22 @@ def get_embeddings(
texts: List[Union[str, TextEmbeddingInput]],
*,
auto_truncate: bool = True,
output_dimensionality: Optional[int] = None
) -> List["TextEmbedding"]:
"""Calculates embeddings for the given texts.
Args:
texts(str): A list of texts or `TextEmbeddingInput` objects to embed.
auto_truncate(bool): Whether to automatically truncate long texts. Default: True.
texts: A list of texts or `TextEmbeddingInput` objects to embed.
auto_truncate: Whether to automatically truncate long texts. Default: True.
output_dimensionality: Optional dimensions of embeddings. Range: [1, 768]. Default: None.
Returns:
A list of `TextEmbedding` objects.
"""
prediction_request = self._prepare_text_embedding_request(
texts=texts,
auto_truncate=auto_truncate,
output_dimensionality=output_dimensionality,
)

prediction_response = self._endpoint.predict(
Expand All @@ -2092,19 +2099,22 @@ async def get_embeddings_async(
texts: List[Union[str, TextEmbeddingInput]],
*,
auto_truncate: bool = True,
output_dimensionality: Optional[int] = None,
) -> List["TextEmbedding"]:
"""Asynchronously calculates embeddings for the given texts.
Args:
texts(str): A list of texts or `TextEmbeddingInput` objects to embed.
auto_truncate(bool): Whether to automatically truncate long texts. Default: True.
texts: A list of texts or `TextEmbeddingInput` objects to embed.
auto_truncate: Whether to automatically truncate long texts. Default: True.
output_dimensionality: Optional dimensions of embeddings. Range: [1, 768]. Default: None.
Returns:
A list of `TextEmbedding` objects.
"""
prediction_request = self._prepare_text_embedding_request(
texts=texts,
auto_truncate=auto_truncate,
output_dimensionality=output_dimensionality
)

prediction_response = await self._endpoint.predict_async(
Expand Down

0 comments on commit b1cab3f

Please sign in to comment.