Skip to content

Commit

Permalink
feat: Add additional parameters for GenerationConfig
Browse files Browse the repository at this point in the history
- `presence_penalty`
- `frequency_penalty`
- `response_mime_type`

PiperOrigin-RevId: 627067043
  • Loading branch information
holtskinner authored and Copybara-Service committed Apr 22, 2024
1 parent c0e7acc commit 0599ca1
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 1 deletion.
27 changes: 27 additions & 0 deletions tests/system/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=protected-access, g-multiple-import
"""System tests for generative models."""

import json
import pytest

# Google imports
Expand All @@ -30,6 +31,7 @@

GEMINI_MODEL_NAME = "gemini-1.0-pro-002"
GEMINI_VISION_MODEL_NAME = "gemini-1.0-pro-vision"
GEMINI_15_MODEL_NAME = "gemini-1.5-pro-preview-0409"


# A dummy function for function calling
Expand Down Expand Up @@ -150,6 +152,31 @@ def test_generate_content_with_parameters(self):
)
assert response.text

def test_generate_content_with_gemini_15_parameters(self):
model = generative_models.GenerativeModel(GEMINI_15_MODEL_NAME)
response = model.generate_content(
contents="Why is sky blue? Respond in JSON Format.",
generation_config=generative_models.GenerationConfig(
temperature=0,
top_p=0.95,
top_k=20,
candidate_count=1,
max_output_tokens=100,
stop_sequences=["STOP!"],
presence_penalty=0.0,
frequency_penalty=0.0,
response_mime_type="application/json",
),
safety_settings={
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_NONE,
},
)
assert response.text
assert json.loads(response.text)

def test_generate_content_from_list_of_content_dict(self):
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
response = model.generate_content(
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,8 @@ def test_generate_content(self, generative_models: generative_models):
candidate_count=1,
max_output_tokens=200,
stop_sequences=["\n\n\n"],
presence_penalty=0.0,
frequency_penalty=0.0,
),
safety_settings=[
generative_models.SafetySetting(
Expand All @@ -420,6 +422,33 @@ def test_generate_content(self, generative_models: generative_models):
)
assert response2.text

model3 = generative_models.GenerativeModel("gemini-1.5-pro-preview-0409")
response3 = model3.generate_content(
"Why is sky blue? Respond in JSON.",
generation_config=generative_models.GenerationConfig(
temperature=0.2,
top_p=0.9,
top_k=20,
candidate_count=1,
max_output_tokens=200,
stop_sequences=["\n\n\n"],
response_mime_type="application/json",
),
safety_settings=[
generative_models.SafetySetting(
category=generative_models.SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=generative_models.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
method=generative_models.SafetySetting.HarmBlockMethod.SEVERITY,
),
generative_models.SafetySetting(
category=generative_models.SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=generative_models.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH,
method=generative_models.SafetySetting.HarmBlockMethod.PROBABILITY,
),
],
)
assert response3.text

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="stream_generate_content",
Expand Down
20 changes: 19 additions & 1 deletion vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,9 @@ def __init__(
candidate_count: Optional[int] = None,
max_output_tokens: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
response_mime_type: Optional[str] = None,
):
r"""Constructs a GenerationConfig object.
Expand All @@ -1199,6 +1202,18 @@ def __init__(
candidate_count: Number of candidates to generate.
max_output_tokens: The maximum number of output tokens to generate per message.
stop_sequences: A list of stop sequences.
presence_penalty: Positive values penalize tokens that have appeared in the generated text,
thus increasing the possibility of generating more diversed topics. Range: [-2.0, 2.0]
frequency_penalty: Positive values penalize tokens that repeatedly appear in the generated
text, thus decreasing the possibility of repeating the same content. Range: [-2.0, 2.0]
response_mime_type: Output response mimetype of the generated
candidate text. Supported mimetypes:
- ``text/plain``: (default) Text output.
- ``application/json``: JSON response in the candidates.
The model needs to be prompted to output the appropriate
response type, otherwise the behavior is undefined.
Usage:
```
Expand All @@ -1222,6 +1237,9 @@ def __init__(
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
response_mime_type=response_mime_type,
)

@classmethod
Expand Down Expand Up @@ -1650,7 +1668,7 @@ def prompt_feedback(

@property
def usage_metadata(
self
self,
) -> gapic_prediction_service_types.GenerateContentResponse.UsageMetadata:
return self._raw_response.usage_metadata

Expand Down

0 comments on commit 0599ca1

Please sign in to comment.