From 0c498c5e4226b2a16adb0ff3cf7e6698a05aa5c7 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 27 Mar 2024 12:45:22 -0700 Subject: [PATCH] feat: Add safety filter levels, watermark support and person generation support for Imagen 2 Changelog: Changed the following parameters internally: - `editConfig` to `editConfigV6` - `segmentationClasses` to `classes` - `safetyFilterLevel` to `safetySetting` PiperOrigin-RevId: 619620158 --- tests/unit/aiplatform/test_vision_models.py | 4 ++-- vertexai/vision_models/_vision_models.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit/aiplatform/test_vision_models.py b/tests/unit/aiplatform/test_vision_models.py index a567e7424d..accf61daa9 100644 --- a/tests/unit/aiplatform/test_vision_models.py +++ b/tests/unit/aiplatform/test_vision_models.py @@ -499,7 +499,7 @@ def test_generate_images_gcs(self): actual_parameters["outputOptions"]["compressionQuality"] == compression_quality ) - assert actual_parameters["safetyFilterLevel"] == safety_filter_level + assert actual_parameters["safetySetting"] == safety_filter_level assert actual_parameters["personGeneration"] == person_generation assert len(image_response2.images) == number_of_images @@ -639,7 +639,7 @@ def test_generate_images_requests_safety_filter_level(self): ) predict_kwargs = mock_predict.call_args[1] actual_parameters = predict_kwargs["parameters"] - assert actual_parameters["safetyFilterLevel"] == level + assert actual_parameters["safetySetting"] == level def test_generate_images_requests_person_generation(self): """Tests that the model class generates person images.""" diff --git a/vertexai/vision_models/_vision_models.py b/vertexai/vision_models/_vision_models.py index 9e0c45ebdd..506c7820f1 100644 --- a/vertexai/vision_models/_vision_models.py +++ b/vertexai/vision_models/_vision_models.py @@ -502,8 +502,8 @@ class ID shared_generation_parameters["mask_mode"] = mask_mode if segmentation_classes is not None: - parameters["editConfig"]["segmentationClasses"] = segmentation_classes - shared_generation_parameters["segmentation_classes"] = segmentation_classes + parameters["editConfig"]["classes"] = segmentation_classes + shared_generation_parameters["classes"] = segmentation_classes if mask_dilation is not None: parameters["editConfig"]["maskDilation"] = mask_dilation @@ -527,7 +527,7 @@ class ID shared_generation_parameters["add_watermark"] = add_watermark if safety_filter_level is not None: - parameters["safetyFilterLevel"] = safety_filter_level + parameters["safetySetting"] = safety_filter_level shared_generation_parameters["safety_filter_level"] = safety_filter_level if person_generation is not None: