In [None]:
!pip install -q transformers accelerate bitsandbytes datasets peft pillow tqdm \
              evaluate sentence-transformers rouge-score sacrebleu

import os
import json
import re
from collections import defaultdict

import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm

import numpy as np
import evaluate
from sentence_transformers import SentenceTransformer, util
import matplotlib.pyplot as plt

from transformers import Blip2Processor, Blip2ForConditionalGeneration
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, PeftModel

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

MODEL_NAME = "Salesforce/blip2-opt-2.7b"
RESULTS_DIR = "results"
PLOTS_DIR = os.path.join(RESULTS_DIR, "plots")
CKPT_DIR = "checkpoints"

os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(PLOTS_DIR, exist_ok=True)
os.makedirs(CKPT_DIR, exist_ok=True)

# -------------------------
# Global processor (reused)
# -------------------------
processor = Blip2Processor.from_pretrained(MODEL_NAME)


Device: cuda


In [None]:
# =========================
# TextVQA Dataset
# =========================
class TextVQADataset(Dataset):
    def __init__(self, split="train", cache_dir=None, limit=None):
        """
        Args:
            split (str): "train", "validation", "test".
            cache_dir (str, optional): cache dir.
            limit (int, optional): subsample for quick runs.
        """
        self.split = split
        print(f"Loading TextVQA split: {split} ...")
        self.dataset = load_dataset("lmms-lab/textvqa", split=split, cache_dir=cache_dir)
        if limit is not None:
            self.dataset = self.dataset.select(range(limit))
        print(f"Loaded {len(self.dataset)} samples.")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item["image"]                    # PIL Image
        question = item["question"]
        image_id = item["image_id"]
        answers = item.get("answers", [])
        ocr_tokens = item.get("ocr_tokens", [])
        return {
            "image": image,
            "question": question,
            "answers": answers,
            "image_id": image_id,
            "ocr_tokens": ocr_tokens,
        }


def vqa_collate(batch):
    return {
        "image": [b["image"] for b in batch],
        "question": [b["question"] for b in batch],
        "answers": [b["answers"] for b in batch],
        "image_id": [b["image_id"] for b in batch],
        "ocr_tokens": [b["ocr_tokens"] for b in batch],
    }

# =========================
# Metrics
# =========================
try:
    bleu_metric = evaluate.load("bleu")
    meteor_metric = evaluate.load("meteor")
    rouge_metric = evaluate.load("rouge")
except Exception as e:
    print("Warning loading text metrics:", e)
    bleu_metric = meteor_metric = rouge_metric = None

try:
    semantic_model = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
except Exception as e:
    print("Warning loading semantic model:", e)
    semantic_model = None


def preprocess_answer(ans):
    ans = str(ans).lower()
    ans = ans.replace("\n", " ").replace("\t", " ").strip()
    return ans


def compute_vqa_accuracy(ground_truth_list, predicted_answer):
    if not ground_truth_list:
        return 0.0
    predicted_answer = preprocess_answer(predicted_answer)
    gts = [preprocess_answer(a) for a in ground_truth_list]
    match_count = sum(1 for gt in gts if gt == predicted_answer)
    return min(1.0, match_count / 3.0)  # 官方公式


def compute_semantic_similarity(ground_truth_list, predicted_answer):
    if semantic_model is None:
        return 0.0
    predicted_answer = str(predicted_answer)
    gts = [str(gt) for gt in (ground_truth_list or [""])]
    pred_emb = semantic_model.encode(predicted_answer, convert_to_tensor=True)
    gt_embs = semantic_model.encode(gts, convert_to_tensor=True)
    cosine_scores = util.cos_sim(pred_emb, gt_embs)
    if cosine_scores.numel() == 0:
        return 0.0
    return float(torch.max(cosine_scores).item())


