In [None]:
# ============================================================
# Video RAG v1 (POC) — AWS S3 + Bedrock Titan (multimodal embeddings) + ChromaDB
# Optional: OpenAI Vision summaries of retrieved frames
#
# POC assumptions:
# - Only a few short videos (< 1 minute each)
# - You want to learn the workflow end-to-end
#
# COST CONTROLS BUILT IN:
# - Sample a frame every N seconds
# - Hard cap frames per video
# - Cache frames already embedded (avoid re-paying)
# ============================================================


In [None]:
!pip install -q boto3 pillow chromadb openai opencv-python-headless


In [None]:
import os
import io
import json
import base64
import hashlib
from pathlib import Path
from typing import List, Dict, Optional, Tuple

import boto3
from botocore.exceptions import ClientError, NoCredentialsError

import cv2
from PIL import Image, ImageOps
import matplotlib.pyplot as plt

import chromadb
from openai import OpenAI


In [None]:
# ⚠️ DO NOT commit these anywhere permanent
os.environ["AWS_ACCESS_KEY_ID"] = "<YOUR AWS ACCESS KEY>"
os.environ["AWS_SECRET_ACCESS_KEY"] = "<YOUR AWS SECRET KEY>"
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"  # change if needed
import os, getpass

os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API key: ")


In [None]:
# -----------------------------
# S3
# -----------------------------
BUCKET = "<YOUR S3 NAME>"
VIDEO_PREFIX = ""        # e.g., "poc/videos/" or "" if root
VIDEO_EXTS = (".mp4", ".mov", ".m4v", ".avi")

LOCAL_VIDEO_DIR = "data/videos"
LOCAL_FRAME_DIR = "data/frames"

# -----------------------------
# Bedrock
# -----------------------------
BEDROCK_REGION = os.environ.get("AWS_DEFAULT_REGION", "us-east-1")
TITAN_MODEL_ID = "amazon.titan-embed-image-v1"
EMBED_DIM = 1024

# -----------------------------
# Chroma
# -----------------------------
CHROMA_DIR = "chroma_video_db"
CHROMA_COLLECTION = "video_frames"

# -----------------------------
# COST / SAFETY KNOBS
# -----------------------------
SAMPLE_EVERY_SEC = 1.0          # for <1min videos, 1 frame/sec is fine
MAX_FRAMES_PER_VIDEO = 80       # hard cap
MAX_FRAMES_TOTAL_RUN = 300      # safety switch across all videos
MAX_FRAME_SIDE = 768            # resize frames before embedding (faster + more reliable)

# -----------------------------
# Optional OpenAI
# -----------------------------
OPENAI_VISION_MODEL = "gpt-4.1-mini"


In [None]:
def get_s3_client():
    return boto3.client("s3")

def list_video_keys(bucket: str, prefix: str, exts=VIDEO_EXTS, max_keys: Optional[int]=None) -> List[str]:
    s3 = get_s3_client()
    keys = []
    token = None

    try:
        while True:
            kwargs = {"Bucket": bucket, "Prefix": prefix}
            if token:
                kwargs["ContinuationToken"] = token
            resp = s3.list_objects_v2(**kwargs)

            for obj in resp.get("Contents", []):
                key = obj["Key"]
                if key.lower().endswith(tuple(e.lower() for e in exts)):
                    keys.append(key)
                    if max_keys and len(keys) >= max_keys:
                        return keys

            if resp.get("IsTruncated"):
                token = resp.get("NextContinuationToken")
            else:
                break
    except NoCredentialsError:
        raise RuntimeError("AWS credentials not found in this runtime.")
    except ClientError as e:
        raise RuntimeError(f"S3 list failed: {e}")

    return keys

def download_s3_video(bucket: str, key: str, local_dir: str = LOCAL_VIDEO_DIR) -> str:
    s3 = get_s3_client()
    local_dir = Path(local_dir)
    local_dir.mkdir(parents=True, exist_ok=True)

    local_path = local_dir / key
    local_path.parent.mkdir(parents=True, exist_ok=True)

    try:
        s3.download_file(bucket, key, str(local_path))
    except ClientError as e:
        raise RuntimeError(f"Download failed for {key}: {e}")

    return str(local_path)


In [None]:
def safe_mkdir(path: str):
    Path(path).mkdir(parents=True, exist_ok=True)

