In [None]:
!pip install import-ipynb

In [None]:
%cd /content/drive/MyDrive/SHBT261FinalProject/code/src

In [None]:
"""
Main entry point for TextVQA experiments
Provides a unified interface for all tasks
"""

import torch
import os

import import_ipynb
from model import get_model_and_processor, load_lora_weights
from data_loader import TextVQADataset
from evaluate_zeroshot import run_inference_zeroshot
from evaluate_finetuned import run_inference_finetuned
from train_lora import train, TrainingConfig
from analyze_results import *

RESULT_DIR = "results"
os.makedirs(RESULT_DIR, exist_ok=True)

### Zero-shot

In [None]:
dataset = TextVQADataset("textvqa_data/data", split="validation", max_samples=None)

model, processor = get_model_and_processor(
    model_name="Qwen/Qwen2.5-VL-3B-Instruct",
    use_4bit=True,
    use_lora=False,
)

outputs = run_inference_zeroshot(
    model=model,
    processor=processor,
    dataset=dataset,
    use_ocr=False,
    output_json=f"{RESULT_DIR}/zeroshot_metric.json"
)

outputs["metrics"]


#### Ablation: plain vs OCR enhanced

In [None]:
outputs = run_inference_zeroshot(
    model=model,
    processor=processor,
    dataset=dataset,
    use_ocr=True,
    output_json=f"{RESULT_DIR}/zeroshot_metric_ocr.json"
)

outputs["metrics"]


### Fine-tune

In [None]:
# full fine-tune r=16
train_lora(
    data_dir="textvqa_data/data",
    output_dir="checkpoints/full_r16",
    model_name="Qwen/Qwen2.5-VL-3B-Instruct",
    max_train_samples=None,
    num_epochs=3,
    lora_r=16,
)

model, processor = load_lora_weights(
    model_name="Qwen/Qwen2.5-VL-3B-Instruct",
    lora_path="checkpoints/full_r16",
    use_4bit=True,
)

outputs = run_inference_finetuned(
    model=model,
    processor=processor,
    dataset=dataset,
    output_json=f"{RESULT_DIR}/full_finetune_metric_r16.json",
)

outputs["metrics"]

#### Ablation: training data size (5k vs full)

In [None]:
# 5k fine-tune r=16
train_lora(
    data_dir="textvqa_data/data",
    output_dir="checkpoints/5k_r16",
    model_name="Qwen/Qwen2.5-VL-3B-Instruct",
    max_train_samples=5000,
    num_epochs=1,
    lora_r=16,
)

model, processor = load_lora_weights(
    model_name="Qwen/Qwen2.5-VL-3B-Instruct",
    lora_path="checkpoints/5k_r16",
    use_4bit=True,
)

outputs = run_inference_finetuned(
    model=model,
    processor=processor,
    dataset=dataset,
    output_json=f"{RESULT_DIR}/5k_finetune_metric_r16.json",
)

outputs["metrics"]

#### Ablation: lora rank (8/16/32)

In [None]:
# 5k fine-tune r=8
train_lora(
    data_dir="textvqa_data/data",
    output_dir="checkpoints/5k_r16",
    model_name="Qwen/Qwen2.5-VL-3B-Instruct",
    max_train_samples=5000,
    num_epochs=1,
    lora_r=8,
)

model, processor = load_lora_weights(
    model_name="Qwen/Qwen2.5-VL-3B-Instruct",
    lora_path="checkpoints/5k_r8",
    use_4bit=True,
)

outputs = run_inference_finetuned(
    model=model,
    processor=processor,
    dataset=dataset,
    output_json=f"{RESULT_DIR}/5k_finetune_metric_r8.json",
)

outputs["metrics"]

In [None]:
# 5k fine-tune r=32
train_lora(
    data_dir="textvqa_data/data",
    output_dir="checkpoints/5k_r16",
    model_name="Qwen/Qwen2.5-VL-3B-Instruct",
    max_train_samples=5000,
    num_epochs=1,
    lora_r=32,
)

model, processor = load_lora_weights(
    model_name="Qwen/Qwen2.5-VL-3B-Instruct",
    lora_path="checkpoints/5k_r32",
    use_4bit=True,
)

outputs = run_inference_finetuned(
    model=model,
    processor=processor,
    dataset=dataset,
    output_json=f"{RESULT_DIR}/5k_finetune_metric_r32.json",
)

outputs["metrics"]

### Error analysis

In [None]:
# Analyze zero-shot
run_analysis(
    predictions_path=f"{RESULT_DIR}/zeroshot_metric.json",
    model_name="zeroshot",
    output_dir=RESULT_DIR,
)

# Analyze full fine-tune
run_analysis(
    predictions_path=f"{RESULT_DIR}/full_finetune_metric_r16.json",
    model_name="full_r16",
    output_dir=RESULT_DIR,
)

# Compare two models
compare_two_models(
    zs_path=f"{RESULT_DIR}/zeroshot_metric.json",
    ft_path=f"{RESULT_DIR}/full_finetune_metric_r16.json",
    output_dir=RESULT_DIR,
)