Skip to content

Commit

Permalink
feat: LVM - Add GCS URI support for Imagen Models (imagetext, `imag…
Browse files Browse the repository at this point in the history
…egeneration`)

PiperOrigin-RevId: 606401323
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Feb 13, 2024
1 parent 32c7197 commit 4109ea8
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 37 deletions.
125 changes: 125 additions & 0 deletions tests/unit/aiplatform/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@ def make_image_generation_response(
return {"predictions": predictions}


def make_image_generation_response_gcs(count: int = 1) -> Dict[str, Any]:
predictions = []
for _ in range(count):
predictions.append(
{
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png",
"mimeType": "image/png",
}
)
return {"predictions": predictions}


def make_image_upscale_response(upscale_size: int) -> Dict[str, Any]:
predictions = {
"bytesBase64Encoded": make_image_base64(upscale_size, upscale_size),
Expand All @@ -122,6 +134,14 @@ def make_image_upscale_response(upscale_size: int) -> Dict[str, Any]:
return {"predictions": [predictions]}


def make_image_upscale_response_gcs() -> Dict[str, Any]:
predictions = {
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png",
"mimeType": "image/png",
}
return {"predictions": [predictions]}


def generate_image_from_file(
width: int = 100, height: int = 100
) -> ga_vision_models.Image:
Expand Down Expand Up @@ -332,6 +352,111 @@ def test_generate_images(self):
assert image.generation_parameters["mask_hash"]
assert image.generation_parameters["language"] == language

def test_generate_images_gcs(self):
"""Tests the image generation model."""
model = self._get_image_generation_model()

# TODO(b/295946075) The service stopped supporting image sizes.
# height = 768
number_of_images = 4
seed = 1
guidance_scale = 15
language = "en"
output_gcs_uri = "gs://test-bucket/"

image_generation_response = make_image_generation_response_gcs(
count=number_of_images
)
gca_predict_response = gca_prediction_service.PredictResponse()
gca_predict_response.predictions.extend(
image_generation_response["predictions"]
)

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
) as mock_predict:
prompt1 = "Astronaut riding a horse"
negative_prompt1 = "bad quality"
image_response = model.generate_images(
prompt=prompt1,
# Optional:
negative_prompt=negative_prompt1,
number_of_images=number_of_images,
# TODO(b/295946075) The service stopped supporting image sizes.
# width=width,
# height=height,
seed=seed,
guidance_scale=guidance_scale,
language=language,
output_gcs_uri=output_gcs_uri,
)
predict_kwargs = mock_predict.call_args[1]
actual_parameters = predict_kwargs["parameters"]
actual_instance = predict_kwargs["instances"][0]
assert actual_instance["prompt"] == prompt1
assert actual_parameters["negativePrompt"] == negative_prompt1
# TODO(b/295946075) The service stopped supporting image sizes.
# assert actual_parameters["sampleImageSize"] == str(max(width, height))
# assert actual_parameters["aspectRatio"] == f"{width}:{height}"
assert actual_parameters["seed"] == seed
assert actual_parameters["guidanceScale"] == guidance_scale
assert actual_parameters["language"] == language
assert actual_parameters["storageUri"] == output_gcs_uri

assert len(image_response.images) == number_of_images
for idx, image in enumerate(image_response):
assert image.generation_parameters
assert image.generation_parameters["prompt"] == prompt1
assert image.generation_parameters["negative_prompt"] == negative_prompt1
# TODO(b/295946075) The service stopped supporting image sizes.
# assert image.generation_parameters["width"] == width
# assert image.generation_parameters["height"] == height
assert image.generation_parameters["seed"] == seed
assert image.generation_parameters["guidance_scale"] == guidance_scale
assert image.generation_parameters["language"] == language
assert image.generation_parameters["index_of_image_in_batch"] == idx
assert image.generation_parameters["storage_uri"] == output_gcs_uri

image1 = generate_image_from_gcs_uri()
mask_image = generate_image_from_gcs_uri()

# Test generating image from base image
with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
) as mock_predict:
prompt2 = "Ancient book style"
image_response2 = model.edit_image(
prompt=prompt2,
# Optional:
number_of_images=number_of_images,
seed=seed,
guidance_scale=guidance_scale,
base_image=image1,
mask=mask_image,
language=language,
output_gcs_uri=output_gcs_uri,
)
predict_kwargs = mock_predict.call_args[1]
actual_parameters = predict_kwargs["parameters"]
actual_instance = predict_kwargs["instances"][0]
assert actual_instance["prompt"] == prompt2
assert actual_instance["image"]["gcsUri"]
assert actual_instance["mask"]["image"]["gcsUri"]
assert actual_parameters["language"] == language

assert len(image_response2.images) == number_of_images
for image in image_response2:
assert image.generation_parameters
assert image.generation_parameters["prompt"] == prompt2
assert image.generation_parameters["base_image_uri"]
assert image.generation_parameters["mask_uri"]
assert image.generation_parameters["language"] == language
assert image.generation_parameters["storage_uri"] == output_gcs_uri

@unittest.skip(reason="b/295946075 The service stopped supporting image sizes.")
def test_generate_images_requests_square_images_by_default(self):
"""Tests that the model class generates square image by default."""
Expand Down
Loading

0 comments on commit 4109ea8

Please sign in to comment.