<a href="https://colab.research.google.com/github/mrm8488/shared_colab_notebooks/blob/master/Idefics_8b_dpo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Preference Optimization for Vision Language Models with `TRL`

![pic](https://huggingface.co/blog/assets/dpo_vlm/thumbnail.png)

> Colab for the post :https://huggingface.co/blog/dpo_vlm

> Author: [mrm8488](https://twitter.com/mrm8488)

### Install required dependecies

In [1]:
! pip install -U transformers datasets trl peft accelerate wandb

Collecting transformers
  Downloading transformers-4.42.4-py3-none-any.whl (9.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.3/9.3 MB[0m [31m23.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.42.3
    Uninstalling transformers-4.42.3:
      Successfully uninstalled transformers-4.42.3
Successfully installed transformers-4.42.4


In [2]:
! wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


### Check our GPU! ✅

In [3]:
! nvidia-smi

Fri Jul 12 12:26:07 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA L4                      Off | 00000000:00:03.0 Off |                    0 |
| N/A   60C    P8              17W /  72W |      1MiB / 23034MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

### Imports

In [5]:
from datasets import features, load_dataset
from transformers import AutoModelForVision2Seq, AutoProcessor
import torch
from trl import DPOConfig, DPOTrainer
from peft import LoraConfig
import os

In [6]:
os.environ["WANDB_PROJECT"] = "idefics2-8b-dpo-rlaif-v"

### Load the model and dataset 🧠 📚

In [7]:
ds_id = "openbmb/RLAIF-V-Dataset"
dataset = load_dataset(ds_id, split="train")

In [8]:
dataset

Dataset({
    features: ['ds_name', 'image', 'question', 'chosen', 'rejected', 'origin_dataset', 'origin_split', 'idx', 'image_path'],
    num_rows: 83132
})

In [12]:
# Uncomment the next line to train on a subset of 20k examples
# dataset = dataset.shuffle(seed=42).select(range(20000))

In [13]:
dataset

Dataset({
    features: ['ds_name', 'image', 'question', 'chosen', 'rejected', 'origin_dataset', 'origin_split', 'idx', 'image_path'],
    num_rows: 20000
})

In [11]:
model_id = "HuggingFaceM4/idefics2-8b"

model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.bfloat16)
processor = AutoProcessor.from_pretrained(model_id, do_image_splitting=False)

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


### Format the dataset ⚙

In [14]:
def format_ds(example):
    # Prepare the input for the chat template
    prompt = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": example["question"]}]}]
    chosen = [{"role": "assistant", "content": [{"type": "text", "text": example["chosen"]}]}]
    rejected = [{"role": "assistant", "content": [{"type": "text", "text": example["rejected"]}]}]
    # Apply the chat template
    prompt = processor.apply_chat_template(prompt, tokenize=False)
    chosen = processor.apply_chat_template(chosen, tokenize=False)
    rejected = processor.apply_chat_template(rejected, tokenize=False)
    # Resize the image to ensure it fits within the maximum allowable
    # size of the processor to prevent OOM errors.
    max_size = processor.image_processor.size["longest_edge"] // 2
    example["image"].thumbnail((max_size, max_size))
    return {"images": [example["image"]], "prompt": prompt, "chosen": chosen, "rejected": rejected}

In [15]:
dataset = dataset.map(format_ds, remove_columns=dataset.column_names, num_proc=os.cpu_count())

In [16]:
# Make sure that the images are decoded, it prevents from storing bytes.
# More info here https://github.com/huggingface/blog/pull/2148#discussion_r1667400478
f = dataset.features
f["images"] = features.Sequence(features.Image(decode=True))
dataset = dataset.cast(f)

### Set the Training Arguments 🛢

In [18]:
training_args = DPOConfig(
    output_dir="idefics2-8b-dpo-rlaif-v",
    bf16=True,
    gradient_checkpointing=True,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=32,
    num_train_epochs=1,
    push_to_hub=True,
    dataset_num_proc=os.cpu_count(),
    dataloader_num_workers=os.cpu_count(),
    logging_steps=10,
    push_to_hub_token="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", # Your HF Token
    report_to="wandb",
    )



### Define our `DPO` Trainer instance 🔧

In [19]:
trainer = DPOTrainer(
    model,
    ref_model=None,  # not needed when using peft
    args=training_args,
    train_dataset=dataset,
    tokenizer=processor,
    peft_config=LoraConfig(target_modules="all-linear"),
)



### Let's train our model 🏋️‍♀️

In [None]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mmrm8488[0m. Use [1m`wandb login --relogin`[0m to force relogin


  self.pid = os.fork()
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss


### Push the resulting model to the HF Hub 🔼

In [None]:
trainer.push_to_hub()