Skip to content

Commit

Permalink
feat: Make get_embeddings work both for foundational & tuned models.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 629254179
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Apr 30, 2024
1 parent 3ce0126 commit b8b589c
Showing 1 changed file with 44 additions and 37 deletions.
81 changes: 44 additions & 37 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import abc
import dataclasses
import collections.abc
from typing import (
Any,
AsyncIterator,
Expand Down Expand Up @@ -975,6 +976,7 @@ class TuningEvaluationSpec:
enable_checkpoint_selection: Optional[bool] = None
tensorboard: Optional[Union[aiplatform.Tensorboard, str]] = None


# Evaluation spec fields that are not supported by RLHF tuning
_UNUSED_RLHF_EVAL_SPECS = (
"evaluation_interval",
Expand Down Expand Up @@ -2053,30 +2055,12 @@ def _prepare_text_embedding_request(
parameters=parameters,
)

def _parse_text_embedding_response(
self,
prediction_response: aiplatform.models.Prediction,
prediction_idx: int = 0,
) -> "TextEmbedding":
"""Parses the text embedding model response."""
prediction = prediction_response.predictions[prediction_idx]
embeddings = prediction["embeddings"]
statistics = embeddings["statistics"]
return TextEmbedding(
values=embeddings["values"],
statistics=TextEmbeddingStatistics(
token_count=statistics["token_count"],
truncated=statistics["truncated"],
),
_prediction_response=prediction_response,
)

def get_embeddings(
self,
texts: List[Union[str, TextEmbeddingInput]],
*,
auto_truncate: bool = True,
output_dimensionality: Optional[int] = None
output_dimensionality: Optional[int] = None,
) -> List["TextEmbedding"]:
"""Calculates embeddings for the given texts.
Expand All @@ -2099,15 +2083,12 @@ def get_embeddings(
parameters=prediction_request.parameters,
)

results = []
for prediction_idx in range(len(prediction_response.predictions)):
result = self._parse_text_embedding_response(
prediction_response=prediction_response,
prediction_idx=prediction_idx,
return [
TextEmbedding._parse_text_embedding_response(
prediction_response, i_prediction
)
results.append(result)

return results
for i_prediction, _ in enumerate(prediction_response.predictions)
]

async def get_embeddings_async(
self,
Expand All @@ -2129,23 +2110,20 @@ async def get_embeddings_async(
prediction_request = self._prepare_text_embedding_request(
texts=texts,
auto_truncate=auto_truncate,
output_dimensionality=output_dimensionality
output_dimensionality=output_dimensionality,
)

prediction_response = await self._endpoint.predict_async(
instances=prediction_request.instances,
parameters=prediction_request.parameters,
)

results = []
for prediction_idx in range(len(prediction_response.predictions)):
result = self._parse_text_embedding_response(
prediction_response=prediction_response,
prediction_idx=prediction_idx,
return [
TextEmbedding._parse_text_embedding_response(
prediction_response, i_prediction
)
results.append(result)

return results
for i_prediction, _ in enumerate(prediction_response.predictions)
]


class _PreviewTextEmbeddingModel(
Expand Down Expand Up @@ -2175,6 +2153,36 @@ class TextEmbedding:
statistics: Optional[TextEmbeddingStatistics] = None
_prediction_response: Optional[aiplatform.models.Prediction] = None

@classmethod
def _parse_text_embedding_response(
cls, prediction_response: aiplatform.models.Prediction, prediction_index: int
) -> "TextEmbedding":
"""Creates a `TextEmbedding` object from a prediction.
Args:
prediction_response: `aiplatform.models.Prediction` object.
Returns:
`TextEmbedding` object.
"""
prediction = prediction_response.predictions[prediction_index]
is_prediction_from_pretrained_models = isinstance(
prediction, collections.abc.Mapping
)
if is_prediction_from_pretrained_models:
embeddings = prediction["embeddings"]
embedding_stats = embeddings["statistics"]
return cls(
values=embeddings["values"],
statistics=TextEmbeddingStatistics(
token_count=embedding_stats["token_count"],
truncated=embedding_stats["truncated"],
),
_prediction_response=prediction_response,
)
else:
return cls(values=prediction, _prediction_response=prediction_response)


@dataclasses.dataclass
class InputOutputTextPair:
Expand Down Expand Up @@ -3146,7 +3154,6 @@ class _CodeGenerationModel(_LanguageModel):

_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml"


def _create_prediction_request(
self,
prefix: str,
Expand Down

0 comments on commit b8b589c

Please sign in to comment.