Skip to content

Commit

Permalink
feat: Add safety filter levels, watermark support and person generati…
Browse files Browse the repository at this point in the history
…on support for Imagen 2

Changelog:
Changed the following parameters internally:

  - `editConfig` to `editConfigV6`
  - `segmentationClasses` to `classes`
  - `safetyFilterLevel` to `safetySetting`

PiperOrigin-RevId: 619620158
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Mar 27, 2024
1 parent 2e56acc commit 0c498c5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions tests/unit/aiplatform/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def test_generate_images_gcs(self):
actual_parameters["outputOptions"]["compressionQuality"]
== compression_quality
)
assert actual_parameters["safetyFilterLevel"] == safety_filter_level
assert actual_parameters["safetySetting"] == safety_filter_level
assert actual_parameters["personGeneration"] == person_generation

assert len(image_response2.images) == number_of_images
Expand Down Expand Up @@ -639,7 +639,7 @@ def test_generate_images_requests_safety_filter_level(self):
)
predict_kwargs = mock_predict.call_args[1]
actual_parameters = predict_kwargs["parameters"]
assert actual_parameters["safetyFilterLevel"] == level
assert actual_parameters["safetySetting"] == level

def test_generate_images_requests_person_generation(self):
"""Tests that the model class generates person images."""
Expand Down
6 changes: 3 additions & 3 deletions vertexai/vision_models/_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,8 @@ class ID
shared_generation_parameters["mask_mode"] = mask_mode

if segmentation_classes is not None:
parameters["editConfig"]["segmentationClasses"] = segmentation_classes
shared_generation_parameters["segmentation_classes"] = segmentation_classes
parameters["editConfig"]["classes"] = segmentation_classes
shared_generation_parameters["classes"] = segmentation_classes

if mask_dilation is not None:
parameters["editConfig"]["maskDilation"] = mask_dilation
Expand All @@ -527,7 +527,7 @@ class ID
shared_generation_parameters["add_watermark"] = add_watermark

if safety_filter_level is not None:
parameters["safetyFilterLevel"] = safety_filter_level
parameters["safetySetting"] = safety_filter_level
shared_generation_parameters["safety_filter_level"] = safety_filter_level

if person_generation is not None:
Expand Down

0 comments on commit 0c498c5

Please sign in to comment.