In [None]:
# ============================================================
# Multimodal Image Search Notebook (Colab)
# ------------------------------------------------------------
# Goal:
#   - Download JPG/JPEG/PNG images from S3
#   - Create multimodal embeddings with Amazon Bedrock Titan Multimodal Embeddings (G1)
#   - Store vectors in ChromaDB
#   - Search images with natural language queries (e.g., "blue t-shirt")
#   - (Optional) Use OpenAI Vision to summarize retrieved images with S1..Sn references
#
# What you need beforehand:
#   - AWS credentials available in Colab (env vars recommended):
#       AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_DEFAULT_REGION
#   - IAM permissions:
#       s3:ListBucket, s3:GetObject on your bucket
#       bedrock:InvokeModel on the Titan embedding model
#   - OpenAI API key (optional for summaries):
#       OPENAI_API_KEY
# ============================================================


In [None]:
# --- Install dependencies (Colab) ---
!pip install -q boto3 pillow chromadb openai


In [None]:
# ============================================================
# 0) Imports + Global Settings
# ============================================================
import os
import io
import re
import json
import base64
from pathlib import Path
from typing import List, Dict, Tuple, Optional

import boto3
from botocore.exceptions import ClientError, NoCredentialsError

from PIL import Image, ImageOps
import matplotlib.pyplot as plt

import chromadb
from openai import OpenAI


In [None]:
# ⚠️ DO NOT commit these anywhere permanent
os.environ["AWS_ACCESS_KEY_ID"] = "<YOUR AWS ACCESS KEY>"
os.environ["AWS_SECRET_ACCESS_KEY"] = "<YOUR SECRET KEY>"
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"  # change if needed
import os, getpass

os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API key: ")

In [None]:
# ============================================================
# 0A) Configure your run
# ============================================================
# S3 location
BUCKET = "<YOUR BUCKET NAME>"
PREFIX = ""           # e.g., "photos/2024/" or "" for root
SAMPLE_N = 200        # start small; increase as you gain confidence

# Bedrock region (must match where you invoke Bedrock)
BEDROCK_REGION = os.environ.get("AWS_DEFAULT_REGION", "us-east-1")

# Titan Multimodal Embeddings G1 modelId for embeddings
TITAN_MODEL_ID = "amazon.titan-embed-image-v1"

# Chroma persistence
CHROMA_DIR = "chroma_images_db"
CHROMA_COLLECTION = "photo_library"

# Embedding dimension (Titan supports 256 / 384 / 1024; 1024 is a strong default)
EMBED_DIM = 1024

# Local storage for downloaded images
LOCAL_IMAGE_DIR = "data/images"

# Extensions to index from S3
IMAGE_EXTS = (".jpg", ".jpeg", ".png")


In [None]:
# ============================================================
# 1) S3 Helpers: list + download
# ============================================================
def get_s3_client(region_name: Optional[str] = None):
    """
    Creates an S3 client. boto3 will automatically use credentials from:
      - environment variables
      - ~/.aws/credentials (not typical in Colab)
      - IAM role (if running on AWS infra)
    """
    return boto3.client("s3", region_name=region_name) if region_name else boto3.client("s3")


def list_image_keys(
    bucket: str,
    prefix: str = "",
    exts: Tuple[str, ...] = IMAGE_EXTS,
    max_keys: Optional[int] = None
) -> List[str]:
    """
    List image object keys from S3 under a prefix (paginated).

    Usability tips:
      - Use prefix to narrow scope (much faster/cheaper).
      - Start with max_keys (e.g., 200) to validate before indexing everything.
    """
    s3 = get_s3_client()
    keys = []
    token = None

    try:
        while True:
            kwargs = {"Bucket": bucket, "Prefix": prefix}
            if token:
                kwargs["ContinuationToken"] = token

            resp = s3.list_objects_v2(**kwargs)

            for obj in resp.get("Contents", []):
                key = obj["Key"]
                if key.lower().endswith(tuple(e.lower() for e in exts)):
                    keys.append(key)
                    if max_keys and len(keys) >= max_keys:
                        return keys

            if resp.get("IsTruncated"):
                token = resp.get("NextContinuationToken")
            else:
                break

    except NoCredentialsError:
        raise RuntimeError("AWS credentials not found. Set env vars or configure auth.")
    except ClientError as e:
        raise RuntimeError(f"S3 list error: {e}")

    return keys


def download_s3_objects(bucket: str, keys: List[str], local_dir: str = LOCAL_IMAGE_DIR) -> List[str]:
    """
    Download S3 objects to a local folder.
    Recreates S3 'folders' as subfolders locally.

    Failure behavior:
      - Continues downloading even if some objects fail.
      - Prints failures.
    """
    s3 = get_s3_client()
    local_dir = Path(local_dir)
    local_dir.mkdir(parents=True, exist_ok=True)

    downloaded = []
    for key in keys:
        local_path = local_dir / key
        local_path.parent.mkdir(parents=True, exist_ok=True)
        try:
            s3.download_file(bucket, key, str(local_path))
            downloaded.append(str(local_path))
        except ClientError as e:
            print(f"[DOWNLOAD FAILED] {key} -> {e}")

    return downloaded


In [None]:
# ============================================================
# 2) Image normalization (recovery from failures)
# ------------------------------------------------------------
# Why this exists:
#   Bedrock may reject some images ("Unable to process provided image") due to:
#     - corruption/truncation
#     - unusual PNG chunks/profiles
#     - weird encodings
#
# We normalize to a clean RGB JPEG and optionally resize.
# ============================================================
def normalize_image_to_jpeg_bytes(image_path: str, max_side: int = 2048, quality: int = 90) -> bytes:
    """
    Load an image, fix EXIF orientation, convert to RGB, optionally resize,
    and return JPEG-encoded bytes.

    This dramatically reduces model ingestion failures.
    """
    p = Path(image_path)

    with Image.open(p) as img:
        # Fix camera rotation from EXIF
        img = ImageOps.exif_transpose(img)

        # Convert to RGB (drops alpha if present)
        if img.mode != "RGB":
            img = img.convert("RGB")

        # Resize if too large (keeps aspect ratio)
        w, h = img.size
        scale = max(w, h) / max_side
        if scale > 1:
            new_w = int(w / scale)
            new_h = int(h / scale)
            img = img.resize((new_w, new_h))

        buf = io.BytesIO()
        img.save(buf, format="JPEG", quality=quality, optimize=True)
        return buf.getvalue()


In [None]:
# ============================================================
# 3) Bedrock Titan Multimodal Embeddings (Image + Text)
# ============================================================
def get_bedrock_runtime(region_name: str = BEDROCK_REGION):
    """
    Bedrock Runtime client used for InvokeModel calls.
    """
    return boto3.client("bedrock-runtime", region_name=region_name)


def titan_embed_image_safe(
    image_path: str,
    output_dim: int = EMBED_DIM,
    region: str = BEDROCK_REGION,
    model_id: str = TITAN_MODEL_ID
) -> List[float]:
    """
    Create an embedding for an image using Titan Multimodal Embeddings G1.

    Uses normalization so most problematic images still work.
    If an image still fails, we raise with a helpful error.
    """
    br = get_bedrock_runtime(region)

    jpeg_bytes = normalize_image_to_jpeg_bytes(image_path)
    image_b64 = base64.b64encode(jpeg_bytes).decode("utf-8")

    body = json.dumps({
        "inputImage": image_b64,
        "embeddingConfig": {"outputEmbeddingLength": output_dim}
    })

    try:
        resp = br.invoke_model(
            modelId=model_id,
            body=body,
            accept="application/json",
            contentType="application/json"
        )
        data = json.loads(resp["body"].read())
        if data.get("message"):
            raise RuntimeError(data["message"])
        return data["embedding"]

    except ClientError as e:
        raise RuntimeError(f"Bedrock InvokeModel failed: {e}")


def titan_embed_text(
    query: str,
    output_dim: int = EMBED_DIM,
    region: str = BEDROCK_REGION,
    model_id: str = TITAN_MODEL_ID
) -> List[float]:
    """
    Create an embedding for a text query using the SAME Titan multimodal model.
    This is what enables text->image retrieval in one shared vector space.
    """
    br = get_bedrock_runtime(region)

    body = json.dumps({
        "inputText": query,
        "embeddingConfig": {"outputEmbeddingLength": output_dim}
    })

    try:
        resp = br.invoke_model(
            modelId=model_id,
            body=body,
            accept="application/json",
            contentType="application/json"
        )
        data = json.loads(resp["body"].read())
        if data.get("message"):
            raise RuntimeError(data["message"])
        return data["embedding"]

    except ClientError as e:
        raise RuntimeError(f"Bedrock InvokeModel failed: {e}")


In [None]:
# ============================================================
# 4) ChromaDB setup + ingest (upsert) with failure recovery
# ============================================================
def get_chroma_collection(persist_dir: str = CHROMA_DIR, collection_name: str = CHROMA_COLLECTION):
    """
    Creates/loads a persistent Chroma collection.
    - Persistent storage: survives notebook runtime restarts if the directory remains.
    - In Colab, runtime resets wipe disk; but within the session it's persisted.
    """
    client = chromadb.PersistentClient(path=persist_dir)
    return client.get_or_create_collection(name=collection_name)


def build_image_records(image_keys: List[str], local_paths: List[str]) -> List[Dict]:
    """
    Build stable records:
      - id: use S3 key (best for re-indexing / updates)
      - metadata: includes s3_key and filename
      - path: local file path
    Assumes lists are aligned (image_keys[i] matches local_paths[i]).
    """
    records = []
    for key, path in zip(image_keys, local_paths):
        records.append({
            "id": key,
            "path": path,
            "metadata": {
                "s3_key": key,
                "filename": Path(path).name
            }
        })
    return records


def upsert_images_to_chroma(
    collection,
    records: List[Dict],
    output_dim: int = EMBED_DIM,
    batch_size: int = 8,
    show_progress_every: int = 10
) -> List[Tuple[str, str, str]]:
    """
    Create embeddings for images and upsert them into Chroma.

    Failure behavior:
      - Skips images that fail embedding
      - Keeps going
      - Returns a list of (id, path, error) for inspection/retry

    Usability tips:
      - Start small (SAMPLE_N=200)
      - Then scale to thousands once stable
    """
    failed = []
    stored_before = collection.count()

    for batch_idx, i in enumerate(range(0, len(records), batch_size), start=1):
        batch = records[i:i + batch_size]

        ids, metas, docs, embs = [], [], [], []

        for r in batch:
            try:
                emb = titan_embed_image_safe(r["path"], output_dim=output_dim)
                ids.append(r["id"])
                metas.append(r["metadata"])
                docs.append("")  # no text doc needed for pure image search
                embs.append(emb)
            except Exception as e:
                failed.append((r["id"], r["path"], str(e)))

        # Only upsert if at least one succeeded
        if ids:
            collection.upsert(ids=ids, metadatas=metas, documents=docs, embeddings=embs)

        if (batch_idx % show_progress_every) == 0:
            print(f"[PROGRESS] batches={batch_idx} stored_now={collection.count()} failed={len(failed)}")

    stored_after = collection.count()
    print(f"[DONE] stored_before={stored_before} stored_after={stored_after} newly_added={stored_after - stored_before} failed={len(failed)}")
    return failed


In [None]:
# ============================================================
# 5) Query (text -> image retrieval) + display results
# ============================================================
def query_images(collection, text_query: str, k: int = 8, output_dim: int = EMBED_DIM) -> List[Dict]:
    """
    Embed the text query with Titan -> vector search in Chroma -> return top-k hits.
    """
    q_emb = titan_embed_text(text_query, output_dim=output_dim)
    res = collection.query(
        query_embeddings=[q_emb],
        n_results=k,
        include=["metadatas", "distances"]  # ids always returned by Chroma query
    )
    return [
        {"id": _id, "distance": dist, "metadata": meta}
        for _id, dist, meta in zip(res["ids"][0], res["distances"][0], res["metadatas"][0])
    ]


