DPO/IPO Links:

https://github.com/eric-mitchell/direct-preference-optimization [Owner] \
https://huggingface.co/blog/pref-tuning \
https://github.com/huggingface/alignment-handbook \
https://github.com/dida-do/public/blob/master/fine-tuning_llm/train-dpo.py \
https://www.kaggle.com/code/aisuko/supervised-fine-tuning-llama2-with-dpo \
https://github.com/michaelnny/DPO-LLaMA \
https://plainenglish.io/community/direct-preference-optimization-dpo-a-simplified-approach-to-fine-tuning-large-language-models \
https://towardsdatascience.com/fine-tune-a-mistral-7b-model-with-direct-preference-optimization-708042745aac \
https://github.com/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb \
https://huggingface.co/blog/dpo-trl
https://discuss.huggingface.co/t/sfttrainer-class-and-training-arguements/85976/2

!pip install --upgrade \
“transformers==4.38.2”\
“datasets==2.16.1”\
“accelerate==0.26.1”\
“evaluate==0.4.1”\
“bitsandbytes==0.42.0”\
“trl==0.7.11”\
“peft==0.8.2”

# LLaMA2 with DPO

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.trainer import TrainingArguments
from datasets import load_dataset
from tqdm import tqdm
from peft import LoraConfig, TaskType, AutoPeftModelForCausalLM
from trl.trainer import ConstantLengthDataset
from trl import SFTTrainer, DPOTrainer

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
model_name = "huggyllama/llama-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [None]:
dataset = load_dataset(
    path="lvwerra/stack-exchange-paired",
    data_dir = "data/rl",
    split = "train"
)

In [None]:
dataset

In [None]:
dataset = dataset.train_test_split(test_size=0.005, seed=None)

In [None]:
train_data = dataset['train']
test_data = dataset['test']

In [None]:
train_data

In [None]:
def prepare_sample_text(example):
    text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}"
    return text

In [None]:
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
    '''
    Estimate the average number of characters per token in the dataset
    '''
    
    total_characters, total_tokens = 0, 0
    for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
        text = prepare_sample_text(example)
        total_characters += len(text)
        if tokenizer.is_fast:
            total_tokens += len(tokenizer(text).tokens())
        else:
            total_tokens += len(tokenizer.tokenize(text))
    
    return total_characters/total_tokens

In [None]:
chars_per_token = chars_token_ratio(train_data, tokenizer)

In [None]:
train_dataset = ConstantLengthDataset(
    tokenizer,
    train_data,
    formatting_func=prepare_sample_text,
    infinite=True,
    seq_length=1024,
    chars_per_token=chars_per_token
)

test_dataset = ConstantLengthDataset(
    tokenizer,
    test_data,
    formatting_func=prepare_sample_text,
    infinite=False,
    seq_length=1024,
    chars_per_token=chars_per_token
)

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config = bnb_config,
    device_map = "auto",
    torch_dtype = torch.bfloat16,
    trust_remote_code = False
)

In [None]:
base_model.config.use_cache=False

In [None]:
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj"],
    bias="none",
    task_type = TaskType.CAUSAL_LM
)

In [None]:
training_args=TrainingArguments(
    output_dir="./sft",
    max_steps=100,
    logging_steps=10,
    save_steps=10,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    group_by_length=False,
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    warmup_steps=50,
    weight_decay=0.05,
    optim="paged_adamw_32bit",
    fp16=True,
    remove_unused_columns=False,
    report_to="none"
)

In [None]:
sft_trainer=SFTTrainer(
    model=base_model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    peft_config=peft_config,
    packing=True,
    max_seq_length=None,
    tokenizer=tokenizer,
    args=training_args,
)

In [None]:
# sft_trainer.train()

In [None]:
model = AutoPeftModelForCausalLM.from_pretrained("./sft/checkpoint-100", device_map="auto", torch_dtype=torch.bfloat16)

In [None]:
model = model.merge_and_unload()

In [None]:
model.save_pretrained("./sft/final_merged_checkpoint", safe_serialization=True)

## Direct Preference Optimization

In [2]:
def return_prompt_and_responses(samples):
    return {
        "prompt":[
            "Question:"+question+"\n\nAnswer:" for question in samples["question"]
        ],
        "chosen": samples["response_j"],
        "rejected": samples["response_k"],
    }

In [3]:
def get_stack_exchange_paired(data_dir="data/rl", sanity_check=False, cache_dir=None, num_proc=24):
    dataset=load_dataset(
        "lvwerra/stack-exchange-paired",
        split="train",
        data_dir=data_dir,
        cache_dir=cache_dir,
    )
    original_columns=dataset.column_names
    
    if sanity_check:
        dataset=dataset.select(range(min(len(dataset), 1000)))
    
    return dataset.map(
        return_prompt_and_responses,
        batched=True,
        num_proc=num_proc,
        remove_columns=original_columns,
    )

In [4]:
model = AutoModelForCausalLM.from_pretrained(
    "./sft/final_merged_checkpoint",
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)

model.config.use_cache=False

model_ref = AutoModelForCausalLM.from_pretrained(
    "./sft/final_merged_checkpoint",
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 3/3 [00:07<00:00,  2.50s/it]
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 3/3 [00:05<00:00,  1.91s/it]


In [7]:
tokenizer_dpo=AutoTokenizer.from_pretrained("huggyllama/llama-7b")
tokenizer_dpo.pad_token=tokenizer_dpo.eos_token



In [8]:
train_dataset = get_stack_exchange_paired()

Resolving data files: 100%|██████████| 20/20 [00:00<00:00, 180.12it/s]
Generating train split: 7435908 examples [00:54, 136411.61 examples/s]
Map (num_proc=24): 100%|██████████| 7435908/7435908 [01:34<00:00, 79018.24 examples/s]  


In [9]:
train_dataset = train_dataset.filter(
    lambda x: len(x["prompt"])+len(x["chosen"])<=1024 and len(x["prompt"])+len(x["rejected"])<=1024
)

Filter: 100%|██████████| 7435908/7435908 [00:34<00:00, 213449.10 examples/s]


In [10]:
eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True)

Generating train split: 4483004 examples [00:29, 150622.67 examples/s]
Map (num_proc=24): 100%|██████████| 1000/1000 [00:00<00:00, 3277.29 examples/s]


In [11]:
eval_dataset = eval_dataset.filter(
    lambda x: len(x["prompt"])+len(x["chosen"])<=1024 and len(x["prompt"])+len(x["rejected"])<=1024
)

Filter: 100%|██████████| 1000/1000 [00:00<00:00, 43812.06 examples/s]


In [12]:
training_args=TrainingArguments(
    output_dir="./dpo",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=1,
    max_steps=1000,
    logging_steps=10,
    save_steps=100,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    group_by_length=False,
    learning_rate=5e-4,
    evaluation_strategy="steps",
    eval_steps=100,
    lr_scheduler_type="cosine",
    warmup_steps=100,
    weight_decay=0.05,
    optim="paged_adamw_32bit",
    bf16=True,
    remove_unused_columns=False,
    report_to="none"
)

In [13]:
peft_config=LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=['q_proj','v_proj','k_proj','out_proj','fc_in','fc_out','wte',],
    bias="none",
    task_type="CAUSAL_LM",
)

In [15]:
dpo_trainer=DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=0.1,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer_dpo,
    max_prompt_length=512,
    max_length=1024,
)

Map:  33%|███▎      | 552611/1659503 [17:07<29:28, 626.06 examples/s]  

In [None]:
dpo_trainer.train()

In [None]:
dpo_trainer.save_model()