In [None]:
!pip install import-ipynb

In [None]:
%cd /content/drive/MyDrive/SHBT261FinalProject/code/src

In [None]:
"""
Results Visualization and Error Analysis Module
Provides comprehensive analysis of model predictions
"""

import os
import json
import re
from collections import Counter, defaultdict
from typing import List, Dict

import import_ipynb
from metrics import normalize_answer, textvqa_accuracy

In [None]:
def load_predictions(file_path: str) -> List[Dict]:
    """Load predictions from JSON file"""
    with open(file_path, "r") as f:
        return json.load(f)


def categorize_error(prediction: str, ground_truths: List[str], ocr_tokens: List[str]) -> str:
    """Categorize prediction error types"""
    pred_norm = normalize_answer(prediction)
    gt_norms = [normalize_answer(gt) for gt in ground_truths]
    ocr_norms = [normalize_answer(tok) for tok in ocr_tokens]

    if pred_norm in gt_norms:
        return "correct"

    for gt in gt_norms:
        if pred_norm in gt or gt in pred_norm:
            return "format_error"

    gt_in_ocr = any(gt in " ".join(ocr_norms) for gt in gt_norms)
    pred_in_ocr = pred_norm in " ".join(ocr_norms) or any(pred_norm in tok for tok in ocr_norms)

    if gt_in_ocr and not pred_in_ocr:
        return "ocr_miss"
    if pred_in_ocr and gt_in_ocr:
        return "reasoning_error"
    if not pred_in_ocr and ocr_tokens:
        return "hallucination"

    return "other"


def analyze_errors(results: List[Dict]) -> Dict:
    """Perform complete error analysis"""
    error_counts = Counter()
    error_examples = defaultdict(list)
    correct_count = 0

    for r in results:
        pred = r["prediction"]
        gts = r["ground_truths"]
        ocr = r.get("ocr_tokens", [])

        acc = textvqa_accuracy(pred, gts)
        if acc > 0.5:
            correct_count += 1
            etype = "correct"
        else:
            etype = categorize_error(pred, gts, ocr)

        error_counts[etype] += 1
        if len(error_examples[etype]) < 5:
            error_examples[etype].append({
                "question": r["question"],
                "prediction": pred,
                "ground_truths": gts[:3],
                "ocr_tokens": ocr[:5]
            })

    total = len(results)
    return {
        "total_samples": total,
        "correct_count": correct_count,
        "accuracy": correct_count / total if total else 0,
        "error_distribution": dict(error_counts),
        "error_examples": dict(error_examples),
    }


def analyze_question_types(results: List[Dict]) -> Dict:
    patterns = {
        "what": r"^what\s",
        "how many": r"^how many\s",
        "what is the name": r"what is the name",
        "what is the number": r"what is the number|what number",
        "what is the brand": r"what (?:is the )?brand",
        "what is the time": r"what (?:is the )?time",
        "what is the date": r"what (?:is the )?date",
        "what is the price": r"what (?:is the )?price",
        "what color": r"what color",
        "which": r"^which\s",
        "where": r"^where\s",
        "who": r"^who\s",
    }

    stats = defaultdict(lambda: {"correct": 0, "total": 0})

    for r in results:
        q = r["question"].lower()
        pred = r["prediction"]
        gts = r["ground_truths"]

        correct = textvqa_accuracy(pred, gts) > 0.5

        matched = False
        for key, pat in patterns.items():
            if re.search(pat, q):
                stats[key]["total"] += 1
                if correct:
                    stats[key]["correct"] += 1
                matched = True
                break

        if not matched:
            stats["other"]["total"] += 1
            if correct:
                stats["other"]["correct"] += 1

    for key in stats:
        t = stats[key]["total"]
        stats[key]["accuracy"] = stats[key]["correct"] / t if t else 0

    return dict(stats)


def analyze_answer_length(results: List[Dict]) -> Dict:
    stats = defaultdict(lambda: {"correct": 0, "total": 0})

    for r in results:
        gts = r["ground_truths"]
        pred = r["prediction"]

        if gts:
            avg_len = sum(len(gt.split()) for gt in gts) / len(gts)
            if avg_len <= 1:
                cat = "single_word"
            elif avg_len <= 3:
                cat = "short_phrase"
            else:
                cat = "long_phrase"
        else:
            cat = "unknown"

        correct = textvqa_accuracy(pred, gts) > 0.5
        stats[cat]["total"] += 1
        if correct:
            stats[cat]["correct"] += 1

    for cat in stats:
        t = stats[cat]["total"]
        stats[cat]["accuracy"] = stats[cat]["correct"] / t if t else 0

    return dict(stats)


def compare_models(zs_results: List[Dict], ft_results: List[Dict]) -> Dict:
    zs_map = {r["question_id"]: r for r in zs_results}
    ft_map = {r["question_id"]: r for r in ft_results}

    common = set(zs_map) & set(ft_map)

    comp = {
        "zeroshot": {"correct": 0, "total": len(common)},
        "finetuned": {"correct": 0, "total": len(common)},
        "improvements": [],
        "regressions": [],
    }

    for qid in common:
        zs = zs_map[qid]
        ft = ft_map[qid]

        zs_acc = textvqa_accuracy(zs["prediction"], zs["ground_truths"])
        ft_acc = textvqa_accuracy(ft["prediction"], ft["ground_truths"])

        if zs_acc > 0.5: comp["zeroshot"]["correct"] += 1
        if ft_acc > 0.5: comp["finetuned"]["correct"] += 1

        if ft_acc > zs_acc and len(comp["improvements"]) < 10:
            comp["improvements"].append({
                "question": zs["question"],
                "zeroshot_pred": zs["prediction"],
                "finetuned_pred": ft["prediction"],
                "ground_truths": zs["ground_truths"][:3],
            })
        elif ft_acc < zs_acc and len(comp["regressions"]) < 10:
            comp["regressions"].append({
                "question": zs["question"],
                "zeroshot_pred": zs["prediction"],
                "finetuned_pred": ft["prediction"],
                "ground_truths": zs["ground_truths"][:3],
            })

    comp["zeroshot"]["accuracy"] = comp["zeroshot"]["correct"] / len(common)
    comp["finetuned"]["accuracy"] = comp["finetuned"]["correct"] / len(common)
    comp["improvement"] = comp["finetuned"]["accuracy"] - comp["zeroshot"]["accuracy"]
    comp["num_compared"] = len(common)

    return comp


In [None]:
def run_analysis(predictions_path: str, output_dir="results", model_name="Model"):
    os.makedirs(output_dir, exist_ok=True)

    results = load_predictions(predictions_path)

    err = analyze_errors(results)
    qtype = analyze_question_types(results)
    length = analyze_answer_length(results)

    out = {
        "error_analysis": err,
        "question_type_analysis": qtype,
        "answer_length_analysis": length
    }

    out_json = os.path.join(output_dir, f"analysis_detailed_{model_name}.json")
    with open(out_json, "w") as f:
        json.dump(out, f, indent=2)

    print("Saved:", out_json)
    return out


In [None]:
def compare_two_models(zs_path, ft_path, output_dir="results"):
    os.makedirs(output_dir, exist_ok=True)

    zs = load_predictions(zs_path)
    ft = load_predictions(ft_path)

    comp = compare_models(zs, ft)

    out_path = os.path.join(output_dir, "model_comparison.json")
    with open(out_path, "w") as f:
        json.dump(comp, f, indent=2)

    print("Saved comparison:", out_path)
    return comp
