In [10]:
import os
import json
import base64
import asyncio
import time
import pandas as pd
from pathlib import Path
from typing import List, Dict, Any, Callable
from io import BytesIO
from PIL import Image
from tqdm.notebook import tqdm

from google.colab import drive
from google.colab import userdata
from openai import OpenAI
import google.generativeai as genai

ROOT_DRIVE_PATH = "/content/drive/MyDrive/DSGA1011/Project"
drive.mount('/content/drive')
SPATIAL_MM_ROOT = Path('/content/Spatial-MM')
IMAGES_ROOT = Path(ROOT_DRIVE_PATH) / 'Spatial_MM-Benchmark'
COT_IMAGE_DIR = IMAGES_ROOT / 'Spatial_MM_CoT'
OUTPUT_DIR = Path(ROOT_DRIVE_PATH) / 'spatial_mm_outputs'
os.makedirs(OUTPUT_DIR, exist_ok=True)

if not SPATIAL_MM_ROOT.exists():
    os.system(f'git clone https://github.com/FatemehShiri/Spatial-MM.git {SPATIAL_MM_ROOT}')

try:
    GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
    OPENAI_API_KEY = userdata.get('OPENAI_API_KEY')
    genai.configure(api_key=GOOGLE_API_KEY)
    openai_client = OpenAI(api_key=OPENAI_API_KEY)
except Exception as e:
    print(f"Error loading API keys: {e}")


Mounted at /content/drive


In [None]:
async def call_gemini_async(model_name, prompt, image_b64, max_tokens):
    model = genai.GenerativeModel(model_name)
    img_bytes = base64.b64decode(image_b64["base64"])
    img = Image.open(BytesIO(img_bytes))

    def _call():
        return model.generate_content(
            [prompt, img],
            generation_config={"max_output_tokens": max_tokens, "temperature": 0}
        ).text.strip()

    loop = asyncio.get_running_loop()
    return await loop.run_in_executor(None, _call)

async def call_openai_async(model_name, prompt, image_b64, max_tokens):
    image_url = f"data:{image_b64['mime']};base64,{image_b64['base64']}"
    messages = [
        {"role": "user", "content": [
            {"type": "text", "text": prompt},
            {"type": "image_url", "image_url": {"url": image_url}}
        ]}
    ]

    if "gpt5" in model_name:
        kwargs = {"max_completion_tokens": max_tokens}
    else:
        kwargs = {"max_tokens": max_tokens}

    loop = asyncio.get_running_loop()
    resp = await loop.run_in_executor(None, lambda: openai_client.chat.completions.create(
        model=model_name, messages=messages, **kwargs
    ))
    return resp.choices[0].message.content.strip()

MODEL_DISPATCH = {
    "gemini_flash":      lambda p, i, m: call_gemini_async("gemini-2.5-flash", p, i, m),
    "gemini_flash_lite": lambda p, i, m: call_gemini_async("gemini-2.5-flash-lite", p, i, m),
    "gpt4o":             lambda p, i, m: call_openai_async("gpt-4o", p, i, m),
    "gpt5_mini":        lambda p, i, m: call_openai_async("gpt-5-mini", p, i, m),
}


In [None]:
def encode_image(img_path):
    try:
        img = Image.open(img_path).convert("RGB")
        buffered = BytesIO()
        img.save(buffered, format="JPEG")
        return {"mime": "image/jpeg", "base64": base64.b64encode(buffered.getvalue()).decode("utf-8")}
    except Exception as e:
        raise ValueError(f"Failed to encode image {img_path}: {e}")

def resize_and_encode_image(img_path, scale=0.8):
    try:
        img = Image.open(img_path).convert("RGB")
        new_size = (int(img.width * scale), int(img.height * scale))
        img = img.resize(new_size, Image.Resampling.LANCZOS)
        buffered = BytesIO()
        img.save(buffered, format="JPEG")
        return {"mime": "image/jpeg", "base64": base64.b64encode(buffered.getvalue()).decode("utf-8")}
    except Exception as e:
        raise ValueError(f"Failed to resize {img_path}: {e}")