def show_hit_images(hits: List[Dict], s3_to_local: Dict[str, str], n: int = 6):
    """
    Display top N hits in Colab with their distance and filename.
    """
    for h in hits[:n]:
        key = h["metadata"]["s3_key"]
        path = s3_to_local.get(key)
        if not path:
            continue

        img = Image.open(path)
        plt.figure()
        plt.imshow(img)
        plt.axis("off")
        title = f"{Path(path).name} | dist={h['distance']:.4f}\n{s3_key_to_display(BUCKET, key)}"
        plt.title(title)
        plt.show()


def s3_key_to_display(bucket: Optional[str], key: str) -> str:
    """
    Safe display string. Not a public URL (unless your bucket is public).
    """
    return f"s3://{bucket}/{key}" if bucket else key


In [None]:
# ============================================================
# 6) (Optional) OpenAI Vision "G" step: summarize retrieved images
# ------------------------------------------------------------
# Goal:
#   - Retrieve top-k images (R)
#   - Show top N with stable labels S1..Sn
#   - Send those same images to OpenAI vision model
#   - Model returns summary referencing S1..Sn so you can visually validate
#
# Requirements:
#   - OPENAI_API_KEY set
# ============================================================
def image_to_b64(path: str) -> str:
    return base64.b64encode(Path(path).read_bytes()).decode("utf-8")


def display_hit_images_labeled(
    hits: List[Dict],
    s3_to_local: Dict[str, str],
    n_images: int = 4,
    bucket: Optional[str] = None
) -> List[Dict]:
    """
    Display images in the same order we will send them to the model.
    Returns labeled metadata so summaries can reference S1..Sn reliably.
    """
    labeled = []
    for i, h in enumerate(hits[:n_images], start=1):
        key = h["metadata"]["s3_key"]
        path = s3_to_local.get(key)
        if not path:
            continue

        label = f"S{i}"
        filename = h["metadata"].get("filename", Path(path).name)
        dist = h["distance"]

        labeled.append({
            "label": label,
            "path": path,
            "filename": filename,
            "s3_key": key,
            "distance": dist,
            "s3_uri": s3_key_to_display(bucket, key)
        })

        img = Image.open(path)
        plt.figure()
        plt.imshow(img)
        plt.axis("off")
        plt.title(f"{label} | {filename} | dist={dist:.4f}\n{s3_key_to_display(bucket, key)}")
        plt.show()

    return labeled


def summarize_images_with_openai(
    user_query: str,
    labeled_images: List[Dict],
    model: str = "gpt-4.1-mini"
) -> str:
    """
    Send labeled images to OpenAI Vision and request a per-image description
    that references labels S1..Sn.
    """
    client = OpenAI()

    # Build a small index table for traceability
    index_lines = [
        f"{x['label']}: {x['filename']} | {x['s3_key']} | dist={x['distance']:.4f}"
        for x in labeled_images
    ]

    prompt_text = (
        f"User query: {user_query}\n\n"
        "Retrieved images index:\n" + "\n".join(index_lines) + "\n\n"
        "Instructions:\n"
        "1) For each image S1..Sn, describe what you see in 1–2 sentences.\n"
        "2) For each, explain briefly why it matches (or doesn't match) the query.\n"
        "3) End with a short overall summary.\n"
        "CRITICAL: Use the exact labels S1, S2, ... in your output."
    )

    # OpenAI expects a single user message with mixed content blocks
    content = [{"type": "text", "text": prompt_text}]

    # Attach images in the same order as labels
    for item in labeled_images:
        b64 = image_to_b64(item["path"])
        content.append({
            "type": "image_url",
            "image_url": {"url": f"data:image/jpeg;base64,{b64}"}
        })

    resp = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": "You are a concise visual cataloging assistant."},
            {"role": "user", "content": content}
        ],
        temperature=0.2
    )
    return resp.choices[0].message.content


