In [None]:
import os
import json
import time
import base64
import asyncio
from io import BytesIO
from pathlib import Path
from PIL import Image
import google.generativeai as genai
from openai import OpenAI
from google.colab import drive, userdata
drive.mount("/content/drive")
ROOT_DRIVE_PATH = Path("/content/drive/MyDrive/DSGA1011/Project")
SPATIAL_MM_ROOT = ROOT_DRIVE_PATH / "Spatial-MM"
IMAGES_ROOT = ROOT_DRIVE_PATH / "Spatial_MM-Benchmark"
COT_IMAGE_DIR = IMAGES_ROOT / "Spatial_MM_CoT"
OUTPUT_DIR = ROOT_DRIVE_PATH / "spatial_mm_outputs"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

GOOGLE_API_KEY = userdata.get("GOOGLE_API_KEY")
PENAI_API_KEY = userdata.get("OPENAI_API_KEY")

genai.configure(api_key=GOOGLE_API_KEY)
openai_client = OpenAI(api_key=OPENAI_API_KEY)




In [None]:
def encode_image_pil(img: Image.Image) -> dict:
    """Encode a PIL image as JPEG base64."""
    buffered = BytesIO()
    img.convert("RGB").save(buffered, format="JPEG")
    return {
        "mime": "image/jpeg",
        "base64": base64.b64encode(buffered.getvalue()).decode("utf-8"),
    }

def encode_image_path(img_path: Path) -> dict:
    img = Image.open(img_path).convert("RGB")
    return encode_image_pil(img)

def resize_and_encode_image_path(img_path: Path, scale: float) -> dict:
    img = Image.open(img_path).convert("RGB")
    new_size = (int(img.width * scale), int(img.height * scale))
    img = img.resize(new_size, Image.Resampling.LANCZOS)
    return encode_image_pil(img)


In [None]:
async def call_gemini_async(model_name: str, prompt: str, image_b64: dict, max_tokens: int) -> str:
    model = genai.GenerativeModel(model_name)
    img_bytes = base64.b64decode(image_b64["base64"])
    img = Image.open(BytesIO(img_bytes))

    def _call():
        resp = model.generate_content(
            [prompt, img],
            generation_config={"max_output_tokens": max_tokens, "temperature": 0.0},
        )
        return (resp.text or "").strip()

    loop = asyncio.get_running_loop()
    return await loop.run_in_executor(None, _call)

async def call_openai_async(model_name: str, prompt: str, image_b64: dict, max_tokens: int) -> str:
    image_url = f"data:{image_b64['mime']};base64,{image_b64['base64']}"
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {"type": "image_url", "image_url": {"url": image_url}},
            ],
        }
    ]
    if "gpt5" in model_name:
        kwargs = {"max_completion_tokens": max_tokens}
    else:
        kwargs = {"max_tokens": max_tokens}

    def _call():
        resp = openai_client.chat.completions.create(
            model=model_name,
            messages=messages,
            temperature=0.0,
            **kwargs,
        )
        return (resp.choices[0].message.content or "").strip()

    loop = asyncio.get_running_loop()
    return await loop.run_in_executor(None, _call)

MODEL_DISPATCH = {
    "gemini_flash":      lambda p, i, m: call_gemini_async("gemini-2.5-flash", p, i, m),
    "gemini_flash_lite": lambda p, i, m: call_gemini_async("gemini-2.5-flash-lite", p, i, m),
    "gpt4o":             lambda p, i, m: call_openai_async("gpt-4o", p, i, m),
    "gpt5_mini":         lambda p, i, m: call_openai_async("gpt-5-mini", p, i, m),
}

In [None]:
async def extract_key_objects_with_gemini(question: str) -> list:
    prompt = (
        "You are helping to solve visual multi-hop spatial reasoning questions.\n"
        "Given the question text below, identify the most important objects\n"
        "that must be located in the image to answer the question.\n\n"
        "Return ONLY a JSON array of strings (no explanation).\n\n"
        f"Question:\n{question}\n"
    )
    model = genai.GenerativeModel("gemini-2.5-flash")

    def _call():
        resp = model.generate_content(prompt)
        return (resp.text or "").strip()

    loop = asyncio.get_running_loop()
    raw = await loop.run_in_executor(None, _call)

    try:
        parsed = json.loads(raw)
        if isinstance(parsed, list):
            return [str(x).strip() for x in parsed if str(x).strip()]
    except Exception:
        pass

    raw = raw.replace("[", "").replace("]", "")
    parts = raw.replace("\n", ",").split(",")
    return [p.strip().strip('"').strip("'") for p in parts if p.strip()]