def calculate_metrics(results):
    """
    results: list of {
       "predicted_answer": str,
       "ground_truth_answers": list[str]
    }
    """
    if not results:
        return {}

    total_acc = 0.0
    predictions, references = [], []
    semantic_scores = []

    for item in results:
        pred = item["predicted_answer"]
        gts = item["ground_truth_answers"]

        acc = compute_vqa_accuracy(gts, pred)
        total_acc += acc

        predictions.append(str(pred))
        references.append([str(gt) for gt in gts] if gts else [""])

        if semantic_model is not None:
            semantic_scores.append(compute_semantic_similarity(gts, pred))

    metrics = {"accuracy": total_acc / len(results)}

    # BLEU
    if bleu_metric is not None:
        try:
            bleu_score = bleu_metric.compute(
                predictions=predictions,
                references=references,
                max_order=2,
                smooth=True,
            )
            metrics["bleu"] = float(bleu_score.get("bleu", 0.0))
        except Exception as e:
            print("Error computing BLEU:", e)
            metrics["bleu"] = 0.0

    # METEOR
    if meteor_metric is not None:
        try:
            ms = meteor_metric.compute(predictions=predictions, references=references)
            metrics["meteor"] = float(ms.get("meteor", 0.0))
        except Exception as e:
            print("Error computing METEOR:", e)
            metrics["meteor"] = 0.0

    # ROUGE
    if rouge_metric is not None:
        try:
            rs = rouge_metric.compute(predictions=predictions, references=references)
            metrics["rouge1"] = float(rs.get("rouge1", 0.0))
            metrics["rouge2"] = float(rs.get("rouge2", 0.0))
            metrics["rougeL"] = float(rs.get("rougeL", 0.0))
        except Exception as e:
            print("Error computing ROUGE:", e)
            metrics["rouge1"] = 0.0
            metrics["rouge2"] = 0.0
            metrics["rougeL"] = 0.0

    if semantic_scores:
        metrics["semantic_similarity"] = float(sum(semantic_scores) / len(semantic_scores))
    else:
        metrics["semantic_similarity"] = 0.0

    return metrics


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [None]:
# =========================
# OCR cleaning & question type
# =========================

def clean_ocr_text(ocr_tokens, max_tokens=15, min_token_length=2):
    if not ocr_tokens:
        return ""
    cleaned = []
    seen = set()
    for token in ocr_tokens:
        token = token.strip()
        if not token or len(token) < min_token_length:
            continue
        alnum_ratio = sum(c.isalnum() for c in token) / len(token)
        if alnum_ratio < 0.3:
            continue
        token_lower = token.lower()
        if token_lower not in seen:
            seen.add(token_lower)
            cleaned.append(token)
        if len(cleaned) >= max_tokens:
            break
    return " ".join(cleaned)


def classify_question(question):
    q = question.lower()
    brand_keywords = ["brand", "logo", "label", "company", "manufacturer", "maker"]
    number_keywords = ["number", "percent", "%", "price", "$", "cost", "amount", "how much", "how many"]
    date_keywords = ["year", "date", "born", "since", "when"]
    time_keywords = ["time", "clock", "what time"]
    text_keywords = ["text", "say", "spell", "read", "what does", "what is written"]

    if any(k in q for k in brand_keywords):
        return "brand"
    if any(k in q for k in number_keywords):
        return "number"
    if any(k in q for k in date_keywords):
        return "date"
    if any(k in q for k in time_keywords):
        return "time"
    if any(k in q for k in text_keywords):
        return "text"
    return "general"


def summarize_ocr_by_type(ocr_tokens, q_type, max_tokens=15, min_token_length=2):
    if not ocr_tokens:
        return ""
    filtered = []
    for token in ocr_tokens:
        token = token.strip()
        if not token or len(token) < min_token_length:
            continue
        alnum_ratio = sum(c.isalnum() for c in token) / len(token)
        if alnum_ratio < 0.3:
            continue

        if q_type == "brand":
            if token[0].isupper() and token.isalpha():
                filtered.append(token)
        elif q_type == "number":
            if any(c.isdigit() for c in token) or "$" in token or "%" in token:
                filtered.append(token)
        elif q_type == "date":
            if re.match(r"^\d{4}$", token) or re.match(r"^\d{1,2}[/-]\d{1,2}[/-]\d{2,4}$", token):
                filtered.append(token)
            elif token.isdigit() and len(token) >= 4:
                filtered.append(token)
        elif q_type == "time":
            if re.match(r"^\d{1,2}[:]\d{2}", token):
                filtered.append(token)
            elif any(c.isdigit() for c in token) and len(token) <= 5:
                filtered.append(token)
        elif q_type == "text":
            filtered.append(token)
        else:
            filtered.append(token)

        if len(filtered) >= max_tokens:
            break

    return " ".join(filtered) if filtered else ""


