# **0. Setup**

## Drive

In [None]:
from google.colab import drive
drive.mount("/content/drive")
%cd /content/drive/MyDrive/MLQA-TSR/

## Libs

In [None]:
!pip install qdrant_client

In [None]:
import os
import json
import pickle
from pathlib import Path
from tqdm import tqdm
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
from transformers import AutoModel

## Configs


In [None]:
# from huggingface_hub import login
# login("x")

In [None]:
N_TEXTS = 10
M_IMGS = 5
K_OBJS = 3
K_SHOTS = 0  # TRY ZERO-SHOT FIRST!

# MODEL_NAME = "Qwen/Qwen2.5-VL-7B-Instruct"
MODEL_NAME = "OpenGVLab/InternVL3-8B"


## Paths

In [None]:
DATA_ROOT = "/content/drive/MyDrive/MLQA-TSR/LIFEISTOUGH/data"
TRAIN_JSON = f"{DATA_ROOT}/train_data/vlsp_2025_train.json"
TEST_JSON = f"{DATA_ROOT}/public_test/vlsp_2025_public_test_task1.json"
TRAIN_IMAGE_DIR = Path(f"{DATA_ROOT}/train_data/train_images")
TEST_IMAGE_DIR = Path(f"{DATA_ROOT}/public_test/public_test_images")

OUT_DIR = "/content/drive/MyDrive/MLQA-TSR/outputs"
run_time = "20251225-fulldata"
RUN_DIR = os.path.join(OUT_DIR, run_time)
os.makedirs(RUN_DIR, exist_ok=True)

RETRIEVAL_PATH = os.path.join(RUN_DIR, f"subtask1_retrieval_n{N_TEXTS}_m{M_IMGS}_k{K_OBJS}.pkl")

# **1. Load data**

In [None]:
with open(TRAIN_JSON, "r", encoding="utf-8") as f:
    train_data = json.load(f)

with open(TEST_JSON, "r", encoding="utf-8") as f:
    test_data = json.load(f)

train_by_id = {s["id"]: s for s in train_data}
train_id2img = {s["id"]: s.get("image_id") or s.get("image") for s in train_data}

with open(RETRIEVAL_PATH, "rb") as f:
    retrieval_cache = pickle.load(f)

print(f"Loaded {len(train_data)} training samples")
print(f"Loaded {len(test_data)} test samples")

# **2. Helper/Util**

In [None]:
def format_choices(choices: dict) -> str:
    if not choices:
        return ""
    order = ["A", "B", "C", "D"]
    keys = [k for k in order if k in choices] + [k for k in choices if k not in order]
    lines = [f"{k}. {choices[k]}" for k in keys]
    return "\n".join(lines)


In [None]:
def build_qa_text(question: str, choices: dict | None) -> str:
    if choices:
        return question.strip() + "\n" + format_choices(choices)
    return question.strip()

In [None]:
def get_train_image_path(sample_id: str) -> Path:
    img_id = train_id2img.get(sample_id)
    if img_id is None:
        raise KeyError(f"Cannot find image_id for {sample_id}")

    for ext in [".jpg", ".png"]:
        p = TRAIN_IMAGE_DIR / f"{img_id}{ext}"
        if p.exists():
            return p
    raise FileNotFoundError(f"Image not found: {img_id}")

In [None]:
def get_test_image_path(test_sample: dict) -> Path:
    img_id = test_sample.get("image_id") or test_sample.get("image")

    for ext in [".jpg", ".png"]:
        p = TEST_IMAGE_DIR / f"{img_id}{ext}"
        if p.exists():
            return p
    raise FileNotFoundError(f"Test image not found: {img_id}")

In [None]:
def get_few_shot_examples(test_id: str, retrieval_cache: dict, k: int, min_score: float = 0.5) -> list:
    """
    Extract top-k training samples with SCORE FILTERING

    Args:
        min_score: Minimum similarity score (0.0-1.0). Default 0.5 = only take good matches
    """
    if test_id not in retrieval_cache:
        return []

    shots = []
    for rec, score in retrieval_cache[test_id]:
        # FILTER by score!
        if score < min_score:
            continue

        train_id = rec.payload["sample_id"]
        shots.append({
            "train_id": train_id,
            "score": float(score),
        })
        if len(shots) >= k:
            break

    return shots

