From 96e7f7d9243c36fa991dd147fe66b3a7e545b3bb Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Fri, 20 Oct 2023 14:25:26 -0700 Subject: [PATCH] feat: add preview count_tokens method to CodeGenerationModel PiperOrigin-RevId: 575318395 --- tests/unit/aiplatform/test_language_models.py | 37 ++++++++++++++++ vertexai/language_models/_language_models.py | 42 ++++++++++++++++++- 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 4d29e0ac73..3766eb8a7d 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -2771,6 +2771,43 @@ def test_code_generation_multiple_candidates(self): response.candidates[0].text == _TEST_CODE_GENERATION_PREDICTION["content"] ) + def test_code_generation_preview_count_tokens(self): + """Tests the count_tokens method in CodeGenerationModel.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _CODE_COMPLETION_BISON_PUBLISHER_MODEL_DICT + ), + ): + model = preview_language_models.CodeGenerationModel.from_pretrained( + "code-gecko@001" + ) + + gca_count_tokens_response = gca_prediction_service_v1beta1.CountTokensResponse( + total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"], + total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[ + "total_billable_characters" + ], + ) + + with mock.patch.object( + target=prediction_service_client_v1beta1.PredictionServiceClient, + attribute="count_tokens", + return_value=gca_count_tokens_response, + ): + response = model.count_tokens("def reverse_string(s):") + + assert response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"] + assert ( + response.total_billable_characters + == _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"] + ) + def test_code_completion(self): """Tests code completion with the code generation model.""" aiplatform.init( diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 4b4d0ec0ff..f0ebae0477 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -2648,7 +2648,47 @@ async def predict_streaming_async( yield _parse_text_generation_model_response(prediction_obj) -class _PreviewCodeGenerationModel(CodeGenerationModel, _TunableModelMixin): +class _CountTokensCodeGenerationMixin(_LanguageModel): + """Mixin for code generation models that support the CountTokens API""" + + def count_tokens( + self, + prefix: str, + *, + suffix: Optional[str] = None, + ) -> CountTokensResponse: + """Counts the tokens and billable characters for a given code generation prompt. + + Note: this does not make a prediction request to the model, it only counts the tokens + in the request. + + Args: + prefix (str): Code before the current point. + suffix (str): Code after the current point. + + Returns: + A `CountTokensResponse` object that contains the number of tokens + in the text and the number of billable characters. + """ + prediction_request = {"prefix": prefix, "suffix": suffix} + + count_tokens_response = self._endpoint._prediction_client.select_version( + "v1beta1" + ).count_tokens( + endpoint=self._endpoint_name, + instances=[prediction_request], + ) + + return CountTokensResponse( + total_tokens=count_tokens_response.total_tokens, + total_billable_characters=count_tokens_response.total_billable_characters, + _count_tokens_response=count_tokens_response, + ) + + +class _PreviewCodeGenerationModel( + CodeGenerationModel, _TunableModelMixin, _CountTokensCodeGenerationMixin +): __name__ = "CodeGenerationModel" __module__ = "vertexai.preview.language_models"