Skip to content

Commit

Permalink
feat: GenAI - Added the response_schema parameter to the `Generatio…
Browse files Browse the repository at this point in the history
…nConfig` class

PiperOrigin-RevId: 637930285
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed May 28, 2024
1 parent ac17d87 commit b5e2c02
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tests/system/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ def get_current_weather(location: str, unit: str = "centigrade"):
"required": ["location"],
}

_RESPONSE_SCHEMA_STRUCT = {
"type": "object",
"properties": {
"location": {
"type": "string",
},
},
"required": ["location"],
}


class TestGenerativeModels(e2e_base.TestEndToEnd):
"""System tests for generative models."""
Expand Down Expand Up @@ -174,6 +184,7 @@ def test_generate_content_with_gemini_15_parameters(self):
presence_penalty=0.0,
frequency_penalty=0.0,
response_mime_type="application/json",
response_schema=_RESPONSE_SCHEMA_STRUCT,
),
safety_settings={
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
Expand Down
9 changes: 9 additions & 0 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,7 @@ def __init__(
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
response_mime_type: Optional[str] = None,
response_schema: Optional[Dict[str, Any]] = None,
):
r"""Constructs a GenerationConfig object.
Expand All @@ -1216,6 +1217,8 @@ def __init__(
The model needs to be prompted to output the appropriate
response type, otherwise the behavior is undefined.
response_schema: Output response schema of the genreated candidate text. Only valid when
response_mime_type is application/json.
Usage:
```
Expand All @@ -1232,6 +1235,11 @@ def __init__(
)
```
"""
if response_schema is None:
raw_schema = None
else:
gapic_schema_dict = _convert_schema_dict_to_gapic(response_schema)
raw_schema = aiplatform_types.Schema(gapic_schema_dict)
self._raw_generation_config = gapic_content_types.GenerationConfig(
temperature=temperature,
top_p=top_p,
Expand All @@ -1242,6 +1250,7 @@ def __init__(
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
response_mime_type=response_mime_type,
response_schema=raw_schema,
)

@classmethod
Expand Down

0 comments on commit b5e2c02

Please sign in to comment.