diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 1c4781c9bc..0291ccb7f6 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -577,6 +577,10 @@ def test_text_generation(self): ) assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"] + assert ( + response.raw_prediction_response.predictions[0] + == _TEST_TEXT_GENERATION_PREDICTION + ) assert ( response.safety_attributes["Violent"] == _TEST_TEXT_GENERATION_PREDICTION["safetyAttributes"]["scores"][0] diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 620bc4e708..588748fbef 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -222,6 +222,11 @@ class TextGenerationResponse: def __repr__(self): return self.text + @property + def raw_prediction_response(self) -> aiplatform.models.Prediction: + """Raw prediction response.""" + return self._prediction_response + class _TextGenerationModel(_LanguageModel): """TextGenerationModel represents a general language model.