In [None]:
# ============================================================
# 7) RUN: Step-by-step execution
# ============================================================

# 7A) List & download images from S3 (start small)
image_keys = list_image_keys(BUCKET, prefix=PREFIX, max_keys=SAMPLE_N)
print("Found image keys:", len(image_keys))
print("Sample keys:", image_keys[:5])

local_images = download_s3_objects(BUCKET, image_keys, local_dir=LOCAL_IMAGE_DIR)
print("Downloaded local files:", len(local_images))
print("Sample local paths:", local_images[:3])

# Build mapping for display later
s3_to_local = {k: p for k, p in zip(image_keys, local_images)}


In [None]:
# 7B) Quick sanity test: embed one image + embed one text query
vec_img = titan_embed_image_safe(local_images[0], output_dim=EMBED_DIM)
print("Image embedding length:", len(vec_img), "| first 5:", vec_img[:5])

vec_txt = titan_embed_text("t-shirt", output_dim=EMBED_DIM)
print("Text embedding length:", len(vec_txt), "| first 5:", vec_txt[:5])


In [None]:
# 7C) Create/load Chroma collection and upsert embeddings
collection = get_chroma_collection(persist_dir=CHROMA_DIR, collection_name=CHROMA_COLLECTION)

records = build_image_records(image_keys, local_images)
failed = upsert_images_to_chroma(collection, records, output_dim=EMBED_DIM, batch_size=8)

print("Chroma count:", collection.count())
print("Failed examples:", failed[:3])


[DONE] stored_before=65 stored_after=65 newly_added=0 failed=0
Chroma count: 65
Failed examples: []


In [None]:
# 7D) Query and display results
query = "blue t-shirt"
hits = query_images(collection, query, k=8, output_dim=EMBED_DIM)

print("Top hits:")
for h in hits[:5]:
    print(f"- dist={h['distance']:.4f} | {h['metadata']['filename']} | {h['metadata']['s3_key']}")

show_hit_images(hits, s3_to_local, n=6)


In [None]:
# 7E) OPTIONAL: OpenAI Vision summary (Generation on top of Retrieval)
# - Make sure OPENAI_API_KEY is set in your environment.
# - This will display images labeled S1..Sn and then print a summary referencing S1..Sn.

# Example:
labeled = display_hit_images_labeled(hits, s3_to_local, n_images=4, bucket=BUCKET)
summary = summarize_images_with_openai(user_query=query, labeled_images=labeled, model="gpt-4.1-mini")
print(summary)


In [None]:
# ============================================================
# 8) Operational utilities (optional but useful)
# ============================================================
def retry_failed_images(collection, failed_list, output_dim=EMBED_DIM, batch_size=4):
    """
    Retry embedding only the previously failed images.
    Useful after you fix corrupted downloads or adjust normalization.
    """
    # Rebuild minimal records from failures (id, path)
    retry_records = []
    for _id, path, _err in failed_list:
        retry_records.append({
            "id": _id,
            "path": path,
            "metadata": {
                "s3_key": _id,
                "filename": Path(path).name
            }
        })
    return upsert_images_to_chroma(collection, retry_records, output_dim=output_dim, batch_size=batch_size)


def show_failed_images(failed_list, n=5):
    """
    Display a few failed images (if Pillow can open them).
    If they are corrupted, Pillow may fail too.
    """
    for _id, path, err in failed_list[:n]:
        print("----")
        print("ID:", _id)
        print("Path:", path)
        print("Error:", err)
        try:
            img = Image.open(path)
            plt.figure()
            plt.imshow(img)
            plt.axis("off")
            plt.title(Path(path).name)
            plt.show()
        except Exception as e:
            print("Could not display image:", e)

# Example usage:
# show_failed_images(failed, n=5)
# failed_retry = retry_failed_images(collection, failed)
