diff --git a/tests/system/aiplatform/test_vision_models.py b/tests/system/aiplatform/test_vision_models.py index 75ab2fd8df..33e5720d37 100644 --- a/tests/system/aiplatform/test_vision_models.py +++ b/tests/system/aiplatform/test_vision_models.py @@ -162,6 +162,83 @@ def test_image_generation_model_generate_images(self): assert image.generation_parameters["index_of_image_in_batch"] == idx assert image.generation_parameters["language"] == language + for width, height in [(1, 1), (9, 16), (16, 9), (4, 3), (3, 4)]: + prompt_aspect_ratio = "A street lit up on a rainy night" + model = vision_models.ImageGenerationModel.from_pretrained( + "imagegeneration@006" + ) + + number_of_images = 4 + seed = 1 + guidance_scale = 15 + language = "en" + aspect_ratio = f"{width}:{height}" + + image_response = model.generate_images( + prompt=prompt_aspect_ratio, + number_of_images=number_of_images, + aspect_ratio=aspect_ratio, + seed=seed, + guidance_scale=guidance_scale, + language=language, + ) + + assert len(image_response.images) == number_of_images + for idx, image in enumerate(image_response): + assert image.generation_parameters + assert image.generation_parameters["prompt"] == prompt_aspect_ratio + assert image.generation_parameters["aspect_ratio"] == aspect_ratio + assert image.generation_parameters["seed"] == seed + assert image.generation_parameters["guidance_scale"] == guidance_scale + assert image.generation_parameters["index_of_image_in_batch"] == idx + assert image.generation_parameters["language"] == language + assert ( + abs( + float(image.size[0]) / float(image.size[1]) + - float(width) / float(height) + ) + <= 0.001 + ) + + person_generation_prompts = [ + "A street lit up on a rainy night", + "A woman walking down a street lit up on a rainy night", + "A child walking down a street lit up on a rainy night", + "A man walking down a street lit up on a rainy night", + ] + + person_generation_levels = ["dont_allow", "allow_adult", "allow_all"] + + for i in range(0, 3): + for j in range(0, i + 1): + image_response = model.generate_images( + prompt=person_generation_prompts[j], + number_of_images=number_of_images, + seed=seed, + guidance_scale=guidance_scale, + language=language, + person_generation=person_generation_levels[j], + ) + if i == j: + assert len(image_response.images) == number_of_images + else: + assert len(image_response.images) < number_of_images + for idx, image in enumerate(image_response): + assert ( + image.generation_parameters["person_generation"] + == person_generation_levels[j] + ) + assert ( + image.generation_parameters["prompt"] + == person_generation_prompts[j] + ) + assert image.generation_parameters["seed"] == seed + assert ( + image.generation_parameters["guidance_scale"] == guidance_scale + ) + assert image.generation_parameters["index_of_image_in_batch"] == idx + assert image.generation_parameters["language"] == language + # Test saving and loading images with tempfile.TemporaryDirectory() as temp_dir: image_path = os.path.join(temp_dir, "image.png") @@ -178,8 +255,14 @@ def test_image_generation_model_generate_images(self): mask_pil_image.save(mask_path, format="PNG") mask_image = vision_models.Image.load_from_file(mask_path) - # Test generating image from base image + # Test generating image from base image prompt2 = "Ancient book style" + edit_mode = "inpainting-insert" + mask_mode = "foreground" + mask_dilation = 0.06 + product_position = "fixed" + output_mime_type = "image/jpeg" + compression_quality = 0.90 image_response2 = model.edit_image( prompt=prompt2, # Optional: @@ -188,6 +271,12 @@ def test_image_generation_model_generate_images(self): guidance_scale=guidance_scale, base_image=image1, mask=mask_image, + edit_mode=edit_mode, + mask_mode=mask_mode, + mask_dilation=mask_dilation, + product_position=product_position, + output_mime_type=output_mime_type, + compression_quality=compression_quality, language=language, ) assert len(image_response2.images) == number_of_images @@ -199,6 +288,90 @@ def test_image_generation_model_generate_images(self): assert image.generation_parameters["seed"] == seed assert image.generation_parameters["guidance_scale"] == guidance_scale assert image.generation_parameters["index_of_image_in_batch"] == idx + assert image.generation_parameters["edit_mode"] == edit_mode + assert image.generation_parameters["mask_mode"] == mask_mode + assert image.generation_parameters["mask_dilation"] == mask_dilation + assert image.generation_parameters["product_position"] == product_position + assert image.generation_parameters["mime_type"] == output_mime_type + assert ( + image.generation_parameters["compression_quality"] + == compression_quality + ) + assert image.generation_parameters["language"] == language + assert "base_image_hash" in image.generation_parameters + assert "mask_hash" in image.generation_parameters + + prompt3 = "Chocolate chip cookies" + edit_mode = "inpainting-insert" + mask_mode = "semantic" + segmentation_classes = [1, 13, 17, 9, 18] + product_position = "fixed" + output_mime_type = "image/png" + + image_response3 = model.edit_image( + prompt=prompt3, + number_of_images=number_of_images, + seed=seed, + guidance_scale=guidance_scale, + base_image=image1, + mask=mask_image, + edit_mode=edit_mode, + mask_mode=mask_mode, + segmentation_classes=segmentation_classes, + product_position=product_position, + output_mime_type=output_mime_type, + language=language, + ) + + assert len(image_response3.images) == number_of_images + for idx, image in enumerate(image_response3): + assert image.generation_parameters + assert image.generation_parameters["prompt"] == prompt3 + assert image.generation_parameters["seed"] == seed + assert image.generation_parameters["guidance_scale"] == guidance_scale + assert image.generation_parameters["index_of_image_in_batch"] == idx + assert image.generation_parameters["edit_mode"] == edit_mode + assert image.generation_parameters["mask_mode"] == mask_mode + assert ( + image.generation_parameters["segmentation_classes"] + == segmentation_classes + ) + assert image.generation_parameters["product_position"] == product_position + assert image.generation_parameters["mime_type"] == output_mime_type assert image.generation_parameters["language"] == language assert "base_image_hash" in image.generation_parameters assert "mask_hash" in image.generation_parameters + + def test_image_verification_model_verify_image(self): + """Tests the image verification model verifying watermark presence in an image.""" + verification_model = vision_models.ImageVerificationModel.from_pretrained( + "imageverification@001" + ) + model = vision_models.ImageGenerationModel.from_pretrained( + "imagegeneration@005" + ) + seed = 1 + guidance_scale = 15 + language = "en" + image_verification_response = verification_model.verify_image( + image=_create_blank_image() + ) + assert image_verification_response["decision"] == "REJECT" + + prompt = "A street lit up on a rainy night" + image_response = model.generate_images( + prompt=prompt, + number_of_images=1, + seed=seed, + guidance_scale=guidance_scale, + language=language, + add_watermark=True, + ) + assert len(image_response.images) == 1 + + image_with_watermark = vision_models.Image(image_response.images[0].image_bytes) + + image_verification_response = verification_model.verify_image( + image_with_watermark + ) + assert image_verification_response["decision"] == "ACCEPT" diff --git a/tests/unit/aiplatform/test_vision_models.py b/tests/unit/aiplatform/test_vision_models.py index 7eb8328cfe..a567e7424d 100644 --- a/tests/unit/aiplatform/test_vision_models.py +++ b/tests/unit/aiplatform/test_vision_models.py @@ -91,6 +91,18 @@ }, } +_IMAGE_VERIFICATION_PUBLISHER_MODEL_DICT = { + "name": "publishers/google/models/imageverification", + "version_id": "001", + "open_source_category": "PROPRIETARY", + "launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA, + "publisher_model_template": "projects/{project}/locations/{location}/publishers/google/models/imageverification@001", + "predict_schemata": { + "instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/watermark_verification_model_1.0.0.yaml", + "parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/watermark_verification_model_1.0.0.yaml", + }, +} + def make_image_base64(width: int, height: int) -> str: image: PIL_Image.Image = PIL_Image.new(mode="RGB", size=(width, height)) @@ -119,7 +131,9 @@ def make_image_generation_response_gcs(count: int = 1) -> Dict[str, Any]: for _ in range(count): predictions.append( { - "gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png", + "gcsUri": ( + "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png" + ), "mimeType": "image/png", } ) @@ -181,7 +195,9 @@ def setup_method(self): def teardown_method(self): initializer.global_pool.shutdown(wait=True) - def _get_image_generation_model(self) -> preview_vision_models.ImageGenerationModel: + def _get_image_generation_model( + self, + ) -> preview_vision_models.ImageGenerationModel: """Gets the image generation model.""" aiplatform.init( project=_TEST_PROJECT, @@ -320,9 +336,9 @@ def test_generate_images(self): image_response[0].save(location=image_path) image1 = preview_vision_models.GeneratedImage.load_from_file(image_path) # assert image1._pil_image.size == (width, height) - assert image1.generation_parameters assert image1.generation_parameters["prompt"] == prompt1 assert image1.generation_parameters["language"] == language + assert image1.generation_parameters["negative_prompt"] == negative_prompt1 # Preparing mask mask_path = os.path.join(temp_dir, "mask.png") @@ -441,6 +457,14 @@ def test_generate_images_gcs(self): return_value=gca_predict_response, ) as mock_predict: prompt2 = "Ancient book style" + edit_mode = "inpainting-insert" + mask_mode = "background" + mask_dilation = 0.06 + product_position = "fixed" + output_mime_type = "image/jpeg" + compression_quality = 80 + safety_filter_level = "block_fewest" + person_generation = "allow_all" image_response2 = model.edit_image( prompt=prompt2, # Optional: @@ -449,8 +473,14 @@ def test_generate_images_gcs(self): guidance_scale=guidance_scale, base_image=image1, mask=mask_image, - language=language, - output_gcs_uri=output_gcs_uri, + edit_mode=edit_mode, + mask_mode=mask_mode, + mask_dilation=mask_dilation, + product_position=product_position, + output_mime_type=output_mime_type, + compression_quality=compression_quality, + safety_filter_level=safety_filter_level, + person_generation=person_generation, ) predict_kwargs = mock_predict.call_args[1] actual_parameters = predict_kwargs["parameters"] @@ -458,7 +488,19 @@ def test_generate_images_gcs(self): assert actual_instance["prompt"] == prompt2 assert actual_instance["image"]["gcsUri"] assert actual_instance["mask"]["image"]["gcsUri"] - assert actual_parameters["language"] == language + assert actual_parameters["editConfig"]["editMode"] == edit_mode + assert actual_parameters["editConfig"]["maskMode"] == mask_mode + assert actual_parameters["editConfig"]["maskDilation"] == mask_dilation + assert ( + actual_parameters["editConfig"]["productPosition"] == product_position + ) + assert actual_parameters["outputOptions"]["mimeType"] == output_mime_type + assert ( + actual_parameters["outputOptions"]["compressionQuality"] + == compression_quality + ) + assert actual_parameters["safetyFilterLevel"] == safety_filter_level + assert actual_parameters["personGeneration"] == person_generation assert len(image_response2.images) == number_of_images for image in image_response2: @@ -466,8 +508,20 @@ def test_generate_images_gcs(self): assert image.generation_parameters["prompt"] == prompt2 assert image.generation_parameters["base_image_uri"] assert image.generation_parameters["mask_uri"] - assert image.generation_parameters["language"] == language - assert image.generation_parameters["storage_uri"] == output_gcs_uri + assert image.generation_parameters["edit_mode"] == edit_mode + assert image.generation_parameters["mask_mode"] == mask_mode + assert image.generation_parameters["mask_dilation"] == mask_dilation + assert image.generation_parameters["product_position"] == product_position + assert image.generation_parameters["mime_type"] == output_mime_type + assert ( + image.generation_parameters["compression_quality"] + == compression_quality + ) + assert ( + image.generation_parameters["safety_filter_level"] + == safety_filter_level + ) + assert image.generation_parameters["person_generation"] == person_generation @unittest.skip(reason="b/295946075 The service stopped supporting image sizes.") def test_generate_images_requests_square_images_by_default(self): @@ -514,6 +568,96 @@ def test_generate_images_requests_square_images_by_default(self): actual_parameters = predict_kwargs["parameters"] assert "sampleImageSize" not in actual_parameters + def test_generate_images_requests_9x16_images(self): + """Tests that the model class generates 9x16 images.""" + model = self._get_image_generation_model() + + aspect_ratio = "9:16" + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + ) as mock_predict: + model.generate_images(prompt="test", aspect_ratio=aspect_ratio) + predict_kwargs = mock_predict.call_args[1] + actual_parameters = predict_kwargs["parameters"] + assert actual_parameters["aspectRatio"] == aspect_ratio + + def test_generate_images_requests_with_aspect_ratio(self): + """Tests that the model class generates images with different aspect ratios""" + + def test_aspect_ratio(aspect_ratio: str): + model = self._get_image_generation_model() + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + ) as mock_predict: + model.generate_images(prompt="test", aspect_ratio=aspect_ratio) + predict_kwargs = mock_predict.call_args[1] + actual_parameters = predict_kwargs["parameters"] + assert actual_parameters["aspectRatio"] == aspect_ratio + + aspect_ratios = ["1:1", "9:16", "16:9", "4:3", "3:4"] + for aspect_ratio in aspect_ratios: + test_aspect_ratio(aspect_ratio) + + def test_generate_images_requests_add_watermark(self): + """Tests that the model class generates images with watermark.""" + model = self._get_image_generation_model() + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + ) as mock_predict: + model.generate_images( + prompt="test", + add_watermark=True, + ) + predict_kwargs = mock_predict.call_args[1] + actual_parameters = predict_kwargs["parameters"] + assert actual_parameters["addWatermark"] + + def test_generate_images_requests_safety_filter_level(self): + """Tests that the model class applies safety filter levels""" + model = self._get_image_generation_model() + + safety_filter_levels = [ + "block_most", + "block_some", + "block_few", + "block_fewest", + ] + + for level in safety_filter_levels: + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + ) as mock_predict: + model.generate_images( + prompt="test", + safety_filter_level=level, + ) + predict_kwargs = mock_predict.call_args[1] + actual_parameters = predict_kwargs["parameters"] + assert actual_parameters["safetyFilterLevel"] == level + + def test_generate_images_requests_person_generation(self): + """Tests that the model class generates person images.""" + model = self._get_image_generation_model() + + for person_generation in ["dont_allow", "allow_adult", "allow_all"]: + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + ) as mock_predict: + model.generate_images( + prompt="test", + person_generation=person_generation, + ) + predict_kwargs = mock_predict.call_args[1] + actual_parameters = predict_kwargs["parameters"] + assert actual_parameters["personGeneration"] == person_generation + def test_upscale_image_on_generated_image(self): """Tests image upscaling on generated images.""" model = self._get_image_generation_model() @@ -711,6 +855,51 @@ def test_ask_question(self): assert actual_answers == image_answers +@pytest.mark.usefixtures("google_auth_mock") +class ImageVerificationModelTests: + """Unit tests for the image verification models.""" + + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_get_image_verification_results(self): + """Tests the image verification model.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _IMAGE_VERIFICATION_PUBLISHER_MODEL_DICT + ), + ) as mock_get_publisher_model: + model = ga_vision_models.ImageVerificationModel.from_pretrained( + "imageverification@001" + ) + mock_get_publisher_model.assert_called_once_with( + name="publishers/google/models/imageverification@001", + retry=base._DEFAULT_RETRY, + ) + + image = generate_image_from_file() + gca_prediction_response = gca_prediction_service.PredictResponse() + gca_prediction_response.predictions.append({"decision": "REJECT"}) + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_prediction_response, + ): + actual_results = model.verify_image(image=image) + assert actual_results == [gca_prediction_response, "REJECT"] + + @pytest.mark.usefixtures("google_auth_mock") class TestMultiModalEmbeddingModels: """Unit tests for the image generation models.""" @@ -780,7 +969,10 @@ def test_image_embedding_model_with_image_and_text(self): test_embeddings = [0, 0] gca_predict_response = gca_prediction_service.PredictResponse() gca_predict_response.predictions.append( - {"imageEmbedding": test_embeddings, "textEmbedding": test_embeddings} + { + "imageEmbedding": test_embeddings, + "textEmbedding": test_embeddings, + } ) image = generate_image_from_file() @@ -847,7 +1039,10 @@ def test_image_embedding_model_with_lower_dimensions(self): test_embeddings = [0] * dimension gca_predict_response = gca_prediction_service.PredictResponse() gca_predict_response.predictions.append( - {"imageEmbedding": test_embeddings, "textEmbedding": test_embeddings} + { + "imageEmbedding": test_embeddings, + "textEmbedding": test_embeddings, + } ) image = generate_image_from_file() @@ -883,7 +1078,10 @@ def test_image_embedding_model_with_gcs_uri(self): test_embeddings = [0, 0] gca_predict_response = gca_prediction_service.PredictResponse() gca_predict_response.predictions.append( - {"imageEmbedding": test_embeddings, "textEmbedding": test_embeddings} + { + "imageEmbedding": test_embeddings, + "textEmbedding": test_embeddings, + } ) image = generate_image_from_gcs_uri() @@ -919,7 +1117,10 @@ def test_image_embedding_model_with_storage_url(self): test_embeddings = [0, 0] gca_predict_response = gca_prediction_service.PredictResponse() gca_predict_response.predictions.append( - {"imageEmbedding": test_embeddings, "textEmbedding": test_embeddings} + { + "imageEmbedding": test_embeddings, + "textEmbedding": test_embeddings, + } ) image = generate_image_from_storage_url() diff --git a/vertexai/preview/vision_models.py b/vertexai/preview/vision_models.py index ca5d019814..470b48105f 100644 --- a/vertexai/preview/vision_models.py +++ b/vertexai/preview/vision_models.py @@ -15,18 +15,20 @@ """Classes for working with vision models.""" from vertexai.vision_models._vision_models import ( + GeneratedImage, Image, + ImageCaptioningModel, ImageGenerationModel, ImageGenerationResponse, - ImageCaptioningModel, ImageQnAModel, ImageTextModel, - GeneratedImage, MultiModalEmbeddingModel, MultiModalEmbeddingResponse, Video, VideoEmbedding, VideoSegmentConfig, + WatermarkVerificationModel, + WatermarkVerificationResponse, ) __all__ = [ @@ -36,10 +38,12 @@ "ImageCaptioningModel", "ImageQnAModel", "ImageTextModel", + "WatermarkVerificationModel", "GeneratedImage", "MultiModalEmbeddingModel", "MultiModalEmbeddingResponse", "Video", "VideoEmbedding", "VideoSegmentConfig", + "WatermarkVerificationResponse", ] diff --git a/vertexai/vision_models/_vision_models.py b/vertexai/vision_models/_vision_models.py index 507955c316..9e0c45ebdd 100644 --- a/vertexai/vision_models/_vision_models.py +++ b/vertexai/vision_models/_vision_models.py @@ -21,7 +21,7 @@ import json import pathlib import typing -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union import urllib from google.cloud import storage @@ -312,35 +312,104 @@ def _generate_images( number_of_images: int = 1, width: Optional[int] = None, height: Optional[int] = None, + aspect_ratio: Optional[Literal["1:1", "9:16", "16:9", "4;3", "3:4"]] = None, guidance_scale: Optional[float] = None, seed: Optional[int] = None, base_image: Optional["Image"] = None, mask: Optional["Image"] = None, + edit_mode: Optional[ + Literal[ + "inpainting-insert", "inpainting-remove", "outpainting", "product-image" + ] + ] = None, + mask_mode: Optional[Literal["background", "foreground", "semantic"]] = None, + segmentation_classes: Optional[List[str]] = None, + mask_dilation: Optional[float] = None, + product_position: Optional[Literal["fixed", "reposition"]] = None, + output_mime_type: Optional[Literal["image/png", "image/jpeg"]] = None, + compression_quality: Optional[float] = None, language: Optional[str] = None, output_gcs_uri: Optional[str] = None, + add_watermark: Optional[bool] = False, + safety_filter_level: Optional[ + Literal["block_most", "block_some", "block_few", "block_fewest"] + ] = None, + person_generation: Optional[ + Literal["dont_allow", "allow_adult", "allow_all"] + ] = None, ) -> "ImageGenerationResponse": """Generates images from text prompt. Args: prompt: Text prompt for the image. - negative_prompt: A description of what you want to omit in - the generated images. + negative_prompt: A description of what you want to omit in the generated + images. number_of_images: Number of images to generate. Range: 1..8. width: Width of the image. One of the sizes must be 256 or 1024. height: Height of the image. One of the sizes must be 256 or 1024. - guidance_scale: Controls the strength of the prompt. - Suggested values are: + aspect_ratio: Aspect ratio for the image. Supported values are: + * 1:1 - Square image + * 9:16 - Portait image + * 16:9 - Landscape image + * 4:3 - Landscape, desktop ratio. + * 3:4 - Portrait, desktop ratio + guidance_scale: Controls the strength of the prompt. Suggested values + are - * 0-9 (low strength) * 10-20 (medium strength) * 21+ (high strength) seed: Image generation random seed. base_image: Base image to use for the image generation. mask: Mask for the base image. + edit_mode: Describes the editing mode for the request. Supported values + are - + * inpainting-insert: fills the mask area based on the text prompt + (requires mask and text) + * inpainting-remove: removes the object(s) in the mask area. + (requires mask) + * outpainting: extend the image based on the mask area. + (Requires mask) + * product-image: Changes the background for the predominant product + or subject in the image + mask_mode: Solicits generation of the mask (v/s providing mask as an + input). Supported values are: + * background: Automatically generates a mask for all regions except + the primary subject(s) of the image + * foreground: Automatically generates a mask for the primary + subjects(s) of the image. + * semantic: Segment one or more of the segmentation classes using + class ID + segmentation_classes: List of class IDs for segmentation. Max of 5 IDs + mask_dilation: Defines the dilation percentage of the mask provided. + Float between 0 and 1. Defaults to 0.03 + product_position: Defines whether the product should stay fixed or be + repositioned. Supported Values: + * fixed: Fixed position + * reposition: Can be moved (default) + output_mime_type: Which image format should the output be saved as. + Supported values: + * image/png: Save as a PNG image + * image/jpeg: Save as a JPEG image + compression_quality: Level of compression if the output mime type is + selected to be image/jpeg. Float between 0 to 100 language: Language of the text prompt for the image. Default: None. - Supported values are `"en"` for English, `"hi"` for Hindi, - `"ja"` for Japanese, `"ko"` for Korean, and `"auto"` for - automatic language detection. + Supported values are `"en"` for English, `"hi"` for Hindi, `"ja"` + for Japanese, `"ko"` for Korean, and `"auto"` for automatic language + detection. output_gcs_uri: Google Cloud Storage uri to store the generated images. + add_watermark: Add a watermark to the generated image + safety_filter_level: Adds a filter level to Safety filtering. Supported + values are: + * "block_most" : Strongest filtering level, most strict + blocking + * "block_some" : Block some problematic prompts and responses + * "block_few" : Block fewer problematic prompts and responses + * "block_fewest" : Block very few problematic prompts and responses + person_generation: Allow generation of people by the model Supported + values are: + * "dont_allow" : Block generation of people + * "allow_adult" : Generate adults, but not children + * "allow_all" : Generate adults and children Returns: An `ImageGenerationResponse` object. @@ -393,7 +462,9 @@ def _generate_images( parameters = {} max_size = max(width or 0, height or 0) or None - if max_size: + if aspect_ratio is not None: + parameters["aspectRatio"] = aspect_ratio + elif max_size: # Note: The size needs to be a string parameters["sampleImageSize"] = str(max_size) if height is not None and width is not None and height != width: @@ -421,6 +492,48 @@ def _generate_images( parameters["storageUri"] = output_gcs_uri shared_generation_parameters["storage_uri"] = output_gcs_uri + parameters["editConfig"] = {} + if edit_mode is not None: + parameters["editConfig"]["editMode"] = edit_mode + shared_generation_parameters["edit_mode"] = edit_mode + + if mask_mode is not None: + parameters["editConfig"]["maskMode"] = mask_mode + shared_generation_parameters["mask_mode"] = mask_mode + + if segmentation_classes is not None: + parameters["editConfig"]["segmentationClasses"] = segmentation_classes + shared_generation_parameters["segmentation_classes"] = segmentation_classes + + if mask_dilation is not None: + parameters["editConfig"]["maskDilation"] = mask_dilation + shared_generation_parameters["mask_dilation"] = mask_dilation + + if product_position is not None: + parameters["editConfig"]["productPosition"] = product_position + shared_generation_parameters["product_position"] = product_position + + parameters["outputOptions"] = {} + if output_mime_type is not None: + parameters["outputOptions"]["mimeType"] = output_mime_type + shared_generation_parameters["mime_type"] = output_mime_type + + if compression_quality is not None: + parameters["outputOptions"]["compressionQuality"] = compression_quality + shared_generation_parameters["compression_quality"] = compression_quality + + if add_watermark is not None: + parameters["addWatermark"] = add_watermark + shared_generation_parameters["add_watermark"] = add_watermark + + if safety_filter_level is not None: + parameters["safetyFilterLevel"] = safety_filter_level + shared_generation_parameters["safety_filter_level"] = safety_filter_level + + if person_generation is not None: + parameters["personGeneration"] = person_generation + shared_generation_parameters["person_generation"] = person_generation + response = self._endpoint.predict( instances=[instance], parameters=parameters, @@ -446,29 +559,57 @@ def generate_images( *, negative_prompt: Optional[str] = None, number_of_images: int = 1, + aspect_ratio: Optional[Literal["1:1", "9:16", "16:9", "4;3", "3:4"]] = None, guidance_scale: Optional[float] = None, language: Optional[str] = None, seed: Optional[int] = None, output_gcs_uri: Optional[str] = None, + add_watermark: Optional[bool] = False, + safety_filter_level: Optional[ + Literal["block_most", "block_some", "block_few", "block_fewest"] + ] = None, + person_generation: Optional[ + Literal["dont_allow", "allow_adult", "allow_all"] + ] = None, ) -> "ImageGenerationResponse": """Generates images from text prompt. Args: prompt: Text prompt for the image. - negative_prompt: A description of what you want to omit in - the generated images. + negative_prompt: A description of what you want to omit in the generated + images. number_of_images: Number of images to generate. Range: 1..8. - guidance_scale: Controls the strength of the prompt. - Suggested values are: + aspect_ratio: Changes the aspect ratio of the generated image Supported + values are: + * "1:1" : 1:1 aspect ratio + * "9:16" : 9:16 aspect ratio + * "16:9" : 16:9 aspect ratio + * "4:3" : 4:3 aspect ratio + * "3:4" : 3;4 aspect_ratio + guidance_scale: Controls the strength of the prompt. Suggested values + are: * 0-9 (low strength) * 10-20 (medium strength) * 21+ (high strength) language: Language of the text prompt for the image. Default: None. - Supported values are `"en"` for English, `"hi"` for Hindi, - `"ja"` for Japanese, `"ko"` for Korean, and `"auto"` for automatic language detection. + Supported values are `"en"` for English, `"hi"` for Hindi, `"ja"` + for Japanese, `"ko"` for Korean, and `"auto"` for automatic language + detection. seed: Image generation random seed. output_gcs_uri: Google Cloud Storage uri to store the generated images. - + add_watermark: Add a watermark to the generated image + safety_filter_level: Adds a filter level to Safety filtering. Supported + values are: + * "block_most" : Strongest filtering level, most strict + blocking + * "block_some" : Block some problematic prompts and responses + * "block_few" : Block fewer problematic prompts and responses + * "block_fewest" : Block very few problematic prompts and responses + person_generation: Allow generation of people by the model Supported + values are: + * "dont_allow" : Block generation of people + * "allow_adult" : Generate adults, but not children + * "allow_all" : Generate adults and children Returns: An `ImageGenerationResponse` object. """ @@ -476,13 +617,14 @@ def generate_images( prompt=prompt, negative_prompt=negative_prompt, number_of_images=number_of_images, - # b/295946075 The service stopped supporting image sizes. - width=None, - height=None, + aspect_ratio=aspect_ratio, guidance_scale=guidance_scale, language=language, seed=seed, output_gcs_uri=output_gcs_uri, + add_watermark=add_watermark, + safety_filter_level=safety_filter_level, + person_generation=person_generation, ) def edit_image( @@ -494,9 +636,26 @@ def edit_image( negative_prompt: Optional[str] = None, number_of_images: int = 1, guidance_scale: Optional[float] = None, + edit_mode: Optional[ + Literal[ + "inpainting-insert", "inpainting-remove", "outpainting", "product-image" + ] + ] = None, + mask_mode: Optional[Literal["background", "foreground", "semantic"]] = None, + segmentation_classes: Optional[List[str]] = None, + mask_dilation: Optional[float] = None, + product_position: Optional[Literal["fixed", "reposition"]] = None, + output_mime_type: Optional[Literal["image/png", "image/jpeg"]] = None, + compression_quality: Optional[float] = None, language: Optional[str] = None, seed: Optional[int] = None, output_gcs_uri: Optional[str] = None, + safety_filter_level: Optional[ + Literal["block_most", "block_some", "block_few", "block_fewest"] + ] = None, + person_generation: Optional[ + Literal["dont_allow", "allow_adult", "allow_all"] + ] = None, ) -> "ImageGenerationResponse": """Edits an existing image based on text prompt. @@ -512,11 +671,55 @@ def edit_image( * 0-9 (low strength) * 10-20 (medium strength) * 21+ (high strength) + edit_mode: Describes the editing mode for the request. Supported values + are: + * inpainting-insert: fills the mask area based on the text prompt + (requires mask and text) + * inpainting-remove: removes the object(s) in the mask area. + (requires mask) + * outpainting: extend the image based on the mask area. + (Requires mask) + * product-image: Changes the background for the predominant product + or subject in the image + segmentation_classes: List of class IDs for segmentation. Max of 5 IDs + mask_mode: Solicits generation of the mask (v/s providing mask as an + input). Supported values are: + * background: Automatically generates a mask for all regions except + the primary subject(s) of the image + * foreground: Automatically generates a mask for the primary + subjects(s) of the image. + * semantic: Segment one or more of the segmentation classes using + class ID + mask_dilation: Defines the dilation percentage of the mask provided. + Float between 0 and 1. Defaults to 0.03 + product_position: Defines whether the product should stay fixed or be + repositioned. Supported Values: + * fixed: Fixed position + * reposition: Can be moved (default) + output_mime_type: Which image format should the output be saved as. + Supported values: + * image/png: Save as a PNG image + * image/jpeg: Save as a JPEG image + compression_quality: Level of compression if the output mime type is + selected to be image/jpeg. Float between 0 to 100 language: Language of the text prompt for the image. Default: None. Supported values are `"en"` for English, `"hi"` for Hindi, - `"ja"` for Japanese, `"ko"` for Korean, and `"auto"` for automatic language detection. + `"ja"` for Japanese, `"ko"` for Korean, and `"auto"` for + automatic language detection. seed: Image generation random seed. output_gcs_uri: Google Cloud Storage uri to store the edited images. + safety_filter_level: Adds a filter level to Safety filtering. Supported + values are: + * "block_most" : Strongest filtering level, most strict + blocking + * "block_some" : Block some problematic prompts and responses + * "block_few" : Block fewer problematic prompts and responses + * "block_fewest" : Block very few problematic prompts and responses + person_generation: Allow generation of people by the model Supported + values are: + * "dont_allow" : Block generation of people + * "allow_adult" : Generate adults, but not children + * "allow_all" : Generate adults and children Returns: An `ImageGenerationResponse` object. @@ -529,8 +732,18 @@ def edit_image( seed=seed, base_image=base_image, mask=mask, + edit_mode=edit_mode, + mask_mode=mask_mode, + segmentation_classes=segmentation_classes, + mask_dilation=mask_dilation, + product_position=product_position, + output_mime_type=output_mime_type, + compression_quality=compression_quality, language=language, output_gcs_uri=output_gcs_uri, + add_watermark=False, # Not supported for editing yet + safety_filter_level=safety_filter_level, + person_generation=person_generation, ) def upscale_image( @@ -1020,3 +1233,52 @@ class ImageTextModel(ImageCaptioningModel, ImageQnAModel): # since SDK Model Garden classes should follow the design pattern of exactly 1 SDK class to 1 Model Garden schema URI _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/vision_reasoning_model_1.0.0.yaml" + + +@dataclasses.dataclass +class WatermarkVerificationResponse: + + __module__ = "vertex.preview.vision_models" + + _prediction_response: Any + watermark_verification_result: Optional[str] = None + + +class WatermarkVerificationModel(_model_garden_models._ModelGardenModel): + """Verifies if an image has a watermark""" + + __module__ = "vertexai.preview.vision_models" + + _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/watermark_watermark_verification_model_1.0.0.yaml" + + def verify_image(self, image: Image) -> WatermarkVerificationResponse: + """Verifies the watermark of an image. + + Args: + image: The image to verify. + + Returns: + A WatermarkVerificationResponse, containing the confidence level of + the image being watermarked. + """ + if not image: + raise ValueError("Image is required.") + + instance = {} + + if image._gcs_uri: + instance["image"] = {"gcsUri": image._gcs_uri} + else: + instance["image"] = {"bytesBase64Encoded": image._as_base64_string()} + + parameters = {} + response = self._endpoint.predict( + instances=[instance], + parameters=parameters, + ) + + verification_likelihood = response.predictions[0].get("decision") + return WatermarkVerificationResponse( + _prediction_response=response, + watermark_verification_result=verification_likelihood, + )