In [None]:
!pip install import-ipynb

In [None]:
%cd /content/drive/MyDrive/SHBT261FinalProject/code/src

In [None]:
"""
Evaluate Fine-tuned Model on TextVQA
Evaluates model with LoRA weights on validation/test sets
"""

import os
import json
from tqdm import tqdm
from datetime import datetime
import torch
import numpy as np

import import_ipynb
from data_loader import TextVQADataset
from model import load_lora_weights, get_generation_config
from metrics import compute_all_metrics, print_metrics

In [None]:
def run_inference_finetuned(
    model,
    processor,
    dataset,
    max_samples=None,
    save_predictions=True,
    output_dir="results",
    metric_filename="finetune_metric.json",
    pred_filename="finetune_predictions.json",
):
    model.eval()
    device = next(model.parameters()).device

    predictions, ground_truths, questions_list, results = [], [], [], []
    num_samples = min(len(dataset), max_samples or len(dataset))
    gen_config = get_generation_config()

    print(f"\nRunning inference on {num_samples} samples...")

    with torch.no_grad():
        for idx in tqdm(range(num_samples)):
            sample = dataset[idx]
            image = sample["image"]
            question = sample["question"]
            answers = sample["answers"]

            # prompt
            conv = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text",
                         "text": f"{question}\nAnswer with only the exact text/number from the image."}
                    ]
                }
            ]

            text = processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)

            inputs = processor(
                text=[text], images=[image],
                return_tensors="pt", padding=True
            )
            inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}

            # Generate
            out_ids = model.generate(**inputs, **gen_config)
            input_len = inputs["input_ids"].shape[1]
            gen_ids = out_ids[:, input_len:]

            prediction = processor.batch_decode(gen_ids, skip_special_tokens=True)[0].strip()

            predictions.append(prediction)
            ground_truths.append(answers)
            questions_list.append(question)

            results.append({
                "image_id": str(sample["image_id"]),
                "question_id": int(sample["question_id"]),
                "question": question,
                "prediction": prediction,
                "ground_truths": answers,
                "ocr_tokens": list(sample.get("ocr_tokens", [])),
            })

    metrics = compute_all_metrics(predictions, ground_truths, questions_list)

    # save outputs
    if save_predictions:
        os.makedirs(output_dir, exist_ok=True)

        pred_path = os.path.join(output_dir, pred_filename)
        with open(pred_path, "w") as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
        print("Predictions saved:", pred_path)

        metric_path = os.path.join(output_dir, metric_filename)
        with open(metric_path, "w") as f:
            json.dump(metrics, f, indent=2)
        print("Metrics saved:", metric_path)

    return {
        "predictions": predictions,
        "results": results,
        "metrics": metrics,
    }