Skip to content

Commit

Permalink
feat: Deploy a tuned text embedding model -- it doesn't matter, if it…
Browse files Browse the repository at this point in the history
…'s tuned using Node.js, or curl.

PiperOrigin-RevId: 629619980
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed May 1, 2024
1 parent b22a8b8 commit 8ca9cdf
Showing 1 changed file with 64 additions and 1 deletion.
65 changes: 64 additions & 1 deletion vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1996,7 +1996,7 @@ class TextEmbeddingInput:
title: Optional[str] = None


class TextEmbeddingModel(_LanguageModel):
class _TextEmbeddingModel(_LanguageModel):
"""TextEmbeddingModel class calculates embeddings for the given texts.
Examples::
Expand Down Expand Up @@ -2126,6 +2126,69 @@ async def get_embeddings_async(
]


class _TunableTextEmbeddingModelMixin(_TunableModelMixin):
@classmethod
def get_tuned_model():
raise NotImplementedError(
"Use deploy_tuned_model instead to get the tuned model."
)

# IMPORTANT: Keep this method supported even if you end up deploying the tuned model as part of the tuning pipeline template.
@classmethod
def deploy_tuned_model(
cls,
tuned_model_name: str,
machine_type: Optional[str] = None,
accelerator: Optional[str] = None,
accelerator_count: Optional[int] = None,
) -> "_LanguageModel":
"""Loads the specified tuned language model.
Args:
tuned_model_name: Tuned model's resource name.
machine_type: Machine type. E.g., "a2-highgpu-1g". See also: https://cloud.google.com/vertex-ai/docs/training/configure-compute.
accelerator: Kind of accelerator. E.g., "NVIDIA_TESLA_A100". See also: https://cloud.google.com/vertex-ai/docs/training/configure-compute.
accelerator_count: Count of accelerators.
Returns:
Tuned `LanguageModel` object.
"""
tuned_vertex_model = aiplatform.Model(tuned_model_name)
tuned_model_labels = tuned_vertex_model.labels

if _TUNING_BASE_MODEL_ID_LABEL_KEY not in tuned_model_labels:
raise ValueError(
f"The provided model {tuned_model_name} does not have a base model ID."
)

tuning_model_id = tuned_vertex_model.labels[_TUNING_BASE_MODEL_ID_LABEL_KEY]
tuned_model_deployments = tuned_vertex_model.gca_resource.deployed_models
if len(tuned_model_deployments) == 0:
# Deploying a model to an endpoint requires a resource quota.
endpoint_name = tuned_vertex_model.deploy(
machine_type=machine_type,
accelerator_type=accelerator,
accelerator_count=accelerator_count,
).resource_name
else:
endpoint_name = tuned_model_deployments[0].endpoint

base_model_id = _get_model_id_from_tuning_model_id(tuning_model_id)
model_info = _model_garden_models._get_model_info(
model_id=base_model_id,
schema_to_class_map={cls._INSTANCE_SCHEMA_URI: cls},
)
model = model_info.interface_class(
model_id=base_model_id,
endpoint_name=endpoint_name,
)
return model


class TextEmbeddingModel(_TextEmbeddingModel, _TunableTextEmbeddingModelMixin):
__module__ = "vertexai.language_models"


class _PreviewTextEmbeddingModel(
TextEmbeddingModel, _ModelWithBatchPredict, _CountTokensMixin
):
Expand Down

0 comments on commit 8ca9cdf

Please sign in to comment.