async def call_vision_model_for_bboxes(image: Image.Image, key_objects: list) -> str:
    prompt = (
        "Given the image and the list of object names below, return normalized bounding boxes.\n"
        "Output MUST be valid JSON in the following format:\n\n"
        "[\n"
        '  {"object": "<name>", "bbox": {"x": <float>, "y": <float>, "w": <float>, "h": <float>}},\n'
        "  ...\n"
        "]\n\n"
        "Coordinates x, y are the top-left corner, and w, h are width and height.\n"
        "All coordinates must be normalized to [0,1].\n"
        "If you cannot find an object, omit it from the list.\n\n"
        f"Objects: {key_objects}\n"
    )
    model = genai.GenerativeModel("gemini-2.5-flash")

    def _call():
        resp = model.generate_content([prompt, image])
        return (resp.text or "").strip()

    loop = asyncio.get_running_loop()
    return await loop.run_in_executor(None, _call)

def parse_bbox_response(raw: str) -> list:
    try:
        parsed = json.loads(raw)
    except Exception:
        return []
    if not isinstance(parsed, list):
        return []
    out = []
    for item in parsed:
        try:
            name = str(item["object"]).strip()
            box = item["bbox"]
            x, y = float(box["x"]), float(box["y"])
            w, h = float(box["w"]), float(box["h"])
            out.append({"object": name, "x": x, "y": y, "w": w, "h": h})
        except Exception:
            continue
    return out

In [None]:
async def generate_spatial_record(example: dict) -> dict:
    image_name = example["image_name"]
    question = example["question"]

    key_objects = await extract_key_objects_with_gemini(question)

    img_path = COT_IMAGE_DIR / image_name
    img = Image.open(img_path).convert("RGB")

    raw_bbox = await call_vision_model_for_bboxes(img, key_objects)
    bboxes = parse_bbox_response(raw_bbox)
    scene_graph_text = example.get("reasoning")

    rec = dict(example)
    rec["key_objects"] = key_objects
    rec["bboxes"] = bboxes
    rec["scene_graph_text"] = scene_graph_text
    return rec

async def build_spatial_dataset(dataset_name: str, concurrency: int = 10):
    if dataset_name != "spatial_cot_multihop":
        raise ValueError("Configured only for dataset_name='spatial_cot_multihop'.")

    src_path = SPATIAL_MM_ROOT / "data" / "multihop_reasoning_309.json"
    out_path = OUTPUT_DIR / f"{dataset_name}_spatial_data.json"

    with open(src_path, "r") as f:
        data = json.load(f)

    results = []
    processed_ids = set()
    if out_path.exists():
        with open(out_path, "r") as f:
            results = json.load(f)
        processed_ids = {r["image_name"] for r in results}

    to_process = [ex for ex in data if ex["image_name"] not in processed_ids]
    if not to_process:
        print(f"Spatial data already built at {out_path}")
        return

    sem = asyncio.Semaphore(concurrency)

    async def worker(example):
        async with sem:
            try:
                return await generate_spatial_record(example)
            except Exception as e:
                print(f"[SPATIAL FAIL] {example['image_name']}: {e}")
                return None

    tasks = [asyncio.create_task(worker(ex)) for ex in to_process]
    start = time.time()
    done = 0

    for task in asyncio.as_completed(tasks):
        rec = await task
        done += 1
        if rec is not None:
            results.append(rec)
        if done % 10 == 0:
            with open(out_path, "w") as f:
                json.dump(results, f, indent=2)
            print(f"[SPATIAL] {done}/{len(to_process)} processed...")

    with open(out_path, "w") as f:
        json.dump(results, f, indent=2)

    print(f"Saved spatial dataset to {out_path} in {time.time() - start:.1f}s")


In [None]:
PROMPT_CONFIGS = {
    "multihop_baseline_bbox": {
        "template": (
            "Given the bounding boxes below and the image, answer the question "
            "with a single word or short phrase.\n\n"
            "Bounding Boxes:\n"
            "{bboxes}\n\n"
            "Question: {question}\n\n"
            "Answer:"
        ),
        "max_tokens": 1000,
    },
    "multihop_baseline_scene_graph": {
        "template": (
            "Given the scene graph below and the image, answer the question "
            "with a single word or short phrase.\n\n"
            "Scene Graph (Reasoning):\n"
            "{scene_graph}\n\n"
            "Question: {question}\n\n"
            "Answer:"
        ),
        "max_tokens": 1000,
    },
}