# =========================
# Prompt templates
# =========================
PROMPT_TEMPLATES = {
    # Baselines (non-OCR)
    "default": "Question: {question} Answer:",
    "descriptive": "Based on the image, answer the question briefly: {question}",
    "instruction": "Look at the image and answer in a few words: {question}",
    "direct": "{question}",
    "text_focus": "Focus on any visible text in the image. Question: {question}",
    "short_direct": "Answer in 1–3 words: {question}",

    # OCR-related
    "ocr_hint": "The image contains the following text: {ocr_text}. Question: {question} Answer:",
    "ocr_hint_v3": "Answer this question about the image: {question}\nVisible text in image: {ocr_text}\nAnswer:",
    "basic_ocr": "Detected text in the image: {ocr_text}\nQuestion: {question}\nAnswer in a short phrase:",
    "ocr_category": (
        "Relevant text in the image ({q_type}): {ocr_summary}\n"
        "Question: {question}\n"
        "Use ONLY that text when answering.\n"
        "Answer briefly:"
    ),
    "structured_ocr": (
        "Relevant text in the image ({q_type}): {ocr_summary}\n"
        "Question: {question}\n"
        "Answer in 1–3 words using that text:"
    ),
}


def build_formatted_question(question, ocr_tokens, template_name=None):
    """
    prompt builder for BLIP
    """
    if template_name is None:
        return question

    tpl = PROMPT_TEMPLATES.get(template_name, "Question: {question} Answer:")

    ocr_text = ""
    ocr_summary = ""
    q_type = ""

    if "{ocr_summary}" in tpl and "{q_type}" in tpl:
        q_type = classify_question(question)
        ocr_summary = summarize_ocr_by_type(ocr_tokens, q_type)
        if not ocr_summary:

            stripped = (
                tpl.replace("Relevant text in the image ({q_type}): {ocr_summary}\n", "")
                   .replace("Use ONLY that text when answering.\n", "")
                   .replace("Answer in 1–3 words using that text:", "Answer briefly:")
            )
            return stripped.format(question=question, q_type=q_type)
        return tpl.format(question=question, ocr_summary=ocr_summary, q_type=q_type)

    if "{ocr_text}" in tpl:
        ocr_text = clean_ocr_text(ocr_tokens)
        if not ocr_text:
            stripped = (
                tpl.replace("{ocr_text}", "")
                   .replace("Detected text in the image:", "")
                   .replace("Visible text in image:", "")
                   .replace("The image contains the following text:", "")
            )
            return stripped.format(question=question)
        return tpl.format(question=question, ocr_text=ocr_text)

    # non OCR
    return tpl.format(question=question)


In [None]:
# =========================
# BLIP-2 OPT-2.7B Model Wrapper
# =========================

def load_blip2_opt(dtype=None):
    print("Loading BLIP-2 OPT-2.7B ...")
    if dtype is None:
        dtype = torch.float16 if device == "cuda" else torch.float32
    model = Blip2ForConditionalGeneration.from_pretrained(
        MODEL_NAME,
        torch_dtype=dtype,
    ).to(device)
    model.eval()
    return model


@torch.no_grad()
def blip_generate(model, image, prompt, max_new_tokens=20, num_beams=1):
    inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
    gen_kwargs = {"max_new_tokens": max_new_tokens}
    if num_beams > 1:
        gen_kwargs.update({"num_beams": num_beams})
    output_ids = model.generate(**inputs, **gen_kwargs)
    text = processor.tokenizer.decode(output_ids[0], skip_special_tokens=True)

    if "answer:" in text.lower():
        idx = text.lower().rfind("answer:")
        pred = text[idx + len("answer:"):].strip()
    else:
        pred = text.strip()
    return pred


# =========================
# LoRA
# =========================

def build_lora_model(base_model, r=16, alpha=16, target_modules=("q_proj", "v_proj")):
    cfg = LoraConfig(
        r=r,
        lora_alpha=alpha,
        target_modules=list(target_modules),
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )
    lora_model = get_peft_model(base_model, cfg)
    lora_model.print_trainable_parameters()
    return lora_model


def train_one_epoch_lora(model, train_loader, lr=1e-4, max_steps=300, prompt_template="default"):
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    step = 0
    running_loss = 0.0

    for batch in tqdm(train_loader, total=max_steps):
        images = batch["image"]
        questions = batch["question"]
        answers_list = batch["answers"]


        gt_answers = [a[0] if len(a) > 0 else "" for a in answers_list]


        full_texts = [
            f"Question: {q} Answer: {ans}"
            for q, ans in zip(questions, gt_answers)
        ]

        enc = processor(
            images=images,
            text=full_texts,
            return_tensors="pt",
            padding=True,
        ).to(device)

        labels = enc["input_ids"].clone()
        enc["labels"] = labels

        out = model(**enc)
        loss = out.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        step += 1

        if step % 50 == 0:
            print(f"step {step} | loss = {loss.item():.4f}")

        if step >= max_steps:
            break

    avg_loss = running_loss / max_steps
    print(f"Average train loss: {avg_loss:.4f}")
    return model


