Skip to content

Commit

Permalink
feat: Add safety filter levels, watermark support and person generati…
Browse files Browse the repository at this point in the history
…on support for Imagen 2

Changelog:
- Added `add_watermark` option to `generate_image` call for adding a SynthID watermark to generated images.
- Added a `edit_mode` option to `edit_image` call. Can now choose between 4 edit modes -
  - `inpainting-insert` : Edit the image within the masked region. Needs both mask and prompt
  - `inpainting-remove`: Remove objects within the masked region. Needs only mask
  - `outpainting`: Extend the image based on the mask area.
  - `product-image`: Changes background for primary subject of the image
- Added a `mask_mode` option to `edit_image` call. Can now choose between 3 mask generation modes, instead of providing masks:
  - `background`: Select everything except the primary subject(s) of the image
  - `foreground`: Select the primary subject(s) of the image
  - `semantic`: Segment one or more of the segmentation classes using class ID
- Added a `segmentation_classes` option for passing a list of class IDs when `semantic` mask_mode is used. Can send upto 5 classes
- Added a `mask_dilation` option for setting the dilation percentage of mask
- Added a `product_position` option to allow repositioning of products in the image. Supported values are:
  - `reposition`: Products can be repositioned
  - `fixed`: Product location is fixed
- Added a `output_mime_type` option to select which image format should the output be returned as. Supported values are:
  - `image/png`
  - `image/jpeg`
- Added a `compression_quality` option to select compression quality when output is `image/jpeg`.
- Added a safety filter level for selecting the level of prompt and image filtering by Responsible AI filters. Supported values are:
  - `"block_most"`  : The strictest filter. Blocks most
  - `"block_some"`  : Second most strict filter. Blocks some prompts and images
  - `"block_few"`   : Blocks a few prompts and images
  - `"block_fewest"`: Blocks fewest prompts and images
- Added an option to control person generation. Supported values are:
  - `"dont_allow"`  : Don't generate people at all
  - `"allow_adults"`: Generate adults, but not children
  - `"allow_all"`   : Allows all person generation
- Added the WatermarkVerificationModel to check if an image has a SynthID watermark. The publisher model is `imageverification@001`. The model object contains just one call, `verify_image`. `verify_image` takes only an image as the input and returns a string with one of 2 values:
  - `ACCEPT`     : The image contains a watermark
  - `REJECT`     : The image does not contain a watermark
PiperOrigin-RevId: 617924430
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Mar 21, 2024
1 parent 181dc7a commit e2efdbe
Show file tree
Hide file tree
Showing 4 changed files with 675 additions and 35 deletions.
175 changes: 174 additions & 1 deletion tests/system/aiplatform/test_vision_models.py
Expand Up @@ -162,6 +162,83 @@ def test_image_generation_model_generate_images(self):
assert image.generation_parameters["index_of_image_in_batch"] == idx
assert image.generation_parameters["language"] == language

for width, height in [(1, 1), (9, 16), (16, 9), (4, 3), (3, 4)]:
prompt_aspect_ratio = "A street lit up on a rainy night"
model = vision_models.ImageGenerationModel.from_pretrained(
"imagegeneration@006"
)

number_of_images = 4
seed = 1
guidance_scale = 15
language = "en"
aspect_ratio = f"{width}:{height}"

image_response = model.generate_images(
prompt=prompt_aspect_ratio,
number_of_images=number_of_images,
aspect_ratio=aspect_ratio,
seed=seed,
guidance_scale=guidance_scale,
language=language,
)

assert len(image_response.images) == number_of_images
for idx, image in enumerate(image_response):
assert image.generation_parameters
assert image.generation_parameters["prompt"] == prompt_aspect_ratio
assert image.generation_parameters["aspect_ratio"] == aspect_ratio
assert image.generation_parameters["seed"] == seed
assert image.generation_parameters["guidance_scale"] == guidance_scale
assert image.generation_parameters["index_of_image_in_batch"] == idx
assert image.generation_parameters["language"] == language
assert (
abs(
float(image.size[0]) / float(image.size[1])
- float(width) / float(height)
)
<= 0.001
)

person_generation_prompts = [
"A street lit up on a rainy night",
"A woman walking down a street lit up on a rainy night",
"A child walking down a street lit up on a rainy night",
"A man walking down a street lit up on a rainy night",
]

person_generation_levels = ["dont_allow", "allow_adult", "allow_all"]

