In [None]:
!pip install import-ipynb

In [None]:
%cd /content/drive/MyDrive/SHBT261FinalProject/code/src

In [3]:
"""
LoRA Fine-tuning Script for TextVQA
Fine-tunes Qwen2.5-VL with LoRA adapters on TextVQA training set
"""
import os
import json
from datetime import datetime
from collections import Counter

import torch
from torch.utils.data import DataLoader
from transformers import (
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)
from tqdm import tqdm
from PIL import Image

import import_ipynb
from data_loader import TextVQADataset
from model import get_model_and_processor, save_lora_weights, get_generation_config
from metrics import compute_all_metrics, print_metrics

In [None]:
class TrainingConfig:
    def __init__(
        self,
        model_name="Qwen/Qwen2.5-VL-3B-Instruct",
        data_dir="textvqa_data/data",
        output_dir="checkpoints",
        lora_r=32,
        lora_alpha=32,
        lora_dropout=0.05,
        num_epochs=3,
        batch_size=2,
        gradient_accumulation_steps=8,
        learning_rate=2e-4,
        warmup_ratio=0.1,
        weight_decay=0.01,
        max_grad_norm=1.0,
        max_train_samples=None,
        max_val_samples=500,
        use_4bit=True,
        seed=42,
        eval_steps=500,
        save_steps=500,
        logging_steps=50,
    ):
        self.model_name = model_name
        self.data_dir = data_dir
        self.output_dir = output_dir
        self.lora_r = lora_r
        self.lora_alpha = lora_alpha
        self.lora_dropout = lora_dropout
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.learning_rate = learning_rate
        self.warmup_ratio = warmup_ratio
        self.weight_decay = weight_decay
        self.max_grad_norm = max_grad_norm
        self.max_train_samples = max_train_samples
        self.max_val_samples = max_val_samples
        self.use_4bit = use_4bit
        self.seed = seed
        self.eval_steps = eval_steps
        self.save_steps = save_steps
        self.logging_steps = logging_steps

In [None]:
def get_most_common_answer(answers):
    if not answers:
        return ""
    counter = Counter(answers)
    return counter.most_common(1)[0][0]

def prepare_training_batch(batch, processor, device):
    images, questions, answers = [], [], []

    for sample in batch:
        images.append(sample["image"])
        questions.append(sample["question"])
        answers.append(get_most_common_answer(sample["answers"]))

    conversations = []
    for q, a in zip(questions, answers):
        conv = [
            {"role": "user",
             "content": [
                 {"type": "image"},
                 {"type": "text",
                  "text": f"{q}\nAnswer with only the exact text/number from the image."}
             ]},
            {"role": "assistant", "content": a},
        ]
        conversations.append(conv)

    texts = [
        processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
        for conv in conversations
    ]

    inputs = processor(
        text=texts, images=images, padding=True, return_tensors="pt"
    )
    inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
    inputs["labels"] = inputs["input_ids"].clone()

    return inputs

In [None]:
def evaluate_model(model, processor, dataset, max_samples=None, device="cuda"):
    model.eval()
    predictions, ground_truths = [], []
    gen_cfg = get_generation_config()

    n = min(len(dataset), max_samples or len(dataset))

    with torch.no_grad():
        for idx in tqdm(range(n), desc="Evaluating", leave=False):
            sample = dataset[idx]
            img = sample["image"]
            q = sample["question"]

            conv = [
                {"role": "user",
                 "content": [
                     {"type": "image"},
                     {"type": "text", "text": f"{q}\nAnswer concisely."}
                 ]}
            ]

            text = processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)
            inputs = processor(text=[text], images=[img], return_tensors="pt", padding=True)
            inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}

            out_ids = model.generate(**inputs, **gen_cfg)
            input_len = inputs["input_ids"].shape[1]
            gen_ids = out_ids[:, input_len:]

            pred = processor.batch_decode(gen_ids, skip_special_tokens=True)[0].strip()
            predictions.append(pred)
            ground_truths.append(sample["answers"])

    metrics = compute_all_metrics(predictions, ground_truths)
    model.train()
    return metrics

In [None]:
def train(config: TrainingConfig):
    print("\n==============================")
    print("   TextVQA LoRA Fine-tuning   ")
    print("==============================")

    torch.manual_seed(config.seed)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = os.path.join(config.output_dir, f"run_{timestamp}")
    os.makedirs(run_dir, exist_ok=True)

    with open(os.path.join(run_dir, "config.json"), "w") as f:
        json.dump(vars(config), f, indent=2)

    print("Loading model...")
    model, processor = get_model_and_processor(
        model_name=config.model_name,
        use_4bit=config.use_4bit,
        use_lora=True,
        lora_r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
    )
    device = next(model.parameters()).device
    print("Model loaded on:", device)

    print("Loading datasets...")
    train_ds = TextVQADataset(config.data_dir, "train", max_samples=config.max_train_samples)
    val_ds = TextVQADataset(config.data_dir, "validation", max_samples=config.max_val_samples)

    train_loader = DataLoader(
        train_ds,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=0,     # IMPORTANT for Colab
        collate_fn=lambda x: x,
    )

    total_steps = (len(train_loader) // config.gradient_accumulation_steps) * config.num_epochs
    warmup_steps = int(total_steps * config.warmup_ratio)

    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)

    model.train()
    global_step, best_acc = 0, 0.0
    log = []

    for epoch in range(config.num_epochs):
        epoch_loss = 0
        print(f"\nEpoch {epoch+1}/{config.num_epochs}")

        for step, batch in enumerate(tqdm(train_loader)):
            inputs = prepare_training_batch(batch, processor, device)
            outputs = model(**inputs)
            loss = outputs.loss / config.gradient_accumulation_steps
            loss.backward()

            epoch_loss += loss.item() * config.gradient_accumulation_steps

            if (step + 1) % config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1

                if global_step % config.logging_steps == 0:
                    avg_loss = epoch_loss / (step + 1)
                    log.append({"step": global_step, "loss": avg_loss})

                if global_step % config.eval_steps == 0:
                    metrics = evaluate_model(model, processor, val_ds, device=device)
                    print_metrics(metrics, f"Validation @ step {global_step}")
                    if metrics["accuracy"] > best_acc:
                        best_acc = metrics["accuracy"]
                        save_lora_weights(model, os.path.join(run_dir, "best_model"))

        print(f"Epoch {epoch+1} finished. Loss: {epoch_loss:.4f}")

    save_lora_weights(model, os.path.join(run_dir, "final_model"))
    print("Training completed.")