In [None]:
def save_results_json(model_name_prefix, zs_or_ft, prompt_name, split, metrics, results):

    parts = [model_name_prefix]
    if zs_or_ft == "ft":
        parts.append("finetuned")
    if prompt_name is not None:
        parts.append(f"prompt_{prompt_name}")
    parts.append(split)
    parts.append("results")
    fname = "_".join(parts) + ".json"
    out_path = os.path.join(RESULTS_DIR, fname)

    payload = {
        "metrics": metrics,
        "results": results,
        "config": {
            "model": model_name_prefix,
            "zs_or_ft": zs_or_ft,
            "prompt_name": prompt_name,
            "split": split,
        },
    }
    with open(out_path, "w") as f:
        json.dump(payload, f, indent=2)
    print(f"Saved: {out_path}")


@torch.no_grad()
def run_eval_blip(model, split="validation", limit=200, prompt_template=None, zs_or_ft="zs"):
    ds = TextVQADataset(split=split)
    loader = DataLoader(ds, batch_size=1, shuffle=False, collate_fn=vqa_collate)

    results = []
    count = 0

    tag = f"{zs_or_ft} {prompt_template or 'default'}"
    print(f"Evaluating BLIP-2 OPT ({tag}) on {split}, limit={limit} ...")

    for batch in tqdm(loader, total=min(limit, len(ds))):
        image = batch["image"][0]
        question = batch["question"][0]
        answers = batch["answers"][0]
        image_id = batch["image_id"][0]
        ocr_tokens = batch["ocr_tokens"][0]

        formatted_q = build_formatted_question(question, ocr_tokens, prompt_template)

        try:
            pred = blip_generate(model, image, formatted_q)
        except Exception as e:
            print(f"Error on {image_id}: {e}")
            pred = ""

        item = {
            "image_id": image_id,
            "question": question,
            "formatted_question": formatted_q,
            "predicted_answer": pred,
            "ground_truth_answers": answers,
        }
        results.append(item)

        count += 1
        if count >= limit:
            break

    metrics = calculate_metrics(results)
    print("Metrics:", metrics)

    save_results_json(
        model_name_prefix="blip_opt",
        zs_or_ft=zs_or_ft,
        prompt_name=prompt_template if prompt_template else "default",
        split=split,
        metrics=metrics,
        results=results,
    )
    return metrics, results


In [None]:
def load_blip_opt_records():
    records = []
    for fname in os.listdir(RESULTS_DIR):
        if not fname.endswith(".json"):
            continue
        if not fname.startswith("blip_opt"):
            continue
        path = os.path.join(RESULTS_DIR, fname)
        with open(path, "r") as f:
            data = json.load(f)
        metrics = data.get("metrics", {})
        base = os.path.splitext(fname)[0]
        parts = base.split("_")

        model = parts[0]  # blip_opt
        zs_or_ft = "zs"
        prompt = "none"
        split = "validation"

        if "finetuned" in parts:
            zs_or_ft = "ft"
        if "prompt" in parts:
            idx = parts.index("prompt")
            if idx + 1 < len(parts):
                prompt = parts[idx + 1]
        if parts[-1] == "results" and len(parts) >= 3:
            split = parts[-2]

        records.append({
            "file": fname,
            "zs_or_ft": zs_or_ft,
            "prompt": prompt,
            "split": split,
            "accuracy": metrics.get("accuracy", 0.0),
            "bleu": metrics.get("bleu", 0.0),
            "meteor": metrics.get("meteor", 0.0),
            "rouge1": metrics.get("rouge1", 0.0),
            "semantic_similarity": metrics.get("semantic_similarity", 0.0),
        })
    return records


def plot_bar(records, metric, title, filename):
    if not records:
        print(f"No records for metric {metric}")
        return


    records = sorted(records, key=lambda r: (r["zs_or_ft"], r["prompt"]))

    labels, values = [], []
    for r in records:
        tag = "ZS" if r["zs_or_ft"] == "zs" else "FT"
        label = f"{tag}-{r['prompt']}"
        labels.append(label)
        values.append(r.get(metric, 0.0))

    plt.figure(figsize=(max(8, len(labels) * 0.7), 5))
    x = range(len(labels))
    plt.bar(x, values)
    plt.xticks(x, labels, rotation=45, ha="right")
    plt.ylabel(metric)
    plt.title(title)
    plt.tight_layout()

    out_path = os.path.join(PLOTS_DIR, filename)
    plt.savefig(out_path, dpi=200)
    plt.close()
    print(f"Saved plot: {out_path}")


