In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Imagen Product Recontext - Evaluation at Scale

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/vertex-ai-creative-studio/blob/main/experiments/Imagen_Product_Recontext/evaluation_imagen_product_recontext_at_scale.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%vertex-ai-creative-studio%2Fmain%2Fexperiments%2FImagen_Product_Recontext%2Fevaluation_imagen_product_recontext_at_scale.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-creative-studio/main/experiments/Imagen_Product_Recontext/evaluation_imagen_product_recontext_at_scale.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/bigquery/import?url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-creative-studio/main/experiments/Imagen_Product_Recontext/evaluation_imagen_product_recontext_at_scale.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/bigquery/v1/32px.svg" alt="BigQuery Studio logo"><br> Open in BigQuery Studio
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-creative-studio/blob/main/experiments/Imagen_Product_Recontext/evaluation_imagen_product_recontext_at_scale.ipynb">
      <img width="32px" src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

<div style="clear: both;"></div>

| | |
|-|-|
|Author(s) | [Layolin Jesudhass](https://github.com/LUJ20), Isidro De Loera

## Imports

In [12]:
#!pip install --upgrade google-genai google-cloud-storage ipywidgets

from google.colab import auth
from google import genai
from google.genai import types
from google.cloud import storage
import os, json, re
from json import JSONDecodeError
import pandas as pd
import tempfile
from IPython.display import display
import ipywidgets as widgets
import matplotlib.pyplot as plt
from PIL import Image

In [13]:
# --------------------------------------------------
# Authenticate your Colab session with GCP
from google.colab import auth
auth.authenticate_user()
print("Authenticated with Google Cloud")
# --------------------------------------------------


Authenticated with Google Cloud


## Initialize Vertex AI Client

In [14]:
# ─── Authenticate & Initialize GenAI Client
auth.authenticate_user()
print("Authenticated with Google Cloud")

PROJECT_ID    = "consumer-genai-experiments"
LOCATION      = "global"
BUCKET_NAME   = "id_test_bucket"
INPUT_PREFIX  = "cymbal_retail/product_images_input/"
OUTPUT_PREFIX = "cymbal_retail/product_images_output/"

# GenAI client
client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION)
print("GenAI client ready")

# Storage client
storage_client = storage.Client()
print("Storage client ready")


Authenticated with Google Cloud
GenAI client ready
Storage client ready


## Helper Functions


In [15]:
# ─── System Instruction & User Prompt
si_text1 = """You are an expert visual quality evaluator for AI-generated e-commerce images.
Your task is to compare three original product images (studio-style, white background)
to one AI-generated output image (recontextualized lifestyle photo).
Your goal is to assess if the output image faithfully and attractively represents the product
while adhering to commercial quality and brand standards.
Do not assume any information beyond what is visible in the images.

For each of the six evaluation dimensions, assign a score from 1 to 5 and provide a short justification.

Finally, calculate an overall quality score based on your judgment of the image's commercial viability.
This score should summarize the output's overall fitness for use in real-world e-commerce listings
(e.g., Cymbal Retail product pages). Return this score under the `overall_score` key in the JSON.

Use your understanding of visual coherence, aesthetic judgment, and product realism to make your assessments."""

msg1_text1 = types.Part.from_text(text="""You will be shown up to 4 images:
• up to 3 input product photos
• 1 AI-generated lifestyle image of the same product (filename contains "_output.jpg")

Your task is to evaluate the output image across the following 6 dimensions.
Each dimension should be scored on a scale from 1 to 5, where:
- 5 = Excellent
- 4 = Good
- 3 = Acceptable
- 2 = Poor
- 1 = Unacceptable

For each dimension, explain your score with 1–2 sentences of justification.

### Evaluation Dimensions:

1. **Product Fidelity** – Does the product in the output match the shape, color, texture, and identity seen in the input images?
2. **Scene Realism** – Does the background setting make physical and spatial sense? Are lighting and shadows natural?
3. **Aesthetic Quality** – Is the image visually appealing? Consider composition, balance, lighting, and professional polish.
4. **Brand Integrity** – Are any visible logos, labels, or branding preserved, undistorted, and realistic?
5. **Policy Compliance** – Does the image follow Cymbal Retail content policies (no people, kids, unsafe objects, political/religious content)?
6. **Imaging Quality** – Is the image sharp, high-resolution, and free from noise, blurs, or compression artifacts?

Please return the results in the following JSON format:

{
  "product_fidelity": { "score": X, "comment": "..." },
  "scene_realism":   { "score": X, "comment": "..." },
  "aesthetic_quality": { "score": X, "comment": "..." },
  "brand_integrity":   { "score": X, "comment": "..." },
  "policy_compliance": { "score": X, "comment": "..." },
  "imaging_quality":   { "score": X, "comment": "..." },
  "overall_score":     { "score": X, "comment": "..." }
}

Guidelines for overall_score:
This should reflect the lowest common denominator (e.g., an otherwise perfect image
with policy violations would get a lower overall).
Use your judgment, not just the numeric average — it's OK to weight fidelity or
compliance more heavily than, say, aesthetic.
""")

