# Using Direct Preference Optimization to Fine-Tune PaliGemma

In [None]:
!pip install -qU transformers trl peft accelerate bitsandbytes

In [None]:
from huggingface_hub import notebook_login
notebook_login()

Load the model and the processor:

In [None]:
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq

model_id = 'google/paligemma-3b-pt-448'

processor = AutoProcessor.from_pretrained(model_id)
tokenizer = processor.tokenizer

# reference base model
model_ref = AutoModelForVision2Seq.from_pretrained(model_id)

We will use Q-LoRA for fine-tuning:

In [None]:
from trl import ModelConfig, get_peft_config, get_quantization_config, get_kbit_device_map
from transformers import BitsAndBytesConfig
from peft import LoraConfig

model_config = ModelConfig()

quantization_config = get_quantization_config(model_config)
device_map = get_kbit_device_map(model_config)
peft_config = get_peft_config(model_config)

model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map=device_map
)

We need a dataset for fine-tuning

In [None]:
from datasets import load_dataset

dataset = load_dataset('HuggingFaceH4/rlaif-v_formatted')

PaliGemma does not have a chat template by default, so we need to create a template:

In [None]:
import logging
import os
import torch

processor.chat_template = """
{% if not add_generation_prompt is defined %}
    {% set add_generation_prompt = false %}
{% endif %}
A chat between a curious user and an artificial intelligence assistant.
The assistant gives helpful, detailed, and polite answers to the user's questions.
{% for message in messages %}
    {% if message['role'] == 'user' %}
        USER:
    {% else %}
        ASSISTANT:
    {% endif %}

    {% for item in message['content'] %}
        {% if item['type'] == 'text' %}
            {{ item['text'] }}
        {% elif item['type'] == 'image' %}
            <image>
        {% endif %}
    {% endfor %}

    {% if message['role'] == 'user' %}
    {% else %}
        {{eos_token}}
    {% endif %}
{% endfor %}

{% if add_generation_prompt %}
    ASSISTANT:
{% endif %}
"""

def process(row):
  row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False)
  row["chosen"] = processor.apply_chat_template(row["chosen"], tokenize=False)
  row["rejected"] = processor.apply_chat_template(row["rejected"], tokenize=False)
  row["images"][0] = row["images"][0].convert("RGB")
  return row

In [None]:
train_dataset = dataset["train"].train_test_split(test_size=0.5)["train"]
eval_dataset = dataset["test"]

Now we can prepare the trainer:

In [None]:
from trl import DPOConfig, DPOTrainer, ModelConfig

training_args = DPOConfig(output_dir="/content/outs",
                          per_device_train_batch_size=8,
                          num_train_epochs=1)

trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=processor,
    dataset_num_proc=32,
    # uncomment for peft
    # peft_config=get_peft_config(model_config)
)
trainer.train()

trainer.save_model(training_args.output_dir)