def plot_all_blip_opt():
    records = load_blip_opt_records()
    if not records:
        print("No blip_opt*.json results yet.")
        return

    plot_bar(records, "accuracy", "BLIP-2 OPT TextVQA Accuracy (ZS vs FT, prompts)", "blip_opt_accuracy.png")
    plot_bar(records, "bleu", "BLIP-2 OPT BLEU (ZS vs FT, prompts)", "blip_opt_bleu.png")
    plot_bar(records, "meteor", "BLIP-2 OPT METEOR (ZS vs FT, prompts)", "blip_opt_meteor.png")
    plot_bar(records, "rouge1", "BLIP-2 OPT ROUGE-1 (ZS vs FT, prompts)", "blip_opt_rouge1.png")
    plot_bar(records, "semantic_similarity", "BLIP-2 OPT Semantic Similarity (ZS vs FT, prompts)", "blip_opt_semantic.png")


In [None]:
########################################
# 1) Zero-shot baselines (BLIP-2 OPT)
########################################
blip_zs_model = load_blip2_opt()

# ZS default
zs_default_metrics, _ = run_eval_blip(
    blip_zs_model,
    split="validation",
    limit=200,
    prompt_template="default",
    zs_or_ft="zs",
)

# ZS basic_ocr
zs_ocr_metrics, _ = run_eval_blip(
    blip_zs_model,
    split="validation",
    limit=200,
    prompt_template="basic_ocr",
    zs_or_ft="zs",
)

del blip_zs_model
torch.cuda.empty_cache()

########################################
# 2) Fine-tune with LoRA (BLIP-2 OPT)
########################################

train_ds = TextVQADataset(split="train", limit=5000)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=vqa_collate)

base_model = load_blip2_opt()
lora_model = build_lora_model(base_model, r=16, alpha=16, target_modules=("q_proj", "v_proj"))

lora_model = train_one_epoch_lora(
    lora_model,
    train_loader,
    lr=1e-4,
    max_steps=300,
    prompt_template="default",
)

# save LoRA
lora_save_path = os.path.join(CKPT_DIR, "blip_opt_lora_r16_a16")
lora_model.save_pretrained(lora_save_path)
print("LoRA saved to:", lora_save_path)

########################################
# 3) Fine-tuned eval:
########################################

ft_default_metrics, _ = run_eval_blip(
    lora_model,
    split="validation",
    limit=200,
    prompt_template="default",
    zs_or_ft="ft",
)

ft_descr_metrics, _ = run_eval_blip(
    lora_model,
    split="validation",
    limit=200,
    prompt_template="descriptive",
    zs_or_ft="ft",
)

ft_textfocus_metrics, _ = run_eval_blip(
    lora_model,
    split="validation",
    limit=200,
    prompt_template="text_focus",
    zs_or_ft="ft",
)

ft_basicocr_metrics, _ = run_eval_blip(
    lora_model,
    split="validation",
    limit=200,
    prompt_template="basic_ocr",
    zs_or_ft="ft",
)

ft_structuredocr_metrics, _ = run_eval_blip(
    lora_model,
    split="validation",
    limit=200,
    prompt_template="structured_ocr",
    zs_or_ft="ft",
)

########################################
# 4) plots
########################################
plot_all_blip_opt()


`torch_dtype` is deprecated! Use `dtype` instead!


Loading BLIP-2 OPT-2.7B ...


config.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/10.0G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

Loading TextVQA split: validation ...


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/20 [00:00<?, ?files/s]

data/train-00000-of-00020.parquet:   0%|          | 0.00/303M [00:00<?, ?B/s]

data/train-00001-of-00020.parquet:   0%|          | 0.00/298M [00:00<?, ?B/s]

data/train-00002-of-00020.parquet:   0%|          | 0.00/290M [00:00<?, ?B/s]

data/train-00003-of-00020.parquet:   0%|          | 0.00/304M [00:00<?, ?B/s]

data/train-00004-of-00020.parquet:   0%|          | 0.00/318M [00:00<?, ?B/s]

data/train-00005-of-00020.parquet:   0%|          | 0.00/262M [00:00<?, ?B/s]

data/train-00006-of-00020.parquet:   0%|          | 0.00/304M [00:00<?, ?B/s]

data/train-00007-of-00020.parquet:   0%|          | 0.00/238M [00:00<?, ?B/s]

data/train-00008-of-00020.parquet:   0%|          | 0.00/280M [00:00<?, ?B/s]

data/train-00009-of-00020.parquet:   0%|          | 0.00/299M [00:00<?, ?B/s]

data/train-00010-of-00020.parquet:   0%|          | 0.00/286M [00:00<?, ?B/s]