async def process_dataset(dataset_name, prompt_key, model_key, concurrency):
    if dataset_name == "spatial_mm_one_obj":
      json_path = SPATIAL_MM_ROOT / 'data' / 'spatial_mm_one_obj.json'
    elif dataset_name == "spatial_mm_two_obj":
      json_path = SPATIAL_MM_ROOT / 'data' / 'spatial_mm_two_obj.json'
    elif dataset_name == "multihop_reasoning_309":
      json_path = SPATIAL_MM_ROOT / 'data' / 'multihop_reasoning_309.json'
    with open(json_path, 'r') as f:
        data = json.load(f)

    out_path = OUTPUT_DIR / f"{dataset_name}_{prompt_key}_{model_key}.json"

    existing_results = []
    if out_path.exists():
        with open(out_path, 'r') as f:
            existing_results = json.load(f)
    existing_ids = {x['image_name'] for x in existing_results if x.get('raw_response')}

    to_process = [d for d in data if d['image_name'] not in existing_ids]
    if not to_process:
        return

    print(f"Processing {len(to_process)} items for {model_key}...")

    semaphore = asyncio.Semaphore(concurrency)
    model_fn = MODEL_DISPATCH[model_key]
    prompt_tmpl = PROMPT_CONFIGS[prompt_key]["template"]
    max_tok = PROMPT_CONFIGS[prompt_key]["max_tokens"]

    def is_connection_error(err: Exception) -> bool:
        msg = str(err).lower()
        return (
            "connection reset by peer" in msg
            or "connection aborted" in msg
            or "remote end closed connection without response" in msg
            or "timed out" in msg
            or "temporarily unavailable" in msg
            or "429" in msg
            or "503" in msg
        )

    async def worker(entry):
        async with semaphore:
            image_name = entry.get('image_name')
            if not image_name:
                print(f"[WARN] Entry missing image_name, skipping")
                return None

            img_path = COT_IMAGE_DIR / image_name
            if not img_path.exists():
                print(f"[WARN] Image not found: {image_name}")
                return None

            max_retries = 5
            resize_scale = [1.0, 0.9, 0.8, 0.6, 0.5]
            shrink_idx = 0

            for attempt in range(max_retries):
                try:
                    if shrink_idx == 0:
                        image_b64 = encode_image(img_path)
                    else:
                        scale = resize_scale[min(shrink_idx, len(resize_scale) - 1)]
                        image_b64 = resize_and_encode_image(img_path, scale=scale)

                    prompt = prompt_tmpl.format(question=entry['question'])
                    response_text = await model_fn(prompt, image_b64, max_tok)

                    res_entry = dict(entry)
                    res_entry['raw_response'] = response_text
                    return res_entry

                except Exception as e:
                    conn_err = is_connection_error(e)
                    if conn_err:
                        if attempt < max_retries - 1:
                            wait_time = 2 * (attempt + 1)
                            await asyncio.sleep(wait_time)
                            continue
                        else:
                            print(f"[FAIL] {image_name} {e}")
                            return None

                    print(f"[FAIL] {image_name} {type(e).__name__}: {e}")
                    return None
            return None

    results = list(existing_results)
    total = len(to_process)
    print(f"Launching processing for {total} images with up to {concurrency} parallel calls...")
    start_time = time.time()

    tasks = [asyncio.create_task(worker(entry)) for entry in to_process]
    processed_count = 0
    success_count = 0

    for task in tqdm(asyncio.as_completed(tasks), total=total):
        res = await task
        if res is not None:
            results.append(res)
            success_count += 1
            if success_count % 10 == 0:
                with open(out_path, 'w') as f:
                    json.dump(results, f, indent=4)

        processed_count += 1

    end_time = time.time()
    elapsed = end_time - start_time
    failed_count = total - success_count

    with open(out_path, 'w') as f:
        json.dump(results, f, indent=4)

    print(f"\n Finished {model_key} ")
    print(f"  Time: {elapsed:.2f}s | Saved to: {out_path.name}")


In [None]:
PROMPT_CONFIGS = {

    "spatial_obj_baseline": {
        "template": (
            "Given an image, answer the multiple choice question.\n\n"
            "{question}\n\n"
            "Only answer by copying exactly the correct line from the choices above."
        ),
        "max_tokens": 1000,
    },


    "spatial_obj_zeroshot_cot": {
        "template": (
            "Given an image, answer the multiple choice question.\n\n"
            "{question}\n\n"
            "Letâ€™s think step by step before answering the question. "
            "Only answer by copying exactly the correct line from the choices above."
        ),
        "max_tokens": 1000,
    },

    "multihop_baseline": {
        "template": (
            "Given an image, answer the question.\n\n"
            "{question}\n\n"
            "Answer the question in the format:\n"
            "Answer: <your answer here>"
        ),
        "max_tokens": 1000,
    },
    "multihop_zeroshot_cot": {
        "template": (
            "Given an image, answer the question about spatial relationship.\n\n"
            "{question}\n\n"
            "Let's think step by step before answering the question. "
            "Explain your reasoning in less than three sentences.\n"
            "Then answer the question in the format:\n"
            "Answer: <your answer here>"
        ),
        "max_tokens": 1000,
    },
}


In [None]:
# one_obj

running = [
    ("gpt4o", 1),
    ("gpt5_mini", 5),
    ("gemini_flash", 15)
    ("gemini_flash_lite", 10)
]

prompts = ["spatial_obj_baseline", "spatial_obj_zeroshot_cot"]

async def main_runner():
    for model, limit in running:
        for prompt in prompts:
            print(f"\n>>> Starting {model} with {prompt}")
            await process_dataset("spatial_mm_one_obj", prompt, model, concurrency=limit)

await main_runner()


In [None]:
# two_obj

running = [
    ("gpt4o", 1),
    ("gpt5_mini", 5),
    ("gemini_flash", 15)
    ("gemini_flash_lite", 10)
]

prompts = ["spatial_obj_baseline", "spatial_obj_zeroshot_cot"]

async def main_runner():
    for model, limit in running:
        for prompt in prompts:
            print(f"\n>>> Starting {model} with {prompt}")
            await process_dataset("spatial_mm_two_obj", prompt, model, concurrency=limit)

await main_runner()


In [None]:
# multihop

running = [
    ("gpt4o", 1),
    ("gpt5_mini", 5),
    ("gemini_flash", 15)
    ("gemini_flash_lite", 10)
]

prompts = ["multihop_baseline", "multihop_zeroshot_cot"]

async def main_runner():
    for model, limit in running:
        for prompt in prompts:
            print(f"\n>>> Starting {model} with {prompt}")
            await process_dataset("multihop_reasoning_309", prompt, model, concurrency=limit)

await main_runner()