In [None]:
def is_connection_error(err: Exception) -> bool:
    msg = str(err).lower()
    return (
        "connection reset by peer" in msg
        or "connection aborted" in msg
        or "remote end closed connection without response" in msg
        or "timed out" in msg
        or "temporarily unavailable" in msg
        or "429" in msg
        or "503" in msg
    )

async def query_spatial_model(dataset_name: str, prompt_key: str, model_key: str, concurrency: int = 5):
    if dataset_name != "spatial_cot_multihop":
        raise ValueError("Configured only for dataset_name='spatial_cot_multihop'.")

    in_path = OUTPUT_DIR / f"{dataset_name}_spatial_data.json"
    if not in_path.exists():
        raise FileNotFoundError(f"{in_path} not found. Run build_spatial_dataset first.")

    with open(in_path, "r") as f:
        data = json.load(f)

    out_path = OUTPUT_DIR / f"{dataset_name}_{prompt_key}_{model_key}.json"

    existing = []
    processed_ids = set()
    if out_path.exists():
        with open(out_path, "r") as f:
            existing = json.load(f)
        processed_ids = {r["image_name"] for r in existing if r.get("raw_response")}

    to_process = [ex for ex in data if ex["image_name"] not in processed_ids]
    if not to_process:
        print(f"Already finished {model_key} / {prompt_key} on {dataset_name}")
        return

    sem = asyncio.Semaphore(concurrency)
    model_fn = MODEL_DISPATCH[model_key]
    cfg = PROMPT_CONFIGS[prompt_key]
    max_tok = cfg["max_tokens"]

    async def worker(example: dict):
        image_name = example["image_name"]
        img_path = COT_IMAGE_DIR / image_name

        if prompt_key == "multihop_baseline_bbox":
            bboxes_str = json.dumps(example.get("bboxes", []), indent=2)
            prompt = cfg["template"].format(
                question=example["question"],
                bboxes=bboxes_str,
            )
        elif prompt_key == "multihop_baseline_scene_graph":
            scene_graph = example.get("scene_graph_text", "")
            prompt = cfg["template"].format(
                question=example["question"],
                scene_graph=scene_graph,
            )
        else:
            raise ValueError(f"Unknown prompt_key: {prompt_key}")

        max_retries = 5
        resize_scales = [1.0, 0.9, 0.8, 0.6, 0.5]

        for attempt in range(max_retries):
            async with sem:
                try:
                    if attempt == 0:
                        img_b64 = encode_image_path(img_path)
                    else:
                        scale = resize_scales[min(attempt, len(resize_scales) - 1)]
                        img_b64 = resize_and_encode_image_path(img_path, scale=scale)

                    resp = await model_fn(prompt, img_b64, max_tok)
                    rec = dict(example)
                    rec["raw_response"] = resp
                    return rec

                except Exception as e:
                    if is_connection_error(e) and attempt < max_retries - 1:
                        wait_t = 2 * (attempt + 1)
                        print(f"[{model_key}/{prompt_key}] {image_name} error: {e} -> retry in {wait_t}s")
                        await asyncio.sleep(wait_t)
                        continue
                    print(f"[{model_key}/{prompt_key}] {image_name} FAILED: {e}")
                    return None

    tasks = [asyncio.create_task(worker(ex)) for ex in to_process]
    results = list(existing)
    start = time.time()
    done = 0
    success = 0

    for task in asyncio.as_completed(tasks):
        rec = await task
        done += 1
        if rec is not None:
            results.append(rec)
            success += 1
        if success % 10 == 0:
            with open(out_path, "w") as f:
                json.dump(results, f, indent=2)
            print(f"[{model_key}/{prompt_key}] {success}/{done}/{len(to_process)} saved...")

    with open(out_path, "w") as f:
        json.dump(results, f, indent=2)

    print(
        f"Done {dataset_name} / {prompt_key} / {model_key}: "
        f"{success}/{len(to_process)} in {time.time() - start:.1f}s"
    )


In [None]:
running = [
    ("gpt4o",             1),
    ("gpt5_mini",         5),
    ("gemini_flash",     15),
    ("gemini_flash_lite",10),
]

prompts = [
    "multihop_baseline_bbox",
    "multihop_baseline_scene_graph",
]

async def main_runner():
    dataset_name = "spatial_cot_multihop"

    print(">>> Building spatial dataset (if needed)...")
    await build_spatial_dataset(dataset_name=dataset_name, concurrency=10)

    for model_key, limit in running:
        for prompt_key in prompts:
            print(f"\n>>> Starting {model_key} with {prompt_key}")
            await query_spatial_model(
                dataset_name=dataset_name,
                prompt_key=prompt_key,
                model_key=model_key,
                concurrency=limit,
            )

await main_runner()