In [None]:
#Dataset preparation for PaliGemma

from torch.utils.data import Dataset
from PIL import Image
import pandas as pd
import os

class RISCDataset(Dataset):
    def __init__(self, csv_path, image_dir, split="train", processor=None, max_target_length=32):
        self.df = pd.read_csv(csv_path)
        self.df = self.df[self.df['split'] == split].reset_index(drop=True)
        self.image_dir = image_dir
        self.processor = processor
        self.max_target_length = max_target_length

        self.samples = [
            (os.path.join(self.image_dir, img), caption)
            for img, caption in zip(self.df["image"], self.df["training_caption"])
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, caption = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        prompt = "<image> <bos> Describe this image in detail:"

        inputs = self.processor(
            images=image,
            text=prompt,
            suffix=caption,
            padding="max_length",
            max_length=self.max_target_length,
            return_tensors="pt"
        )

        batch = {k: v.squeeze(0) for k, v in inputs.items()}
        batch["labels"] = batch["input_ids"].clone()
        if self.processor.tokenizer.pad_token_id is not None:
            batch["labels"][batch["labels"] == self.processor.tokenizer.pad_token_id] = -100
        return batch

In [None]:
# Configurations

import torch
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor

dataset_dir = 'RISCM'
image_dir = os.path.join(dataset_dir, 'resized')
csv_path = os.path.join(dataset_dir, 'captions_cleaned.csv')
model_name = "google/paligemma-3b-pt-224"
hf_token = "HF TOKEN FOR PALIGEMMA PERMISSION"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# QLORA initialization

from transformers import Trainer, TrainingArguments, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig
import wandb

# wandb.init(project="DI-725-Project", name="qlora_paligemma")

model = PaliGemmaForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, token=hf_token).to(device)

for param in model.vision_tower.parameters():
      param.requires_grad = False
for param in model.multi_modal_projector.parameters():
      param.requires_grad = False

bnb_config = BitsAndBytesConfig(
      load_in_4bit=True,
      bnb_4bit_quant_type="nf4",
      bnb_4bit_compute_dtype=torch.bfloat16
)

peft_config = LoraConfig(
    task_type="CAUSAL_LM",
    inference_mode=False,
    r=64,
    lora_alpha=256,
    bias="none",
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
)

# Loading the model with qlora configurations

model = PaliGemmaForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, token=hf_token).to(device)
processor = AutoProcessor.from_pretrained(model_name, token=hf_token, use_fast=True)
model = PaliGemmaForConditionalGeneration.from_pretrained(model_name, quantization_config=bnb_config, device_map={"":0}, token=hf_token)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# wandb.config.update({
#    "model_name": model_name,
#    "vision_tower_requires_grad": False,
#    "peft_config": peft_config.to_dict(),
#    "bnb_config": bnb_config.to_dict(),
# })

In [None]:
# Training

fraction = 1 # For dataset

train_dataset = RISCDataset(csv_path, image_dir, split="train", processor=processor)
val_dataset = RISCDataset(csv_path, image_dir, split="val", processor=processor)

train_size = int(len(train_dataset) * fraction)
val_size = int(len(val_dataset) * fraction)

train_dataset = torch.utils.data.Subset(train_dataset, range(train_size))
val_dataset = torch.utils.data.Subset(val_dataset, range(val_size))

training_args = TrainingArguments(
    output_dir="./qlora_output",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,
    warmup_steps=1000,
    learning_rate=1e-6,
    weight_decay=1e-6,
    num_train_epochs=2,
    eval_strategy="steps",
    save_strategy="steps",
    adam_beta2=0.999,
    optim="adamw_torch",
    save_steps=500,
    eval_steps=500,
    save_total_limit=1,
    label_names=["labels"],
    fp16=True,
    report_to="wandb",
    run_name="qlora_paligemma",
    logging_dir="./logs",
    logging_steps=10,
    disable_tqdm=False,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    dataloader_num_workers=4
)

# wandb.config.update({
#    "training_args": training_args.to_dict(),
# })

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    processing_class=processor,
    data_collator = lambda batch: {
                    k: torch.stack([example[k] for example in batch])
                    for k in batch[0]
    }
)

trainer.train()

#  wandb.finish()