# **3. System prompt**

In [None]:
# Version 1: Original from code
SYSTEM_PROMPT_V1 = (
    "Given an image and a question both about traffic in Vietnam. "
    "Multiple choices and yes/no questions shall be provided. "
    "If A, B, C, D were given, choose the letter only. "
    "If Đúng (Correct); Sai (Wrong) were given, choose Đúng (Correct) or Sai (Wrong) only. "
    "No need explanation needed."
)

# Version 2: More detailed (my previous version)
SYSTEM_PROMPT_V2 = """You are an expert in Vietnamese traffic laws and traffic signs.

Given an image showing traffic signs and a question in Vietnamese, you must:
1. Analyze the traffic signs in the image
2. Understand the question and provided choices
3. Select the correct answer based on Vietnamese traffic regulations (QCVN 41:2024/BGTVT and Law 36/2024/QH15)

Answer format:
- If choices A, B, C, D are given: respond with ONLY the letter (A, B, C, or D)
- If Đúng/Sai (True/False) is asked: respond with ONLY "Đúng" or "Sai"

Do NOT provide explanations. Answer with the letter or word only."""

# Version 3: Vietnamese - closer to training data language
SYSTEM_PROMPT_V3 = """Bạn là chuyên gia về luật giao thông và biển báo giao thông Việt Nam.

Với một hình ảnh và câu hỏi về giao thông, hãy:
- Phân tích biển báo trong ảnh
- Chọn đáp án đúng dựa trên QCVN 41:2024/BGTVT và Luật 36/2024/QH15

Định dạng trả lời:
- Nếu có lựa chọn A, B, C, D: chỉ trả lời CHỮ CÁI (A, B, C hoặc D)
- Nếu là câu hỏi Đúng/Sai: chỉ trả lời "Đúng" hoặc "Sai"

KHÔNG giải thích. Chỉ trả lời chữ cái hoặc từ."""

# Version 4: Focus on Multiple Choice weakness
SYSTEM_PROMPT_V4 = """You are an expert in Vietnamese traffic laws and signs (QCVN 41:2024/BGTVT).

IMPORTANT INSTRUCTIONS:
1. Analyze the traffic signs carefully in the image
2. For Multiple Choice questions with options A, B, C, D:
   - Read ALL options carefully
   - Compare each option with what you see in the image
   - Choose the MOST ACCURATE option
   - Answer with ONLY the letter: A, B, C, or D
3. For Yes/No questions (Đúng/Sai):
   - Answer with ONLY: Đúng or Sai

Do NOT add explanations, periods, or extra text. Just the answer."""

# Version 5: Vietnamese + emphasis on careful analysis
SYSTEM_PROMPT_V5 = """Bạn là chuyên gia luật giao thông Việt Nam (QCVN 41:2024/BGTVT).

HƯỚNG DẪN QUAN TRỌNG:
1. Quan sát KỸ CÀNG biển báo trong ảnh
2. Với câu hỏi trắc nghiệm (A, B, C, D):
   - Đọc KỸ TẤT CẢ các đáp án
   - So sánh TỪNG đáp án với những gì thấy trong ảnh
   - Chọn đáp án CHÍNH XÁC NHẤT
   - Chỉ trả lời CHỮ CÁI: A, B, C hoặc D
3. Với câu hỏi Đúng/Sai:
   - Chỉ trả lời: Đúng hoặc Sai

KHÔNG giải thích. KHÔNG thêm dấu chấm. CHỈ trả lời đáp án."""

# Use V5 - Vietnamese + detailed instructions
SYSTEM_PROMPT = SYSTEM_PROMPT_V5

# **4. Build Messages**