for i in range(0, 3):
for j in range(0, i + 1):
image_response = model.generate_images(
prompt=person_generation_prompts[j],
number_of_images=number_of_images,
seed=seed,
guidance_scale=guidance_scale,
language=language,
person_generation=person_generation_levels[j],
)
if i == j:
assert len(image_response.images) == number_of_images
else:
assert len(image_response.images) < number_of_images
for idx, image in enumerate(image_response):
assert (
image.generation_parameters["person_generation"]
== person_generation_levels[j]
)
assert (
image.generation_parameters["prompt"]
== person_generation_prompts[j]
)
assert image.generation_parameters["seed"] == seed
assert (
image.generation_parameters["guidance_scale"] == guidance_scale
)
assert image.generation_parameters["index_of_image_in_batch"] == idx
assert image.generation_parameters["language"] == language

# Test saving and loading images
with tempfile.TemporaryDirectory() as temp_dir:
image_path = os.path.join(temp_dir, "image.png")
Expand All @@ -178,8 +255,14 @@ def test_image_generation_model_generate_images(self):
mask_pil_image.save(mask_path, format="PNG")
mask_image = vision_models.Image.load_from_file(mask_path)

# Test generating image from base image
# Test generating image from base image
prompt2 = "Ancient book style"
edit_mode = "inpainting-insert"
mask_mode = "foreground"
mask_dilation = 0.06
product_position = "fixed"
output_mime_type = "image/jpeg"
compression_quality = 0.90
image_response2 = model.edit_image(
prompt=prompt2,
# Optional:
Expand All @@ -188,6 +271,12 @@ def test_image_generation_model_generate_images(self):
guidance_scale=guidance_scale,
base_image=image1,
mask=mask_image,
edit_mode=edit_mode,
mask_mode=mask_mode,
mask_dilation=mask_dilation,
product_position=product_position,
output_mime_type=output_mime_type,
compression_quality=compression_quality,
language=language,
)
assert len(image_response2.images) == number_of_images
Expand All @@ -199,6 +288,90 @@ def test_image_generation_model_generate_images(self):
assert image.generation_parameters["seed"] == seed
assert image.generation_parameters["guidance_scale"] == guidance_scale
assert image.generation_parameters["index_of_image_in_batch"] == idx
assert image.generation_parameters["edit_mode"] == edit_mode
assert image.generation_parameters["mask_mode"] == mask_mode
assert image.generation_parameters["mask_dilation"] == mask_dilation
assert image.generation_parameters["product_position"] == product_position
assert image.generation_parameters["mime_type"] == output_mime_type
assert (
image.generation_parameters["compression_quality"]
== compression_quality
)
assert image.generation_parameters["language"] == language
assert "base_image_hash" in image.generation_parameters
assert "mask_hash" in image.generation_parameters

prompt3 = "Chocolate chip cookies"
edit_mode = "inpainting-insert"
mask_mode = "semantic"
segmentation_classes = [1, 13, 17, 9, 18]
product_position = "fixed"
output_mime_type = "image/png"

image_response3 = model.edit_image(
prompt=prompt3,
number_of_images=number_of_images,
seed=seed,
guidance_scale=guidance_scale,
base_image=image1,
mask=mask_image,
edit_mode=edit_mode,
mask_mode=mask_mode,
segmentation_classes=segmentation_classes,
product_position=product_position,
output_mime_type=output_mime_type,
language=language,
)

assert len(image_response3.images) == number_of_images
for idx, image in enumerate(image_response3):
assert image.generation_parameters
assert image.generation_parameters["prompt"] == prompt3
assert image.generation_parameters["seed"] == seed
assert image.generation_parameters["guidance_scale"] == guidance_scale
assert image.generation_parameters["index_of_image_in_batch"] == idx
assert image.generation_parameters["edit_mode"] == edit_mode
assert image.generation_parameters["mask_mode"] == mask_mode
assert (
image.generation_parameters["segmentation_classes"]
== segmentation_classes
)
assert image.generation_parameters["product_position"] == product_position
assert image.generation_parameters["mime_type"] == output_mime_type
assert image.generation_parameters["language"] == language
assert "base_image_hash" in image.generation_parameters
assert "mask_hash" in image.generation_parameters

def test_image_verification_model_verify_image(self):
"""Tests the image verification model verifying watermark presence in an image."""
verification_model = vision_models.ImageVerificationModel.from_pretrained(
"imageverification@001"
)
model = vision_models.ImageGenerationModel.from_pretrained(
"imagegeneration@005"
)
seed = 1
guidance_scale = 15
language = "en"
image_verification_response = verification_model.verify_image(
image=_create_blank_image()
)
assert image_verification_response["decision"] == "REJECT"

prompt = "A street lit up on a rainy night"
image_response = model.generate_images(
prompt=prompt,
number_of_images=1,
seed=seed,
guidance_scale=guidance_scale,
language=language,
add_watermark=True,
)
assert len(image_response.images) == 1

image_with_watermark = vision_models.Image(image_response.images[0].image_bytes)

image_verification_response = verification_model.verify_image(
image_with_watermark
)
assert image_verification_response["decision"] == "ACCEPT"

0 comments on commit e2efdbe

Please sign in to comment.