In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Imagen Product Recontext - Generation at Scale

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/vertex-ai-creative-studio/blob/main/experiments/Imagen_Product_Recontext/imagen_product_recontext_at_scale.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%vertex-ai-creative-studio%2Fmain%2Fexperiments%2FImagen_Product_Recontext%imagen_product_recontext_at_scale.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-creative-studio/main/experiments/Imagen_Product_Recontext/imagen_product_recontext_at_scale.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/bigquery/import?url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-creative-studio/main/experiments/Imagen_Product_Recontext/imagen_product_recontext_at_scale.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/bigquery/v1/32px.svg" alt="BigQuery Studio logo"><br> Open in BigQuery Studio
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-creative-studio/blob/main/experiments/Imagen_Product_Recontext/imagen_product_recontext_at_scale.ipynb">
      <img width="32px" src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

<div style="clear: both;"></div>

| | |
|-|-|
|Author(s) | [Layolin Jesudhass](https://github.com/LUJ20), Isidro De Loera

## Imports

In [12]:
!pip install --upgrade --user google-cloud-aiplatform

from google import genai
from google.genai import types
from google.cloud import aiplatform, storage
from google.cloud.aiplatform.gapic import PredictResponse
from google.colab import auth

import base64
import io
import re
import timeit
import os
import json
from pathlib import Path
from typing import Any, Dict, List, Generator
from PIL import Image
import matplotlib.pyplot as plt



In [13]:
# --------------------------------------------------
# Authenticate your Colab session with GCP
from google.colab import auth
auth.authenticate_user()
print("Authenticated with Google Cloud")
# --------------------------------------------------


Authenticated with Google Cloud


## Helper Functions


In [14]:
#Core Helper Functions for GCS scanning and image display

# GCS‑scanning + debug prints

def get_mime_type(uri: str) -> str:
    ext = os.path.splitext(uri)[1].lower()
    if ext == ".png":
        return "image/png"
    elif ext in (".jpg", ".jpeg"):
        return "image/jpeg"
    else:
        raise ValueError(f"Unsupported extension: {ext}")

def discover_product_batches(input_gcs_uri: str) -> Generator[Dict[str, object], None, None]:
    """
    Yields dicts with:
      - product_folder (e.g. "product_5")
      - image_parts   (List[types.Part])
      - product_uris  (List[str])
    """
    m = re.match(r"gs://([^/]+)/(.+?)/?$", input_gcs_uri)
    if not m:
        raise ValueError(f"Invalid GCS URI: {input_gcs_uri}")
    bucket_name, base_prefix = m.groups()
    prefix = base_prefix.rstrip("/") + "/"
    client = storage.Client()
    print(f"🔍 Scanning bucket={bucket_name} prefix={prefix}")

    # fetch all blobs under prefix
    blobs = list(client.list_blobs(bucket_name, prefix=prefix))
    print(f"  • Found {len(blobs)} total objects under {prefix}")

    # group by first‑level folder name
    folder_map: Dict[str, List[storage.blob.Blob]] = {}
    for b in blobs:
        rel = b.name[len(prefix):]  # e.g. "product_5/thermos_1.png"
        parts = rel.split("/", 1)
        if len(parts) != 2:
            continue
        folder, filename = parts
        if not filename.lower().endswith((".png", ".jpg", ".jpeg")):
            continue
        folder_map.setdefault(folder, []).append(b)

    print(f"  • Discovered product folders: {list(folder_map.keys())}")

    for folder, blob_list in folder_map.items():
        # sort & take up to 3
        blob_list.sort(key=lambda b: os.path.basename(b.name))
        blob_list = blob_list[:3]

        uris = [f"gs://{bucket_name}/{b.name}" for b in blob_list]
        parts = [
            types.Part(
                file_data=types.FileData(file_uri=uri, mime_type=get_mime_type(uri))
            )
            for uri in uris
        ]

        yield {
            "product_folder": folder,
            "image_parts":   parts,
            "product_uris":  uris,
        }

# Display helpers
def download_gcs_image_bytes(uri: str) -> bytes:
    m = re.match(r"gs://([^/]+)/(.*)", uri)
    if not m:
        raise ValueError(f"Invalid GCS URI: {uri}")
    bucket_name, obj = m.groups()
    client = storage.Client()
    return client.bucket(bucket_name).blob(obj).download_as_bytes()

def prediction_to_pil_image(pred: PredictResponse, size=(640, 640)) -> Image.Image:
    b64 = pred["bytesBase64Encoded"]
    data = base64.b64decode(b64)
    img = Image.open(io.BytesIO(data))
    img.thumbnail(size)
    return img

def display_row(items: List[Any], figsize=(12, 4)):
    if not items:
        print("No items to display.")
        return
    fig, axes = plt.subplots(1, len(items), figsize=figsize)
    if len(items) == 1:
        axes = [axes]
    for ax, it in zip(axes, items):
        if isinstance(it, Image.Image):
            ax.imshow(it)
        elif isinstance(it, dict) and "bytesBase64Encoded" in it:
            ax.imshow(prediction_to_pil_image(it))
        else:
            ax.text(0.5, 0.5, str(it), ha="center", va="center", wrap=True)
        ax.axis("off")
    plt.tight_layout()
    plt.show()

# generate() with fixed SafetySetting keyword args
def generate(image_parts: List[types.Part]) -> Dict[str, str]:
    import json, re
    from google import genai
    from google.genai import types

    client = genai.Client(
        vertexai=True,
        project="consumer-genai-experiments",
        location="global",
    )

    user_instr = types.Part.from_text(text="""
Analyze the provided images of a single product (up to 3). First, identify and describe the product in accurate, natural language: focus on material, color, shape, form, pattern, and distinctive design features.

Then, determine a visually appropriate and realistic background or scene where the product would naturally appear and look appealing. Base this on the product’s style and category — for example, place a desk lamp in a home office, or a sneaker in a modern studio.

DO NOT GENERATE PEOPLE/CHILDREN
Your output should be returned as a JSON object in the format below:
{
  "Prompt": "<rich description of product and proposed scene>",
  "product_description": "<just the product, no scene>"
}

Do not reference the original image’s background or lighting.
Do not use placeholders like “in a nice room” — be specific about the setting (e.g., “in a sunlit bohemian-style bedroom with woven textures and indoor plants”).
""")

    system_instr = types.Part.from_text(text="""
Role:
You are an expert visual analyst and prompt engineer for AI-based image generation. Your task is to analyze up to 3 input images of the same product, and generate a single, high-quality prompt suitable for AI-driven product image recontextualization.

Your objectives are twofold:

Describe the product accurately: Identify product category, form, material, texture, color, patterns, and notable design features.

Propose a compelling background/scene: Select a suitable environment in which the product would naturally and attractively appear — based on its likely usage, aesthetic, and category.

DO NOT GENERATE PEOPLE/CHILDREN

Input:
Up to 3 images of the same product (e.g., different angles or lighting).
No metadata, no background descriptions provided — just images.

Output Format:
Return a JSON object in this structure:
{
  "Prompt": "<natural language prompt for recontextualized image generation>",
  "product_description": "<just the product description, no scene>"
}""")

    contents = [ types.Content(role="user", parts=[user_instr, *image_parts]) ]
    config = types.GenerateContentConfig(
        temperature=0.2,
        top_p=0.95,
        max_output_tokens=8192,
        safety_settings=[
            types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
            types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
            types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
            types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
        ],
        system_instruction=[system_instr],
        thinking_config=types.ThinkingConfig(thinking_budget=0)
    )

    output = ""
    for chunk in client.models.generate_content_stream(
        model="gemini-2.5-flash", contents=contents, config=config
    ):
        output += chunk.text

    clean = output.strip()
    clean = re.sub(r"^```(?:\w+)?\n","", clean)
    clean = re.sub(r"\n```$","", clean)
    return json.loads(clean)


# call_product_recontext()
def call_product_recontext(
    image_bytes_list=None,
    image_uris_list=None,
    prompt=None,
    product_description=None,
    disable_prompt_enhancement=True,
    sample_count=1,
    base_steps=None,
    safety_setting=None,
    person_generation=None,
) -> PredictResponse:
    inst: Dict[str, Any] = {"productImages": []}
    if image_uris_list:
        for uri in image_uris_list:
            inst["productImages"].append({"image": {"gcsUri": uri}})
    if not inst["productImages"]:
        raise ValueError("No product images provided.")
    if product_description:
        inst["productImages"][0]["productConfig"] = {"productDescription": product_description}
    if prompt:
        inst["prompt"] = prompt

    params = {"sampleCount": sample_count}
    if disable_prompt_enhancement: params["enhancePrompt"] = False
    if safety_setting:       params["safetySetting"] = safety_setting
    if person_generation:    params["personGeneration"] = person_generation
    if base_steps:           params["baseSteps"] = base_steps

    start = timeit.default_timer()
    resp = predict_client.predict(
        endpoint=model_endpoint,
        instances=[inst],
        parameters=params,
    )
    print(f"Recontext took {timeit.default_timer()-start:.2f}s")
    return resp

# save & upload helper
def save_and_upload_recontext_image(
    prediction_response,
    product_folder: str,
    output_bucket_name: str,
    output_base_prefix: str = "cymbal_retail/product_images_output",
    image_index: int = 0,
):
    out_pref = f"{output_base_prefix}/{product_folder}"

    if isinstance(prediction_response, list):
        prediction_response = prediction_response[0]
    img = prediction_to_pil_image(prediction_response)

    local = f"{product_folder}_output_{image_index}.jpg"
    img.convert("RGB").save(local, format="JPEG")
    print(f"Saved local: {local}")

    client = storage.Client()
    bucket = client.bucket(output_bucket_name)
    # ensure folder exists
    if not list(bucket.list_blobs(prefix=out_pref+"/")):
        bucket.blob(out_pref+"/").upload_from_string("", content_type="application/x-folder")
        print(f"🗂 Created folder gs://{output_bucket_name}/{out_pref}/")

    dst = bucket.blob(f"{out_pref}/{local}")
    dst.upload_from_filename(local, content_type="image/jpeg")
    print(f"Uploaded to gs://{output_bucket_name}/{out_pref}/{local}")

## Initialize Vertex AI Client

In [15]:
# Init clients
PROJECT_ID = "consumer-genai-experiments"
LOCATION   = "us-central1"

aiplatform.init(project=PROJECT_ID, location=LOCATION)
predict_client = aiplatform.gapic.PredictionServiceClient(
    client_options={"api_endpoint": f"{LOCATION}-aiplatform.googleapis.com"}
)
model_endpoint = (
    f"projects/{PROJECT_ID}/locations/{LOCATION}"
    + "/publishers/google/models/imagen-product-recontext-preview-06-30"
)
print("Prediction client ready")


Prediction client ready


# Sequential Run

In [16]:
# Sequential one by one

from datetime import datetime
import io
from PIL import Image

# Define constants
INPUT_PREFIX = "gs://id_test_bucket/cymbal_retail/product_images_input"
OUTPUT_BUCKET = "id_test_bucket"

# Start time
start_time = datetime.now()
print(f"\nProcess started at: {start_time.strftime('%Y-%m-%d %H:%M:%S')}")

# Discover batches
batches = list(discover_product_batches(INPUT_PREFIX))
print(f"\nTotal product folders discovered: {len(batches)}")

if not batches:
    raise RuntimeError("No product batches found—check your GCS path & permissions!")

# Process each batch
for batch in batches:
    print(f"\n=== Processing {batch['product_folder']} ===")

    # Preview
    imgs = [Image.open(io.BytesIO(download_gcs_image_bytes(u))) for u in batch["product_uris"]]
    display_row(imgs)

    # Generate prompt & description
    gen = generate(batch["image_parts"])
    print("Prompt:", gen["Prompt"])
    print("Desc:  ", gen["product_description"])

    # Recontextualize
    resp = call_product_recontext(
        prompt=gen["Prompt"],
        product_description=gen["product_description"],
        image_uris_list=batch["product_uris"],
        disable_prompt_enhancement=False,
        sample_count=1,
        safety_setting="block_low_and_above",
        person_generation="allow_adult",
    )

    # Display & upload
    preds = list(resp.predictions)
    display_row(preds)
    for i, p in enumerate(preds):
        save_and_upload_recontext_image(
            p,
            batch["product_folder"],
            OUTPUT_BUCKET,
            image_index=i
        )

# End time
end_time = datetime.now()
duration = end_time - start_time
print(f"\nProcess completed at: {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"⏱Total time taken: {duration}")


# Parallel threads

In [17]:
# Parallel threads

import concurrent.futures
from datetime import datetime  # <-- Added
import io
from PIL import Image

INPUT_PREFIX = "gs://id_test_bucket/cymbal_retail/product_images_input"
OUTPUT_BUCKET = "id_test_bucket"
MAX_WORKERS = 5  # Adjust this to control parallelism

# Start time
start_time = datetime.now()
print(f"\n🔄 Process started at: {start_time.strftime('%Y-%m-%d %H:%M:%S')}")

batches = list(discover_product_batches(INPUT_PREFIX))
print(f"\nTotal product folders discovered: {len(batches)}")

if not batches:
    raise RuntimeError("No product batches found—check your GCS path & permissions!")

def process_batch(batch):
    try:
        print(f"\n=== Processing {batch['product_folder']} ===")

        # Preview
        imgs = [Image.open(io.BytesIO(download_gcs_image_bytes(u))) for u in batch["product_uris"]]
        display_row(imgs)

        # Generate prompt & description
        gen = generate(batch["image_parts"])
        print("Prompt:", gen["Prompt"])
        print("Desc:  ", gen["product_description"])

        # Recontextualize
        resp = call_product_recontext(
            prompt=gen["Prompt"],
            product_description=gen["product_description"],
            image_uris_list=batch["product_uris"],
            disable_prompt_enhancement=False,
            sample_count=1,
            safety_setting="block_low_and_above",
            person_generation="allow_adult",
        )

        # Display & upload
        preds = list(resp.predictions)
        display_row(preds)
        for i, p in enumerate(preds):
            save_and_upload_recontext_image(
                p,
                batch["product_folder"],
                OUTPUT_BUCKET,
                image_index=i
            )
    except Exception as e:
        print(f"Error processing batch {batch['product_folder']}: {e}")

# Parallel execution
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    futures = [executor.submit(process_batch, batch) for batch in batches]
    concurrent.futures.wait(futures)

# End time
end_time = datetime.now()
duration = end_time - start_time
print(f"\nProcess completed at: {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Total time taken: {duration}")
