In [None]:
import os
import torch
import pandas as pd
from PIL import Image
from torch.optim import AdamW
from tqdm.notebook import tqdm
from accelerate import Accelerator
from peft import LoraConfig, get_peft_model
from bert_score import score as bertscore_score
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from transformers import BlipProcessor,BlipForQuestionAnswering,get_scheduler

# Local configuration
# SRC_PATH = '../data/csvs/vqa.csv'
# IMAGE_DIR = '../data/curated_images'
# DEST_DIR = "../data/csvs"

# Kaggle configuration
SRC_PATH = '../input/vrdata/data/csvs/vqa.csv'
IMAGE_DIR = '../input/vrdata/data/curated_images'
DEST_DIR = "/kaggle/working"

MODEL_NAME         = "Salesforce/blip-vqa-base"
BATCH_SIZE         = 16
EVAL_BATCH_SIZE    = 32
N_EPOCHS           = 3
LEARNING_RATE      = 5e-5
LORA_R             = 16
LORA_ALPHA         = 32
LORA_DROPOUT       = 0.05
MAX_LENGTH         = 128
WARMUP_STEPS       = 0

# Accelerator setup
accelerator = Accelerator(mixed_precision="fp16")
DEVICE = accelerator.device
print(f"Using {accelerator.state.num_processes} GPU(s), fp16")

# Dataset
class VQADataset(Dataset):
    def __init__(self, df, image_dir):
        self.image_dir = image_dir
        self.entries = []
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Verifying images"):
            img_path = os.path.join(image_dir, str(row['filename']))
            if os.path.exists(img_path):
                self.entries.append((img_path, str(row['question']), str(row['answer'])))
            else:
                print(f"Warning: Missing {img_path}")

        if not self.entries:
            raise RuntimeError("No valid images found.")

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        img_path, question, answer = self.entries[idx]
        image = Image.open(img_path).convert("RGB")
        return {"image": image, "question": question, "answer": answer}

# Collate function
def vqa_collate_fn(batch):
    images   = [item["image"]    for item in batch]
    questions= [item["question"] for item in batch]
    answers  = [item["answer"]   for item in batch]
    enc = processor(
        images      = images,
        text        = questions,
        text_target = answers,
        padding     = "longest",
        truncation  = True,
        max_length  = MAX_LENGTH,
        return_tensors = "pt"
    ).to(DEVICE)

    return {
        "input_ids":      enc.input_ids,
        "attention_mask": enc.attention_mask,
        "pixel_values":   enc.pixel_values,
        "labels":         enc.labels
    }

# Load model and processor
print(f"Loading {MODEL_NAME}…")
model     = BlipForQuestionAnswering.from_pretrained(MODEL_NAME)
processor = BlipProcessor.from_pretrained(MODEL_NAME, use_fast=True)
print(f"{MODEL_NAME} loaded successfully.")

# Apply LoRA
lora_cfg = LoraConfig(
    r             = LORA_R,
    lora_alpha    = LORA_ALPHA,
    target_modules= ["q_proj","k_proj","v_proj","query","key","value"],
    lora_dropout  = LORA_DROPOUT,
    bias          = "none",
)
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()

# Prepare datasets and loaders
df = pd.read_csv(SRC_PATH)
train_df, val_df = train_test_split(df, test_size=0.2, random_state=7)

train_ds = VQADataset(train_df, IMAGE_DIR)
val_ds   = VQADataset(val_df,   IMAGE_DIR)
print(f"train={len(train_ds)}, val={len(val_ds)}")

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=vqa_collate_fn)
eval_loader  = DataLoader(val_ds,   batch_size=EVAL_BATCH_SIZE, shuffle=False, collate_fn=vqa_collate_fn)

# Optimizer and scheduler
optimizer   = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
total_steps = len(train_loader) * N_EPOCHS
scheduler   = get_scheduler(
    "linear", optimizer=optimizer,
    num_warmup_steps=WARMUP_STEPS,
    num_training_steps=total_steps
)

# Accelerator preparation
model, optimizer, train_loader, eval_loader, scheduler = accelerator.prepare(
    model, optimizer, train_loader, eval_loader, scheduler
)
accelerator.init_trackers("vqa-lora")

# Finetuning
print("Starting finetuning…")
for epoch in range(1, N_EPOCHS + 1):
    model.train()
    train_bar, total_loss = tqdm(train_loader, desc=f"Epoch {epoch} train"), 0.0
    for step, batch in enumerate(train_bar, 1):
        out   = model(**batch)
        loss  = out.loss
        accelerator.backward(loss)
        accelerator.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step(); scheduler.step(); optimizer.zero_grad()
        total_loss += loss.item()
        train_bar.set_postfix(train_loss=total_loss/step)

    # Evaluation
    model.eval()
    preds, refs = [], []
    for batch in tqdm(eval_loader, desc=f"Epoch {epoch} eval"):
        with torch.no_grad():
            gen_ids = model.generate(
                input_ids      = batch["input_ids"],
                attention_mask = batch["attention_mask"],
                pixel_values   = batch["pixel_values"],
                max_length     = MAX_LENGTH,
                num_beams      = 1,
                no_repeat_ngram_size = 2
            )
        preds.extend(processor.tokenizer.batch_decode(
            gen_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True))
        refs.extend(processor.tokenizer.batch_decode(
            batch["labels"], skip_special_tokens=True, clean_up_tokenization_spaces=True))

    P, R, F1 = bertscore_score(preds, refs, lang="en", rescale_with_baseline=True)
    avg_f1 = F1.mean().item()

    if accelerator.is_local_main_process:
        print(f"\nEpoch {epoch}: train_loss={total_loss/len(train_loader):.4f}, eval_bertscore_f1={avg_f1:.4f}\n")
        ckpt = os.path.join(DEST_DIR, f"blip{epoch}")
        os.makedirs(ckpt, exist_ok=True)
        model.save_pretrained(ckpt)
        processor.save_pretrained(ckpt)
        print(f"Model saved to {ckpt}.")

print("Finetuning complete.")