generate_config = types.GenerateContentConfig(
    temperature=0.1,
    top_p=0.95,
    seed=0,
    max_output_tokens=65535,
    system_instruction=[types.Part.from_text(text=si_text1)],
    thinking_config=types.ThinkingConfig(thinking_budget=0),
    safety_settings=[
        types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH",      threshold="OFF"),
        types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
        types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
        types.SafetySetting(category="HARM_CATEGORY_HARASSMENT",        threshold="OFF"),
    ]
)

# ─── Helpers ─────────────────────────────────────────────────────────────────
def strip_code_fences(text: str) -> str:
    m = re.search(r"```(?:json)?\s*\n([\s\S]*?)```", text)
    return m.group(1) if m else text

def make_part(uri: str) -> types.Part:
    ext = os.path.splitext(uri)[1].lower()
    mime = "image/png" if ext == ".png" else "image/jpeg"
    return types.Part(file_data=types.FileData(file_uri=uri, mime_type=mime))

def list_product_folders(bucket_name: str, prefix: str) -> list[str]:
    blobs = storage_client.list_blobs(bucket_name, prefix=prefix)
    prods = {b.name[len(prefix):].split("/",1)[0] for b in blobs if "/" in b.name[len(prefix):]}
    return sorted(prods)

def get_image_uris(bucket_name: str, prefix: str, max_images: int = 3) -> list[str]:
    exts = {'.jpg','.jpeg','.png','.bmp','.gif','.webp'}
    files = [b.name for b in storage_client.list_blobs(bucket_name, prefix=prefix)
             if os.path.splitext(b.name)[1].lower() in exts]
    return [f"gs://{bucket_name}/{n}" for n in sorted(files)[:max_images]]

def find_output_uri(bucket_name: str, product: str) -> str:
    exts = {'.jpg','.jpeg','.png','.bmp','.gif','.webp'}
    prefix = f"{OUTPUT_PREFIX}{product}/"
    blobs = storage_client.list_blobs(bucket_name, prefix=prefix)
    candidates = [
        b.name for b in blobs
        if 'output' in b.name.lower() and os.path.splitext(b.name)[1].lower() in exts
    ]
    if not candidates:
        raise FileNotFoundError(f"No output image under gs://{bucket_name}/{prefix}")
    chosen = sorted(candidates)[0]
    return f"gs://{bucket_name}/{chosen}"

# ─── Core Eval ───────────────────────────────────────────────────────────────
def generate(input_uris, output_uri) -> str:
    parts = [msg1_text1] + [make_part(u) for u in input_uris] + [make_part(output_uri)]
    full = ""
    for chunk in client.models.generate_content_stream(
        model="gemini-2.5-flash",
        contents=[types.Content(role="user", parts=parts)],
        config=generate_config
    ):
        full += chunk.text
    return full

def evaluate_product(product: str) -> dict:
    in_pref    = f"{INPUT_PREFIX}{product}/"
    inputs     = get_image_uris(BUCKET_NAME, in_pref)
    output_uri = find_output_uri(BUCKET_NAME, product)
    print(f"  • input URIs: {inputs}")
    print(f"  • output URI: {output_uri}")

    raw = generate(inputs, output_uri)
    if not raw.strip():
        raise ValueError(f"Empty response for {product}")

    clean = strip_code_fences(raw).strip()
    try:
        return json.loads(clean)
    except JSONDecodeError:
        print(f"Raw text for {product}:\n{raw}\n")
        raise


# Sequential Run

In [16]:
#sequential Run

from datetime import datetime

start_time = datetime.now()
print(f"🔍 Scanning for product folders… (Started at {start_time.strftime('%Y-%m-%d %H:%M:%S')})")

products = list_product_folders(BUCKET_NAME, INPUT_PREFIX)
#print(f"Found {len(products)} products: {products}")

all_results = {}
for p in products:
    #print(f"\nEvaluating {p} …")
    try:
        res = evaluate_product(p)
        all_results[p] = res
        #print(f"JSON parsed: {json.dumps(res, indent=2)}")
        # Save per-product JSON
        json_path = f"{OUTPUT_PREFIX}{p}/{p}_evaluation.json"
        storage_client.bucket(BUCKET_NAME).blob(json_path).upload_from_string(
            json.dumps(res), content_type='application/json')
        #print(f"Saved JSON to {json_path}")
    except Exception as err:
        print(f" Failed to evaluate {p}: {err}")

end_time = datetime.now()
print(f"\nCompleted at {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Total time taken: {end_time - start_time}")


# Parallel Run

In [17]:
#Parallel Run
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed

