# Fine-tune Llama 2 with DPO

https://huggingface.co/blog/dpo-trl

> Reinforcement Learning from Human Feedback (RLHF) has become the de facto last training step of LLMs such as GPT-4 or Claude to ensure that the language model's outputs are aligned with human expectations such as chattiness or safety features. However, it brings some of the complexity of RL into NLP: we need to build a good reward function, train the model to estimate the value of a state, and at the same time be careful not to strive too far from the original model and produce gibberish instead of sensible text.

Papers:
- [Direct Preference Optimization](https://arxiv.org/abs/2305.18290): "... the RL-based objective used by existing methods to an objective which can be directly optimized via a simple binary cross-entropy loss which simplifies this process of refining LLMs greatly."

Datasets:
- [Stack-Exchange Preference](https://huggingface.co/datasets/lvwerra/stack-exchange-paired): "... contains ranked answers to questions on the various stack-exchange portals."

## DPO vs PPO

> In the traditional model of optimising human derived preferences via RL, the goto method has been to use an auxiliary reward model and fine-tune the model of interest so that it maximizes this given reward via the machinery of RL. Intuitively we use the reward model to provide feedback to the model we are optimising so that it generates high-reward samples more often and low-reward samples less often. At the same time we use a frozen reference model to make sure that whatever is generated does not deviate too much and continues to maintain generation diversity. This is typically done by adding a KL penalty to the full reward maximisation objective via a reference model, which serves to prevent the model from learning to cheat or exploit the reward model.

> The DPO formulation bypasses the reward modeling step and directly optimises the language model on preference data via a key insight: namely an analytical mapping from the reward function to the optimal RL policy that enables the authors to transform the RL loss over the reward and reference models to a loss over the reference model directly! Ths mapping intuitively measures how well a given reward function aligns with the given preference data. DPO thus starts with the optimal solution to the RLHF loss and via a change of variables derives a loss over only the reference model!

## How to train with TRL

1. a supervised fine-tuning (SFT) step
2. the process of annotating data with preference labels
3. training a reward model on the preference data
4. and the RL optmization step

"... the DPO training does away with the task of reward modeling and RL (steps 3 and 4) and directly optimizes the DPO object on preference annotated data."

> In this respect we would still need to do the step 1, but instead of steps 3 and 4 we need to provide the DPOTrainer in TRL with preference data from step 2 which has a very specific format, namely a dictionary with the following three keys:

* prompt this consists of the context prompt which is given to a model at inference time for text generation
* chosen contains the preferred generated response to the corresponding prompt
* rejected contains the response which is not preferred or should not be the sampled response with respect to the given prompt

In [None]:
import os
from dataclasses import dataclass, field
from typing import Dict, Optional
from types import SimpleNamespace

import torch
from datasets import load_dataset, Dataset
from peft import AutoPeftModelForCausalLM, LoraConfig
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments
from trl import DPOTrainer

In [None]:
from huggingface_hub import login
import os

access_token = "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"
login(token=access_token)

In [None]:
config = SimpleNamespace(
    model_name_or_path = "meta-llama/Llama-2-7b-hf"
)

In [None]:
def get_stack_exchange_paired(
    data_dir: str = "data/rl",
    sanity_check: bool = False,
    cache_dir: str = None,
    num_proc=24,
) -> Dataset:
    """Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }

    Prompts are structured as follows:
      "Question: " + <prompt> + "\n\nAnswer: "
    """
    dataset = load_dataset(
        "lvwerra/stack-exchange-paired",
        split="train",
        cache_dir=cache_dir,
        data_dir=data_dir,
    )
    original_columns = dataset.column_names

    if sanity_check:
        dataset = dataset.select(range(min(len(dataset), 1000)))

    def return_prompt_and_responses(samples) -> Dict[str, str]:
        return {
            "prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
            "chosen": samples["response_j"],
            "rejected": samples["response_k"],
        }

    return dataset.map(
        return_prompt_and_responses,
        batched=True,
        num_proc=num_proc,
        remove_columns=original_columns,
    )


In [None]:
# 1. load a pretrained policy model
model = AutoPeftModelForCausalLM.from_pretrained(
    config.model_name_or_path,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)
model.config.use_cache = False

In [None]:
# 1b. load a pretrained reference model (SFT model)
model_ref = AutoPeftModelForCausalLM.from_pretrained(
    config.model_name_or_path,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)

In [None]:
# 1c. load tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
# 2. Load the Stack-exchange paired dataset
train_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check)
train_dataset = train_dataset.filter(
    lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
    and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
)

In [None]:
# 3. Load evaluation dataset
eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True)
eval_dataset = eval_dataset.filter(
    lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
    and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
)

In [None]:
# 4a. initialize training arguments:
training_args = TrainingArguments(
    per_device_train_batch_size=script_args.per_device_train_batch_size,
    max_steps=script_args.max_steps,
    logging_steps=script_args.logging_steps,
    save_steps=script_args.save_steps,
    gradient_accumulation_steps=script_args.gradient_accumulation_steps,
    gradient_checkpointing=script_args.gradient_checkpointing,
    learning_rate=script_args.learning_rate,
    evaluation_strategy="steps",
    eval_steps=script_args.eval_steps,
    output_dir=script_args.output_dir,
    report_to=script_args.report_to,
    lr_scheduler_type=script_args.lr_scheduler_type,
    warmup_steps=script_args.warmup_steps,
    optim=script_args.optimizer_type,
    bf16=True,
    remove_unused_columns=False,
    run_name="dpo_llama2",
)

In [None]:
# 4b. initialize the lora arguments:
peft_config = LoraConfig(
    r=script_args.lora_r,
    lora_alpha=script_args.lora_alpha,
    lora_dropout=script_args.lora_dropout,
    target_modules=[
        "q_proj",
        "v_proj",
        "k_proj",
        "out_proj",
        "fc_in",
        "fc_out",
        "wte",
    ],
    bias="none",
    task_type="CAUSAL_LM",
)


In [None]:
# 5. initialize the DPO trainer
dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=script_args.beta,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    peft_config=peft_config,
    max_prompt_length=script_args.max_prompt_length,
    max_length=script_args.max_length,
)

> Once we have the dataset sorted the DPO loss is essentially a supervised loss which obtains an implicit reward via a reference model and thus at a high-level the DPOTrainer requires the base model we wish to optimize as well as a reference model ... where the beta hyper-parameter is the temperature parameter for the DPO loss, typically in the range 0.1 to 0.5. This controls how much we pay attention to the reference model in the sense that as beta gets smaller the more we ignore the reference model:

```python
dpo_trainer = DPOTrainer(
    model,                 # base model from SFT pipeline
    model_ref,             # typically a copy of the SFT trained base model
    beta=0.1,              # temperature hyperparameter of DPO (usually between 0.1 and 0.5)
    train_dataset=dataset, # dataset prepared above
    tokenizer=tokenizer,   # tokenizer
    args=training_args,    # training arguments e.g. batch size, lr, etc.
)
```



In [None]:
 # 6. train
dpo_trainer.train()
dpo_trainer.save_model(script_args.output_dir)

In [None]:
# 7. save
output_dir = os.path.join(script_args.output_dir, "final_checkpoint")
dpo_trainer.model.save_pretrained(output_dir)