## PaliGemma ü§ù Direct Preference Optimization

In [None]:
!pip install -q transformers trl datasets peft bitsandbytes accelerate

Let's load the model and the processor. First we need to login using `notebook_login`¬†since PaliGemma is a gated model. For utmost security, you can create a fine-grained token with access to PaliGemma and pass that.

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
from transformers import AutoModelForVision2Seq, AutoProcessor
import torch

model_id = "google/paligemma-3b-pt-448"
model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model_ref = AutoModelForVision2Seq.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id, do_image_splitting=False)
tokenizer = processor.tokenizer

If you want to opt for parameter efficient fine-tuning, feel free to run below cell.

In [None]:
from trl import (
    ModelConfig,
    get_peft_config,
    get_quantization_config
)
from peft import LoraConfig
from transformers import BitsAndBytesConfig
from trl import get_kbit_device_map
from transformers import AutoModelForVision2Seq

model_config = ModelConfig()
# do not run below two lines if you're not doing QLoRA
quantization_config = get_quantization_config(model_config)
device_map=get_kbit_device_map()
peft_config = get_peft_config(model_config)
model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=device_map)

In [None]:
from datasets import load_dataset

ds = load_dataset("HuggingFaceH4/rlaif-v_formatted")

###¬†Preprocessing

PaliGemma doesn't have a chat template by default, so we need to add it to later preprocess.

In [4]:
import logging
import os
import torch
from datasets import load_dataset

# paligemma has no chat template, we need to add it
processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""


def process(row):
  row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False)
  row["chosen"] = processor.apply_chat_template(row["chosen"], tokenize=False)
  row["rejected"] = processor.apply_chat_template(row["rejected"], tokenize=False)
  row["images"][0] = row["images"][0].convert("RGB")
  return row

We'll do a very small run (as this notebook is made for education purposes).

In [None]:
train_dataset = ds["train"].train_test_split(test_size=0.5)["train"]
eval_dataset = ds["test"]

In [None]:
from trl import DPOConfig, DPOTrainer, ModelConfig

training_args = DPOConfig(output_dir="/content/outs",
                          per_device_train_batch_size=8,
                          num_train_epochs=1)

trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=processor,
    dataset_num_proc=32,
    # uncomment for peft
    # peft_config=get_peft_config(model_config)
)
trainer.train()

trainer.save_model(training_args.output_dir)