From b5e2c0204070e5f7fb695d39c7e5d23f937dbffd Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Tue, 28 May 2024 09:42:30 -0700 Subject: [PATCH] feat: GenAI - Added the `response_schema` parameter to the `GenerationConfig` class PiperOrigin-RevId: 637930285 --- tests/system/vertexai/test_generative_models.py | 11 +++++++++++ vertexai/generative_models/_generative_models.py | 9 +++++++++ 2 files changed, 20 insertions(+) diff --git a/tests/system/vertexai/test_generative_models.py b/tests/system/vertexai/test_generative_models.py index 800189c0ff..1282e99c26 100644 --- a/tests/system/vertexai/test_generative_models.py +++ b/tests/system/vertexai/test_generative_models.py @@ -70,6 +70,16 @@ def get_current_weather(location: str, unit: str = "centigrade"): "required": ["location"], } +_RESPONSE_SCHEMA_STRUCT = { + "type": "object", + "properties": { + "location": { + "type": "string", + }, + }, + "required": ["location"], +} + class TestGenerativeModels(e2e_base.TestEndToEnd): """System tests for generative models.""" @@ -174,6 +184,7 @@ def test_generate_content_with_gemini_15_parameters(self): presence_penalty=0.0, frequency_penalty=0.0, response_mime_type="application/json", + response_schema=_RESPONSE_SCHEMA_STRUCT, ), safety_settings={ generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index b5d8877494..192928e935 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -1194,6 +1194,7 @@ def __init__( presence_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None, response_mime_type: Optional[str] = None, + response_schema: Optional[Dict[str, Any]] = None, ): r"""Constructs a GenerationConfig object. @@ -1216,6 +1217,8 @@ def __init__( The model needs to be prompted to output the appropriate response type, otherwise the behavior is undefined. + response_schema: Output response schema of the genreated candidate text. Only valid when + response_mime_type is application/json. Usage: ``` @@ -1232,6 +1235,11 @@ def __init__( ) ``` """ + if response_schema is None: + raw_schema = None + else: + gapic_schema_dict = _convert_schema_dict_to_gapic(response_schema) + raw_schema = aiplatform_types.Schema(gapic_schema_dict) self._raw_generation_config = gapic_content_types.GenerationConfig( temperature=temperature, top_p=top_p, @@ -1242,6 +1250,7 @@ def __init__( presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, response_mime_type=response_mime_type, + response_schema=raw_schema, ) @classmethod