start_time = datetime.now()
print(f"Scanning for product folders… (Started at {start_time.strftime('%Y-%m-%d %H:%M:%S')})")

products = list_product_folders(BUCKET_NAME, INPUT_PREFIX)
#print(f"Found {len(products)} products: {products}")

MAX_WORKERS = 8 #Change the number as needed

all_results = {}

def process_product(p):
    try:
        #print(f"Evaluating {p} …")
        res = evaluate_product(p)
        json_str = json.dumps(res, indent=2)
        #print(f"JSON parsed for {p}:\n{json_str}")

        json_path = f"{OUTPUT_PREFIX}{p}/{p}_evaluation.json"
        storage_client.bucket(BUCKET_NAME).blob(json_path).upload_from_string(
            json_str, content_type='application/json')
        #print(f"Saved JSON to {json_path}")
        return (p, res)
    except Exception as err:
        print(f"Failed to evaluate {p}: {err}")
        return (p, None)

# Adjust max_workers based on I/O load and system limits
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    future_to_product = {executor.submit(process_product, p): p for p in products}
    for future in as_completed(future_to_product):
        p, result = future.result()
        if result is not None:
            all_results[p] = result

end_time = datetime.now()
print(f"\nCompleted at {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Total time taken: {end_time - start_time}")




In [18]:
# Sample stats
# Evaluation results:
# Sequential : Total time taken: 0:01:23.264453
# Parallel : 2 threads : Total time taken: 0:00:38.668612
# Parallel : 4 threads : Total time taken: 0:00:20.207010
# Parallel : 8 threads : Total time taken: 0:00:11.559379

# Tabulate & Save Summary


In [19]:
# ─── Tabulate & Save Summary ─────────────────────────────────────────────────
rows = []
for prod, metrics in all_results.items():
    input_product_uri  = f"gs://{BUCKET_NAME}/{INPUT_PREFIX}{prod}/"
    output_product_uri = find_output_uri(BUCKET_NAME, prod)
    overall = metrics.get("overall_score", {}).get("score")
    comment = metrics.get("overall_score", {}).get("comment")
    if overall is None or comment is None:
        raise KeyError(f"JSON for {prod} missing overall_score fields")
    row = {
        "product": prod,
        "input_product_uri":  input_product_uri,
        "output_product_uri": output_product_uri,
        "overall_score": overall,
        "overall_comment": comment
    }
    for dim, info in metrics.items():
        if dim == "overall_score": continue
        row[f"{dim}_score"]   = info["score"]
        row[f"{dim}_comment"] = info["comment"]
    rows.append(row)

# DataFrame & CSV
df = pd.DataFrame(rows)
summary_path = f"{OUTPUT_PREFIX}evaluation_summary.csv"
storage_client.bucket(BUCKET_NAME).blob(summary_path).upload_from_string(
    df.to_csv(index=False), content_type='text/csv')
print(f"Saved evaluation summary CSV to {summary_path}")


# Threshold Based Image Viewer

In [20]:
# ─── Threshold-Based Image Viewer
def display_product_images(product: str):
    input_uris = get_image_uris(BUCKET_NAME, f"{INPUT_PREFIX}{product}/")
    output_uri = find_output_uri(BUCKET_NAME, product)
    local_paths = []
    for uri in input_uris + [output_uri]:
        key = uri.replace(f"gs://{BUCKET_NAME}/", "")
        local = os.path.join(tempfile.gettempdir(), os.path.basename(key))
        storage_client.bucket(BUCKET_NAME).blob(key).download_to_filename(local)
        local_paths.append(local)
    imgs = [Image.open(p) for p in local_paths]
    fig, axes = plt.subplots(1, len(imgs), figsize=(5*len(imgs),5))
    for ax, im in zip(axes, imgs):
        ax.imshow(im); ax.axis('off')
    plt.show()

def review_threshold(threshold: float = 4.0):
    if 'overall_score' not in df.columns:
        raise RuntimeError("`df` missing `overall_score`.")
    filtered = df[df['overall_score'] <= threshold]
    prods = filtered['product'].tolist()
    if not prods:
        print(f"No products <= {threshold}")
        return

    slider = widgets.IntSlider(min=0, max=len(prods)-1, description='Index', continuous_update=False)
    out = widgets.Output()
    def on_change(change):
        with out:
            out.clear_output(wait=True)
            idx = change['new']
            prod = prods[idx]
            score = df.loc[df['product']==prod,'overall_score'].iloc[0]
            comment = df.loc[df['product']==prod,'overall_comment'].iloc[0]
            print(f"Product {idx+1}/{len(prods)}: {prod} (score={score:.3f})")
            print(f"Comment: {comment}")
            display_product_images(prod)
    slider.observe(on_change, names='value')
    display(slider, out)
    on_change({'new': 0})


In [21]:
review_threshold(4.0)

IntSlider(value=0, continuous_update=False, description='Index', max=11)

Output()