In [None]:
def build_messages_adaptive(
    test_sample: dict,
    retrieval_cache: dict,
    train_by_id: dict,
    k_shot: int = K_SHOTS,
    min_score: float = 0.5,
) -> list:
    """
    Adaptive few-shot:
    - If k_shot = 0: zero-shot
    - If k_shot > 0: only use high-quality examples (score >= min_score)
    """
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]

    # Few-shot examples (if k_shot > 0)
    if k_shot > 0:
        test_id = test_sample["id"]
        shot_examples = get_few_shot_examples(test_id, retrieval_cache, k=k_shot, min_score=min_score)

        for shot_info in shot_examples:
            train_id = shot_info["train_id"]
            train_sample = train_by_id.get(train_id)
            if train_sample is None:
                continue

            try:
                img_path = get_train_image_path(train_id)
            except (KeyError, FileNotFoundError):
                continue

            # User message
            messages.append({
                "role": "user",
                "content": [
                    {"type": "image", "image": str(img_path)},
                    {"type": "text", "text": build_qa_text(train_sample["question"], train_sample.get("choices"))}
                ],
            })

            # Assistant message
            messages.append({
                "role": "assistant",
                "content": [{"type": "text", "text": train_sample["answer"]}]
            })

    # Test query
    test_img_path = get_test_image_path(test_sample)
    messages.append({
        "role": "user",
        "content": [
            {"type": "image", "image": str(test_img_path)},
            {"type": "text", "text": build_qa_text(test_sample["question"], test_sample.get("choices"))}
        ],
    })

    return messages

# **5. Load Model**

In [None]:
# processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
# model = AutoModelForVision2Seq.from_pretrained(
#     MODEL_NAME,
#     device_map="auto",
#     torch_dtype=torch.float16
# )
# model.eval()
# print("Model loaded")

import torch
from transformers import AutoTokenizer, AutoModel
path = "OpenGVLab/InternVL3-8B"
model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    use_flash_attn=True,
    trust_remote_code=True).eval().cuda()

# **6. Inference**

In [None]:
# @torch.inference_mode()
# def model_infer(messages: list) -> str:
#     prompt = processor.apply_chat_template(
#         messages,
#         tokenize=False,
#         add_generation_prompt=True,
#     )

#     inputs = processor(
#         text=[prompt],
#         return_tensors="pt",
#     ).to(model.device)

#     outputs = model.generate(
#         **inputs,
#         do_sample=False,
#         max_new_tokens=32,
#     )

#     text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
#     lines = text.strip().splitlines()
#     answer = lines[-1].strip() if lines else text.strip()

#     return answer

@torch.inference_mode()
def model_infer(image, question: str) -> str:
    """
    image: PIL.Image
    question: string
    """
    response = model.chat(
        processor,
        image=image,
        question=question,
        generation_config={
            "do_sample": False,
            "max_new_tokens": 32
        }
    )
    return response.strip()


# **7. EXPERIMENT: Test different K_SHOTS**

In [None]:
from PIL import Image

for sample in tqdm(test_data, desc=f"Processing ({desc})"):
    try:
        # ===== LẤY IMAGE ĐÚNG CHO SAMPLE =====
        image_path = get_test_image_path(sample)
        image = Image.open(image_path).convert("RGB")

        # ===== BUILD QUESTION + CHOICES =====
        question = build_qa_text(
            sample["question"],
            sample.get("choices")
        )

        raw_answer = model_infer(
            image=image,
            question=question
        )

        predictions.append({
            "id": sample["id"],
            "answer": raw_answer
        })

    except Exception as e:
        errors.append({"id": sample["id"], "error": str(e)})
        predictions.append({"id": sample["id"], "answer": ""})


In [None]:
print(f"\n{'='*70}")
print("RUNNING EXPERIMENTS")
print(f"{'='*70}\n")

