Skip to content

Commit

Permalink
feat: Deploy a tuned text embedding model -- it doesn't matter, if it…
Browse files Browse the repository at this point in the history
…'s tuned using Node.js, or curl.

PiperOrigin-RevId: 628574207
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Apr 27, 2024
1 parent 1341e2c commit 61fe7dc
Showing 1 changed file with 99 additions and 22 deletions.
121 changes: 99 additions & 22 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import abc
import dataclasses
import collections.abc
from typing import (
Any,
AsyncIterator,
Expand Down Expand Up @@ -975,6 +976,7 @@ class TuningEvaluationSpec:
enable_checkpoint_selection: Optional[bool] = None
tensorboard: Optional[Union[aiplatform.Tensorboard, str]] = None


# Evaluation spec fields that are not supported by RLHF tuning
_UNUSED_RLHF_EVAL_SPECS = (
"evaluation_interval",
Expand Down Expand Up @@ -1994,7 +1996,7 @@ class TextEmbeddingInput:
title: Optional[str] = None


class TextEmbeddingModel(_LanguageModel):
class _TextEmbeddingModel(_LanguageModel):
"""TextEmbeddingModel class calculates embeddings for the given texts.
Examples::
Expand Down Expand Up @@ -2076,7 +2078,7 @@ def get_embeddings(
texts: List[Union[str, TextEmbeddingInput]],
*,
auto_truncate: bool = True,
output_dimensionality: Optional[int] = None
output_dimensionality: Optional[int] = None,
) -> List["TextEmbedding"]:
"""Calculates embeddings for the given texts.
Expand All @@ -2099,15 +2101,10 @@ def get_embeddings(
parameters=prediction_request.parameters,
)

results = []
for prediction_idx in range(len(prediction_response.predictions)):
result = self._parse_text_embedding_response(
prediction_response=prediction_response,
prediction_idx=prediction_idx,
)
results.append(result)

return results
return [
TextEmbedding.from_prediction(prediction_response, i_prediction)
for i_prediction, _ in enumerate(prediction_response.predictions)
]

async def get_embeddings_async(
self,
Expand All @@ -2129,29 +2126,80 @@ async def get_embeddings_async(
prediction_request = self._prepare_text_embedding_request(
texts=texts,
auto_truncate=auto_truncate,
output_dimensionality=output_dimensionality
output_dimensionality=output_dimensionality,
)

prediction_response = await self._endpoint.predict_async(
instances=prediction_request.instances,
parameters=prediction_request.parameters,
)

results = []
for prediction_idx in range(len(prediction_response.predictions)):
result = self._parse_text_embedding_response(
prediction_response=prediction_response,
prediction_idx=prediction_idx,
return [
TextEmbedding.from_prediction(prediction_response, i_prediction)
for i_prediction, _ in enumerate(prediction_response.predictions)
]


class _TunedEmbeddingModelMixin(_TunableModelMixin):
@classmethod
def get_tuned_model(
cls,
tuned_model_name: str,
machine_type: Optional[str] = None,
accelerator: Optional[str] = None,
accelerator_count: Optional[int] = None,
) -> "_LanguageModel":
"""Loads the specified tuned language model.
Args:
tuned_model_name: Tuned model's resource name.
machine_type: Machine type. E.g., "a2-highgpu-1g". See also: https://cloud.google.com/vertex-ai/docs/training/configure-compute.
accelerator: Kind of accelerator. E.g., "NVIDIA_TESLA_A100". See also: https://cloud.google.com/vertex-ai/docs/training/configure-compute.
accelerator_count: Count of accelerators.
Returns:
Tuned `LanguageModel` object.
"""
tuned_vertex_model = aiplatform.Model(tuned_model_name)
tuned_model_labels = tuned_vertex_model.labels

if _TUNING_BASE_MODEL_ID_LABEL_KEY not in tuned_model_labels:
raise ValueError(
f"The provided model {tuned_model_name} does not have a base model ID."
)
results.append(result)

return results
tuning_model_id = tuned_vertex_model.labels[_TUNING_BASE_MODEL_ID_LABEL_KEY]
tuned_model_deployments = tuned_vertex_model.gca_resource.deployed_models
if len(tuned_model_deployments) == 0:
# Deploying a model to an endpoint requires a resource quota.
endpoint_name = tuned_vertex_model.deploy(
machine_type=machine_type,
accelerator_type=accelerator,
accelerator_count=accelerator_count,
).resource_name
else:
endpoint_name = tuned_model_deployments[0].endpoint

base_model_id = _get_model_id_from_tuning_model_id(tuning_model_id)
model_info = _model_garden_models._get_model_info(
model_id=base_model_id,
schema_to_class_map={cls._INSTANCE_SCHEMA_URI: cls},
)
model = model_info.interface_class(
model_id=base_model_id,
endpoint_name=endpoint_name,
)
return model


class TextEmbeddingModel(_TextEmbeddingModel, _TunedEmbeddingModelMixin):
pass


class _PreviewTextEmbeddingModel(
TextEmbeddingModel, _ModelWithBatchPredict, _CountTokensMixin
_TextEmbeddingModel, _ModelWithBatchPredict, _CountTokensMixin
):
__name__ = "TextEmbeddingModel"
__name__ = "_TextEmbeddingModel"
__module__ = "vertexai.preview.language_models"


Expand All @@ -2175,6 +2223,36 @@ class TextEmbedding:
statistics: Optional[TextEmbeddingStatistics] = None
_prediction_response: Optional[aiplatform.models.Prediction] = None

@classmethod
def from_prediction(
cls, prediction_response: aiplatform.models.Prediction, i_prediction: int
) -> "TextEmbedding":
"""Creates a `TextEmbedding` object from a prediction.
Args:
prediction_response: `aiplatform.models.Prediction` object.
Returns:
`TextEmbedding` object.
"""
prediction = prediction_response.predictions[i_prediction]
is_prediction_from_pretrained_models = isinstance(
prediction, collections.abc.Mapping
)
if is_prediction_from_pretrained_models:
embeddings = prediction["embeddings"]
embedding_stats = embeddings["statistics"]
return cls(
values=embeddings["values"],
statistics=TextEmbeddingStatistics(
token_count=embedding_stats["token_count"],
truncated=embedding_stats["truncated"],
),
_prediction_response=prediction_response,
)
else:
return cls(values=prediction, _prediction_response=prediction_response)


@dataclasses.dataclass
class InputOutputTextPair:
Expand Down Expand Up @@ -3146,7 +3224,6 @@ class _CodeGenerationModel(_LanguageModel):

_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml"


def _create_prediction_request(
self,
prefix: str,
Expand Down

0 comments on commit 61fe7dc

Please sign in to comment.