In [None]:
!pip install import-ipynb

In [None]:
%cd /content/drive/MyDrive/SHBT261FinalProject/code/src

In [None]:
"""
Zero-shot Evaluation Script for TextVQA
Evaluates pretrained VLM directly on validation/test sets
"""

import os
import json
from tqdm import tqdm
from datetime import datetime
import torch
import numpy as np

import import_ipynb
from data_loader import TextVQADataset
from model import get_model_and_processor, get_generation_config
from metrics import compute_all_metrics, print_metrics

In [None]:
def run_inference_zeroshot(
    model,
    processor,
    dataset,
    batch_size=1,
    max_samples=None,
    save_predictions=True,
    output_dir="results",
    metric_filename="zeroshot_metric.json",
    pred_filename="zeroshot_predictions.json",
):
    model.eval()

    predictions = []
    ground_truths = []
    questions_list = []
    results = []

    num_samples = min(len(dataset), max_samples or len(dataset))
    gen_config = get_generation_config()

    print(f"\nRunning inference on {num_samples} samples...")

    with torch.no_grad():
        for idx in tqdm(range(num_samples)):
            sample = dataset[idx]
            image = sample["image"]
            question = sample["question"]
            gt_answers = sample["answers"]

            # OCR-aware prompt
            ocr_tokens = sample.get("ocr_tokens", [])
            ocr_str = ", ".join([str(t) for t in ocr_tokens]) if ocr_tokens else ""
            ocr_section = f"OCR tokens detected: {ocr_str}.\n" if ocr_str else ""

            text_prompt = (
                f"{ocr_section}"
                f"Question: {question}\n"
                "Please answer using one of the OCR tokens above or a short phrase derived from them."
            )

            conv = [
                {"role": "user",
                 "content":[{"type": "image"}, {"type": "text","text": text_prompt}]}
            ]

            text = processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)
            inputs = processor(text=[text], images=[image], return_tensors="pt", padding=True)
            device = next(model.parameters()).device
            inputs = {k: v.to(device) if hasattr(v, "to") else v for k,v in inputs.items()}

            output_ids = model.generate(**inputs, **gen_config)
            input_len = inputs["input_ids"].shape[1]
            pred_ids = output_ids[:, input_len:]
            prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0].strip()

            predictions.append(prediction)
            ground_truths.append(gt_answers)
            questions_list.append(question)

            results.append({
                "image_id": str(sample["image_id"]),
                "question_id": int(sample["question_id"]),
                "question": question,
                "prediction": prediction,
                "ground_truths": gt_answers,
                "ocr_tokens": ocr_tokens,
            })

    metrics = compute_all_metrics(predictions, ground_truths, questions=questions_list)

    # save outputs
    if save_predictions:
        os.makedirs(output_dir, exist_ok=True)

        pred_path = os.path.join(output_dir, pred_filename)
        with open(pred_path, "w") as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
        print("Predictions saved:", pred_path)

        metric_path = os.path.join(output_dir, metric_filename)
        with open(metric_path, "w") as f:
            json.dump(metrics, f, indent=2)
        print("Metrics saved:", metric_path)

    return {"predictions": predictions, "results": results, "metrics": metrics}