# Test with K = 0 (zero-shot), 1, 2, 3 with high score threshold
K_EXPERIMENTS = [
    {"k": 0, "min_score": 0.0, "desc": "Zero-shot"},
    {"k": 1, "min_score": 0.6, "desc": "1-shot (score>=0.6)"},
    {"k": 1, "min_score": 0.7, "desc": "1-shot (score>=0.7)"},
    {"k": 2, "min_score": 0.6, "desc": "2-shot (score>=0.6)"},
    {"k": 3, "min_score": 0.6, "desc": "3-shot (score>=0.6)"},
]

results_summary = []

for exp in K_EXPERIMENTS:
    k = exp["k"]
    min_score = exp["min_score"]
    desc = exp["desc"]

    print(f"\n{'='*70}")
    print(f"Experiment: {desc}")
    print(f"{'='*70}")

    predictions = []
    errors = []

    for sample in tqdm(test_data, desc=f"Processing ({desc})"):
        try:
            messages = build_messages_adaptive(
                test_sample=sample,
                retrieval_cache=retrieval_cache,
                train_by_id=train_by_id,
                k_shot=k,
                min_score=min_score,
            )

            raw_answer = model_infer(messages)
            predictions.append({"id": sample["id"], "answer": raw_answer})

        except Exception as e:
            errors.append({"id": sample["id"], "error": str(e)})
            predictions.append({"id": sample["id"], "answer": ""})

    # Calculate accuracy
    correct = 0
    total = 0
    mc_correct, mc_total = 0, 0
    yn_correct, yn_total = 0, 0

    for pred in predictions:
        test_sample = next((s for s in test_data if s["id"] == pred["id"]), None)
        if not test_sample:
            continue

        total += 1
        pred_ans = pred["answer"].strip()
        gt_ans = test_sample["answer"].strip()
        q_type = test_sample.get("question_type", "")

        is_correct = (pred_ans == gt_ans)
        if is_correct:
            correct += 1

        if q_type == "Multiple choice":
            mc_total += 1
            if is_correct:
                mc_correct += 1
        elif q_type == "Yes/No":
            yn_total += 1
            if is_correct:
                yn_correct += 1

    accuracy = correct / total if total > 0 else 0
    mc_acc = mc_correct / mc_total if mc_total > 0 else 0
    yn_acc = yn_correct / yn_total if yn_total > 0 else 0

    print(f"\nResults for {desc}:")
    print(f"  Overall: {correct}/{total} = {accuracy*100:.1f}%")
    print(f"  MC: {mc_correct}/{mc_total} = {mc_acc*100:.1f}%")
    print(f"  YN: {yn_correct}/{yn_total} = {yn_acc*100:.1f}%")

    results_summary.append({
        "config": desc,
        "k_shot": k,
        "min_score": min_score,
        "overall_acc": accuracy,
        "mc_acc": mc_acc,
        "yn_acc": yn_acc,
        "correct": correct,
        "total": total,
    })

    # Save predictions
    out_path = os.path.join(RUN_DIR, f"subtask2_{desc.replace(' ', '_').replace('(', '').replace(')', '')}.json")
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(predictions, f, ensure_ascii=False, indent=2)
    print(f"  Saved to: {out_path}")


# **8. Summary**

In [None]:

print(f"\n{'='*50}")
print("EXPERIMENT SUMMARY")
print(f"{'='*50}\n")

print(f"{'Config':<25} {'Overall':<10} {'MC':<10} {'YN':<10}")
print("-" * 50)
for res in results_summary:
    print(f"{res['config']:<25} {res['overall_acc']*100:>6.1f}%   {res['mc_acc']*100:>6.1f}%   {res['yn_acc']*100:>6.1f}%")

# Find best config
best = max(results_summary, key=lambda x: x['overall_acc'])
print(f"\n BEST CONFIG: {best['config']}")
print(f"   Overall: {best['overall_acc']*100:.1f}%")
print(f"   MC: {best['mc_acc']*100:.1f}%")
print(f"   YN: {best['yn_acc']*100:.1f}%")

# Save summary
summary_path = os.path.join(RUN_DIR, "experiment_summary.json")
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(results_summary, f, ensure_ascii=False, indent=2)
print(f"\n Summary saved to: {summary_path}")