In [None]:
from google.colab import drive
from google.colab import userdata

drive.mount('/content/drive')
%cd /content/drive/MyDrive/

In [None]:
from torch.utils.data import Dataset
from PIL import Image
import pandas as pd
import os

class RISCDataset(Dataset):
    def __init__(self, csv_path, image_dir, split="train", processor=None, max_target_length=264):
        self.df = pd.read_csv(csv_path)
        self.df = self.df[self.df['split'] == split].reset_index(drop=True)
        self.image_dir = image_dir
        self.processor = processor
        self.tokenizer = processor.tokenizer
        self.max_target_length = max_target_length

        self.samples = [
            (os.path.join(self.image_dir, img), caption)
            for img, caption in zip(self.df["image"], self.df["training_caption"])
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, caption = self.samples[idx]
        image = Image.open(img_path).convert("RGB")

        prompt = "<image> <bos> Describe this image:"

        inputs = self.processor(
            text=prompt,
            images=image,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )

        tokenized_caption = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.max_target_length,
            return_tensors="pt"
        )["input_ids"]

        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        tokenized_caption = tokenized_caption.squeeze(0)

        tokenized_caption[tokenized_caption == self.tokenizer.pad_token_id] = -100
        inputs["labels"] = tokenized_caption

        return inputs

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor

dataset_dir = 'RISCM'
image_dir = os.path.join(dataset_dir, 'resized')
csv_path = os.path.join(dataset_dir, 'captions_cleaned.csv')
model_name = "google/paligemma-3b-pt-224"
hf_token = userdata.get('HF_TOKEN')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from transformers import Trainer, TrainingArguments, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig
import wandb
import numpy as np

model = PaliGemmaForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, token=hf_token).to(device)

for param in model.vision_tower.parameters():
    param.requires_grad = False

for param in model.multi_modal_projector.parameters():
    param.requires_grad = False

last_layer = model.vision_tower.vision_model.encoder.layers[-1]
for param in last_layer.parameters():
    param.requires_grad = True

bnb_config = BitsAndBytesConfig(
      load_in_4bit=True,
      bnb_4bit_quant_type="nf4",
      bnb_4bit_compute_dtype=torch.bfloat16
)

peft_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=32,
    lora_alpha=128,
    lora_dropout=0.1,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)

model = PaliGemmaForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, token=hf_token).to(device)
processor = AutoProcessor.from_pretrained(model_name, token=hf_token, use_fast=True)
model = PaliGemmaForConditionalGeneration.from_pretrained(model_name, quantization_config=bnb_config, device_map={"":0}, token=hf_token)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()


In [None]:
wandb.login(key = userdata.get('WANDB_API_KEY'))

wandb.init(project="DI-725-Project-III", name="qlora_paligemma")

wandb.config.update({
    "model_name": model_name,
    "vision_tower_requires_grad": False,
    "peft_config": peft_config.to_dict(),
    "bnb_config": bnb_config.to_dict(),
})

fraction = 1

train_dataset = RISCDataset(csv_path, image_dir, split="train", processor=processor)
val_dataset = RISCDataset(csv_path, image_dir, split="val", processor=processor)

train_size = int(len(train_dataset) * fraction)
val_size = int(len(val_dataset) * fraction)

train_dataset = torch.utils.data.Subset(train_dataset, range(train_size))
val_dataset = torch.utils.data.Subset(val_dataset, range(val_size))

training_args = TrainingArguments(
    output_dir="./qlora_output",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,
    warmup_steps=1000,
    learning_rate=5e-5,
    weight_decay=0.01,
    num_train_epochs=1,
    eval_strategy="epoch",
    save_strategy="epoch",
    adam_beta2=0.999,
    optim="adamw_8bit",
    save_steps=4000,
    eval_steps=4000,
    save_total_limit=3,
    label_names=["labels"],
    fp16=True,
    report_to="wandb",
    run_name="qlora_paligemma",
    logging_dir="./logs",
    logging_steps=10,
    disable_tqdm=False,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    dataloader_num_workers=4
)

wandb.config.update({
    "training_args": training_args.to_dict(),
})

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    processing_class=processor,
    data_collator = lambda batch: {
                    k: torch.stack([example[k] for example in batch])
                    for k in batch[0]
    }
)

trainer.train()

wandb.finish()