From cbf9b6ee806d7eb89725f53c4509858a272b3141 Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Sat, 19 Aug 2023 17:01:20 -0700 Subject: [PATCH] feat: LLM - TextEmbeddingModel - Added support for structural inputs (`TextEmbeddingInput`), `auto_truncate` parameter and result `statistics` PiperOrigin-RevId: 558465128 --- .../system/aiplatform/test_language_models.py | 16 ++- tests/unit/aiplatform/test_language_models.py | 46 ++++++++- vertexai/language_models/__init__.py | 2 + vertexai/language_models/_language_models.py | 97 ++++++++++++++++--- vertexai/preview/language_models.py | 2 + 5 files changed, 139 insertions(+), 24 deletions(-) diff --git a/tests/system/aiplatform/test_language_models.py b/tests/system/aiplatform/test_language_models.py index 5f6e9c369c..45c7a81a93 100644 --- a/tests/system/aiplatform/test_language_models.py +++ b/tests/system/aiplatform/test_language_models.py @@ -143,11 +143,17 @@ def test_text_embedding(self): aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) model = TextEmbeddingModel.from_pretrained("google/textembedding-gecko@001") - embeddings = model.get_embeddings(["What is life?"]) - assert embeddings - for embedding in embeddings: - vector = embedding.values - assert len(vector) == 768 + # One short text, one llong text (to check truncation) + texts = ["What is life?", "What is life?" * 1000] + embeddings = model.get_embeddings(texts) + assert len(embeddings) == 2 + assert len(embeddings[0].values) == 768 + assert embeddings[0].statistics.token_count > 0 + assert not embeddings[0].statistics.truncated + + assert len(embeddings[1].values) == 768 + assert embeddings[1].statistics.token_count > 1000 + assert embeddings[1].statistics.truncated def test_tuning(self, shared_state): """Test tuning, listing and loading models.""" diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index cfc2e2383c..4c27513fb5 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -298,6 +298,7 @@ def reverse_string_2(s):""", _TEST_TEXT_EMBEDDING_PREDICTION = { "embeddings": { "values": list([1.0] * _TEXT_EMBEDDING_VECTOR_LENGTH), + "statistics": {"truncated": False, "token_count": 4.0}, } } @@ -2170,18 +2171,57 @@ def test_text_embedding(self): gca_predict_response = gca_prediction_service.PredictResponse() gca_predict_response.predictions.append(_TEST_TEXT_EMBEDDING_PREDICTION) + gca_predict_response.predictions.append(_TEST_TEXT_EMBEDDING_PREDICTION) + expected_embedding = _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"] with mock.patch.object( target=prediction_service_client.PredictionServiceClient, attribute="predict", return_value=gca_predict_response, - ): - embeddings = model.get_embeddings(["What is life?"]) + ) as mock_predict: + embeddings = model.get_embeddings( + [ + "What is life?", + language_models.TextEmbeddingInput( + text="Foo", + task_type="RETRIEVAL_DOCUMENT", + title="Bar", + ), + language_models.TextEmbeddingInput( + text="Baz", + task_type="CLASSIFICATION", + ), + ], + auto_truncate=False, + ) + prediction_instances = mock_predict.call_args[1]["instances"] + assert prediction_instances == [ + {"content": "What is life?"}, + { + "content": "Foo", + "taskType": "RETRIEVAL_DOCUMENT", + "title": "Bar", + }, + { + "content": "Baz", + "taskType": "CLASSIFICATION", + }, + ] + prediction_parameters = mock_predict.call_args[1]["parameters"] + assert not prediction_parameters["autoTruncate"] assert embeddings for embedding in embeddings: vector = embedding.values assert len(vector) == _TEXT_EMBEDDING_VECTOR_LENGTH - assert vector == _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]["values"] + assert vector == expected_embedding["values"] + assert ( + embedding.statistics.token_count + == expected_embedding["statistics"]["token_count"] + ) + assert ( + embedding.statistics.truncated + == expected_embedding["statistics"]["truncated"] + ) def test_text_embedding_ga(self): """Tests the text embedding model.""" diff --git a/vertexai/language_models/__init__.py b/vertexai/language_models/__init__.py index 9566691f29..8d16584ecb 100644 --- a/vertexai/language_models/__init__.py +++ b/vertexai/language_models/__init__.py @@ -23,6 +23,7 @@ CodeGenerationModel, InputOutputTextPair, TextEmbedding, + TextEmbeddingInput, TextEmbeddingModel, TextGenerationModel, TextGenerationResponse, @@ -37,6 +38,7 @@ "CodeGenerationModel", "InputOutputTextPair", "TextEmbedding", + "TextEmbeddingInput", "TextEmbeddingModel", "TextGenerationModel", "TextGenerationResponse", diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 85e63d830f..e974d2421b 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -692,8 +692,33 @@ def send_message( return response_obj +@dataclasses.dataclass +class TextEmbeddingInput: + """Structural text embedding input. + + Attributes: + text: The main text content to embed. + task_type: The name of the downstream task the embeddings will be used for. + Valid values: + RETRIEVAL_QUERY + Specifies the given text is a query in a search/retrieval setting. + RETRIEVAL_DOCUMENT + Specifies the given text is a document from the corpus being searched. + SEMANTIC_SIMILARITY + Specifies the given text will be used for STS. + CLASSIFICATION + Specifies that the given text will be classified. + CLUSTERING + Specifies that the embeddings will be used for clustering. + title: Optional identifier of the text content. + """ + text: str + task_type: Optional[str] = None + title: Optional[str] = None + + class TextEmbeddingModel(_LanguageModel): - """TextEmbeddingModel converts text into a vector of floating-point numbers. + """TextEmbeddingModel class calculates embeddings for the given texts. Examples:: @@ -711,36 +736,76 @@ class TextEmbeddingModel(_LanguageModel): "gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml" ) - def get_embeddings(self, texts: List[str]) -> List["TextEmbedding"]: - instances = [{"content": str(text)} for text in texts] + def get_embeddings(self, + texts: List[Union[str, TextEmbeddingInput]], + *, + auto_truncate: bool = True, + ) -> List["TextEmbedding"]: + """Calculates embeddings for the given texts. + + Args: + texts(str): A list of texts or `TextEmbeddingInput` objects to embed. + auto_truncate(bool): Whether to automatically truncate long texts. Default: True. + + Returns: + A list of `TextEmbedding` objects. + """ + instances = [] + for text in texts: + if isinstance(text, TextEmbeddingInput): + instance = {"content": text.text} + if text.task_type: + instance["taskType"] = text.task_type + if text.title: + instance["title"] = text.title + elif isinstance(text, str): + instance = {"content": text} + else: + raise TypeError(f"Unsupported text embedding input type: {text}.") + instances.append(instance) + parameters = {"autoTruncate": auto_truncate} prediction_response = self._endpoint.predict( instances=instances, + parameters=parameters, ) - return [ - TextEmbedding( - values=prediction["embeddings"]["values"], + results = [] + for prediction in prediction_response.predictions: + embeddings = prediction["embeddings"] + statistics = embeddings["statistics"] + result = TextEmbedding( + values=embeddings["values"], + statistics=TextEmbeddingStatistics( + token_count=statistics["token_count"], + truncated=statistics["truncated"], + ), _prediction_response=prediction_response, ) - for prediction in prediction_response.predictions - ] + results.append(result) + + return results class _PreviewTextEmbeddingModel(TextEmbeddingModel, _ModelWithBatchPredict): _LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE +@dataclasses.dataclass +class TextEmbeddingStatistics: + """Text embedding statistics.""" + + token_count: int + truncated: bool + + +@dataclasses.dataclass class TextEmbedding: - """Contains text embedding vector.""" + """Text embedding vector and statistics.""" - def __init__( - self, - values: List[float], - _prediction_response: Any = None, - ): - self.values = values - self._prediction_response = _prediction_response + values: List[float] + statistics: TextEmbeddingStatistics + _prediction_response: aiplatform.models.Prediction = None @dataclasses.dataclass diff --git a/vertexai/preview/language_models.py b/vertexai/preview/language_models.py index 447b3a0f9f..7089091456 100644 --- a/vertexai/preview/language_models.py +++ b/vertexai/preview/language_models.py @@ -26,6 +26,7 @@ CodeChatSession, InputOutputTextPair, TextEmbedding, + TextEmbeddingInput, TextGenerationResponse, ) @@ -60,6 +61,7 @@ "EvaluationTextClassificationSpec", "InputOutputTextPair", "TextEmbedding", + "TextEmbeddingInput", "TextEmbeddingModel", "TextGenerationModel", "TextGenerationResponse",