def extract_frames_opencv(
    video_path: str,
    out_dir: str,
    sample_every_sec: float = SAMPLE_EVERY_SEC,
    max_frames: int = MAX_FRAMES_PER_VIDEO
) -> List[Dict]:
    """
    Extract frames every N seconds and save as JPEG.
    Returns records: {frame_path, timestamp_sec, frame_index}
    """
    safe_mkdir(out_dir)

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Could not open video: {video_path}")

    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    step = max(int(fps * sample_every_sec), 1)

    records = []
    frame_idx = 0
    saved = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        if frame_idx % step == 0:
            timestamp_sec = frame_idx / fps
            out_path = Path(out_dir) / f"frame_{saved:06d}_{int(timestamp_sec*1000)}ms.jpg"

            rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            img = Image.fromarray(rgb)
            img.save(out_path, format="JPEG", quality=90, optimize=True)

            records.append({
                "frame_path": str(out_path),
                "timestamp_sec": float(timestamp_sec),
                "frame_index": int(frame_idx),
            })

            saved += 1
            if saved >= max_frames:
                break

        frame_idx += 1

    cap.release()
    return records


In [None]:
def normalize_image_to_jpeg_bytes(image_path: str, max_side: int = MAX_FRAME_SIDE, quality: int = 90) -> bytes:
    """
    Converts to clean RGB JPEG bytes, fixes orientation, resizes if large.
    Helps prevent Bedrock "Unable to process provided image".
    """
    with Image.open(image_path) as img:
        img = ImageOps.exif_transpose(img)
        if img.mode != "RGB":
            img = img.convert("RGB")

        w, h = img.size
        scale = max(w, h) / max_side
        if scale > 1:
            img = img.resize((int(w/scale), int(h/scale)))

        buf = io.BytesIO()
        img.save(buf, format="JPEG", quality=quality, optimize=True)
        return buf.getvalue()


In [None]:
def get_bedrock_runtime(region_name: str = BEDROCK_REGION):
    return boto3.client("bedrock-runtime", region_name=region_name)

def titan_embed_image_bytes(jpeg_bytes: bytes, output_dim: int = EMBED_DIM) -> List[float]:
    br = get_bedrock_runtime()
    body = json.dumps({
        "inputImage": base64.b64encode(jpeg_bytes).decode("utf-8"),
        "embeddingConfig": {"outputEmbeddingLength": output_dim}
    })
    resp = br.invoke_model(
        modelId=TITAN_MODEL_ID,
        body=body,
        accept="application/json",
        contentType="application/json"
    )
    data = json.loads(resp["body"].read())
    if data.get("message"):
        raise RuntimeError(data["message"])
    return data["embedding"]

def titan_embed_image_safe(image_path: str, output_dim: int = EMBED_DIM) -> List[float]:
    jpeg = normalize_image_to_jpeg_bytes(image_path)
    return titan_embed_image_bytes(jpeg, output_dim=output_dim)

def titan_embed_text(query: str, output_dim: int = EMBED_DIM) -> List[float]:
    br = get_bedrock_runtime()
    body = json.dumps({
        "inputText": query,
        "embeddingConfig": {"outputEmbeddingLength": output_dim}
    })
    resp = br.invoke_model(
        modelId=TITAN_MODEL_ID,
        body=body,
        accept="application/json",
        contentType="application/json"
    )
    data = json.loads(resp["body"].read())
    if data.get("message"):
        raise RuntimeError(data["message"])
    return data["embedding"]


In [None]:
def get_chroma_collection(persist_dir: str = CHROMA_DIR, collection_name: str = CHROMA_COLLECTION):
    client = chromadb.PersistentClient(path=persist_dir)
    return client.get_or_create_collection(name=collection_name)

def make_frame_id(video_key: str, timestamp_sec: float) -> str:
    raw = f"{video_key}|{int(timestamp_sec*1000)}"
    return hashlib.sha1(raw.encode("utf-8")).hexdigest()

def already_indexed(collection, frame_id: str) -> bool:
    try:
        got = collection.get(ids=[frame_id], include=[])
        return len(got["ids"]) > 0
    except Exception:
        return False


In [None]:
def index_video_to_chroma(
    collection,
    bucket: str,
    video_key: str,
    sample_every_sec: float = SAMPLE_EVERY_SEC,
    max_frames_per_video: int = MAX_FRAMES_PER_VIDEO,
    max_frames_total_run: int = MAX_FRAMES_TOTAL_RUN,
    upsert_batch_size: int = 16
) -> Dict:
    """
    Index one video into Chroma:
      - download video locally
      - extract frames every N seconds
      - embed frames with Titan
      - upsert into Chroma with timestamp metadata
    """
    # Download
    video_path = download_s3_video(bucket, video_key, local_dir=LOCAL_VIDEO_DIR)

    # Extract frames
    frame_dir = str(Path(LOCAL_FRAME_DIR) / video_key)
    frames = extract_frames_opencv(
        video_path=video_path,
        out_dir=frame_dir,
        sample_every_sec=sample_every_sec,
        max_frames=max_frames_per_video
    )

    stats = {
        "video_key": video_key,
        "video_path": video_path,
        "frames_extracted": len(frames),
        "frames_indexed": 0,
        "skipped_cached": 0,
        "failed": []
    }

    ids, embs, metas, docs = [], [], [], []
    remaining = max_frames_total_run  # global safety cap for this run

    for f in frames:
        if remaining <= 0:
            break

        frame_id = make_frame_id(video_key, f["timestamp_sec"])
        if already_indexed(collection, frame_id):
            stats["skipped_cached"] += 1
            continue

        try:
            emb = titan_embed_image_safe(f["frame_path"], output_dim=EMBED_DIM)
            ids.append(frame_id)
            embs.append(emb)
            metas.append({
                "video_bucket": bucket,
                "video_key": video_key,
                "timestamp_sec": f["timestamp_sec"],
                "frame_index": f["frame_index"],
                "frame_path": f["frame_path"],
            })
            docs.append("")

            stats["frames_indexed"] += 1
            remaining -= 1

        except Exception as e:
            stats["failed"].append({
                "frame_path": f["frame_path"],
                "timestamp_sec": f["timestamp_sec"],
                "error": str(e)
            })

        if len(ids) >= upsert_batch_size:
            collection.upsert(ids=ids, embeddings=embs, metadatas=metas, documents=docs)
            ids, embs, metas, docs = [], [], [], []

    if ids:
        collection.upsert(ids=ids, embeddings=embs, metadatas=metas, documents=docs)

    return stats


In [None]:
def query_video_frames(collection, text_query: str, k: int = 8) -> List[Dict]:
    q = titan_embed_text(text_query, output_dim=EMBED_DIM)
    res = collection.query(
        query_embeddings=[q],
        n_results=k,
        include=["metadatas", "distances"]
    )
    out = []
    for _id, dist, meta in zip(res["ids"][0], res["distances"][0], res["metadatas"][0]):
        out.append({"id": _id, "distance": dist, "metadata": meta})
    return out

def show_frame_hits(hits: List[Dict], n: int = 6):
    for h in hits[:n]:
        meta = h["metadata"]
        p = meta.get("frame_path")
        if not p or not Path(p).exists():
            continue

        img = Image.open(p)
        plt.figure()
        plt.imshow(img)
        plt.axis("off")
        plt.title(f"{Path(p).name} | dist={h['distance']:.4f}\n{meta.get('video_key')} @ {meta.get('timestamp_sec'):.2f}s")
        plt.show()


In [None]:
def image_to_b64(path: str) -> str:
    return base64.b64encode(Path(path).read_bytes()).decode("utf-8")

def summarize_frames_with_openai(text_query: str, hits: List[Dict], n_images: int = 4, model: str = OPENAI_VISION_MODEL) -> str:
    client = OpenAI()

    selected = hits[:n_images]
    index_lines = []
    content = [{"type": "text", "text": ""}]

    actual = []
    for i, h in enumerate(selected, start=1):
        meta = h["metadata"]
        p = meta.get("frame_path")
        if not p or not Path(p).exists():
            continue
        label = f"S{i}"
        index_lines.append(f"{label}: {meta.get('video_key')} @ {meta.get('timestamp_sec'):.2f}s | dist={h['distance']:.4f}")
        actual.append((label, p))

    for label, p in actual:
        content.append({
            "type": "image_url",
            "image_url": {"url": f"data:image/jpeg;base64,{image_to_b64(p)}"}
        })

    prompt = (
        f"User query: {text_query}\n\n"
        "Retrieved video frames index:\n" + "\n".join(index_lines) + "\n\n"
        "Instructions:\n"
        "1) For each frame S1..Sn, describe what you see.\n"
        "2) Explain briefly why it matches the query.\n"
        "3) Provide a short overall summary.\n"
        "CRITICAL: Reference frames using S1, S2, ... explicitly."
    )
    content[0]["text"] = prompt

    resp = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": "You are a concise video frame analyst."},
            {"role": "user", "content": content}
        ],
        temperature=0.2
    )
    return resp.choices[0].message.content


In [None]:
video_keys = list_video_keys(BUCKET, VIDEO_PREFIX, max_keys=3)
video_keys


In [None]:
collection = get_chroma_collection()

stats_all = []
for vk in video_keys:
    stats = index_video_to_chroma(collection, BUCKET, vk)
    stats_all.append(stats)
    print(json.dumps(stats, indent=2))

print("Chroma count:", collection.count())



In [None]:
hits = query_video_frames(collection, "T-Shirt", k=8)
show_frame_hits(hits, n=6)


In [None]:
print(summarize_frames_with_openai("T-Shirt", hits, n_images=4))
