### Resources
- [Fine-tune Llama2 with DPO](https://huggingface.co/blog/dpo-trl) | [codebase](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2/scripts)

### Setup

In [1]:
from datasets import load_dataset, Dataset

import torch
from transformers import BitsAndBytesConfig, AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig, AutoPeftModelForCausalLM
from trl import SFTTrainer, DPOTrainer
from ml_collections import config_dict
import huggingface_hub

from utils import LLMSampleCB

# huggingface_hub.login()
# wandb.login()

# huggingface-cli login
# wandb login



### Load dataset

In [2]:
dataset = load_dataset(
    "lvwerra/stack-exchange-paired",
    # split="train",
    # data_dir="data/rl"
    streaming=True
)

train_ds = dataset["train"]
test_ds = dataset["test"]

Resolving data files:   0%|          | 0/72 [00:00<?, ?it/s]

In [3]:
dataset

IterableDatasetDict({
    train: IterableDataset({
        features: ['qid', 'question', 'date', 'metadata', 'response_j', 'response_k'],
        n_shards: 72
    })
    test: IterableDataset({
        features: ['qid', 'question', 'date', 'metadata', 'response_j', 'response_k'],
        n_shards: 12
    })
})

### Sample dataset

In [15]:
sample_size = 100
train_sample_data = []
test_sample_data = []

for i, example in enumerate(train_ds):
    if i == sample_size:
        break

    train_sample_data.append(example)

    if (i + 1) % 1000 == 0:
        print(f"[INFO] processing {i+1} of {sample_size}...")

for i, example in enumerate(test_ds):
    if i == sample_size:
        break

    test_sample_data.append(example)

    if (i + 1) % 1000 == 0:
        print(f"[INFO] processing {i+1} of {sample_size}...")

In [16]:
train_ds_sample = Dataset.from_list(train_sample_data)
test_ds_sample = Dataset.from_list(test_sample_data)

In [17]:
print(train_ds_sample)
print(test_ds_sample)

Dataset({
    features: ['qid', 'question', 'date', 'metadata', 'response_j', 'response_k'],
    num_rows: 100
})
Dataset({
    features: ['qid', 'question', 'date', 'metadata', 'response_j', 'response_k'],
    num_rows: 100
})


### Preprocessing

In [18]:
def return_prompt_and_responses(samples):
    output = {
        "prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
        "chosen": samples["response_j"],
        "rejected": samples["response_k"]
    }

    return output

In [19]:
original_columns = train_ds_sample.column_names
original_columns

['qid', 'question', 'date', 'metadata', 'response_j', 'response_k']

In [20]:
train_ds_sample_prepared = train_ds_sample.map(
    return_prompt_and_responses,
    batched=True,
    # batch_size=1000,
    remove_columns=original_columns
)

test_ds_sample_prepared = test_ds_sample.map(
    return_prompt_and_responses,
    batched=True,
    # batch_size=1000,
    remove_columns=original_columns
)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [21]:
train_ds_sample_prepared

Dataset({
    features: ['prompt', 'chosen', 'rejected'],
    num_rows: 100
})

### Supervised Fine-tuning step

In [None]:
script_args = config_dict.ConfigDict()
script_args.model_name = "meta-llama/Llama-2-7b-hf"
script_args.lora_r = 8
script_args.lora_alpha = 16
script_args.lora_dropout = 0.05

In [None]:
def formatting_func(example):
    text = example["prompt"] + example["chosen"]
    return text

# print(formatting_func(train_ds_sample_prepared[1]))

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

model_kwargs = dict(
    device_map={"": 0},
    trust_remote_code=True,
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",
    use_cache=False,
    quantization_config=bnb_config,
    # token=True
)

base_model = AutoModelForCausalLM.from_pretrained(script_args.model_name, **model_kwargs)

# base_model = AutoModelForCausalLM.from_pretrained(
#     script_args.model_name,
#     quantization_config=bnb_config,
#     device_map={"": 0},
#     trust_remote_code=True,
#     token=True
# )

# base_model.config.use_cache = False

tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_size = "right"

peft_config = LoraConfig(
    r=script_args.lora_r,
    lora_alpha=script_args.lora_alpha,
    lora_dropout=script_args.lora_dropout,
    target_modules=["q_proj", "v_proj"],
    bias="none",
    task_type="CAUSAL_LM"
)

training_args = TrainingArguments(
    output_dir="./sft",
    # max_steps=500,
    num_train_epochs=1,
    logging_steps=10,
    save_steps=10,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    gradient_checkpointing=False,
    group_by_length=False,
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    warmup_steps=100,
    weight_decay=0.05,
    optim="paged_adamw_32bit",
    bf16=True,
    remove_unused_columns=False,
    run_name="sft_llama2",
    report_to="wandb"
)
    
trainer = SFTTrainer(
    model=base_model,
    train_dataset=train_ds_sample_prepared,
    eval_dataset=test_ds_sample_prepared,
    peft_config=peft_config,
    packing=True,
    max_seq_length=None,
    tokenizer=tokenizer,
    args=training_args,
    formatting_func=formatting_func
)

In [None]:
trainer.train()

### DPO step

In [22]:
model_path = "./sft/checkpoint-290/"
gradient_checkpointing = False

model = AutoPeftModelForCausalLM.from_pretrained(
    model_path,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    is_trainable=True,
    attn_implementation="flash_attention_2",
    use_cache=False if gradient_checkpointing else True,
)

model_ref = AutoPeftModelForCausalLM.from_pretrained(
    model_path,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    attn_implementation="flash_attention_2",
    use_cache=False if gradient_checkpointing else True,
)

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_size = "right"

training_args_dpo = TrainingArguments(
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=1,
    # max_steps=3,
    logging_steps=10,
    save_steps=100,
    gradient_accumulation_steps=16,
    gradient_checkpointing=False,
    learning_rate=5e-4,
    evaluation_strategy="steps",
    eval_steps=100,
    output_dir="./results",
    report_to="wandb",
    lr_scheduler_type="cosine",
    warmup_steps=100,
    optim="paged_adamw_32bit",
    bf16=True,
    remove_unused_columns=False,
    run_name="dpo_llama2",
)

peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj", "k_proj", "out_proj", "fc_in", "fc_out", "wte"],
    bias="none",
    task_type="CAUSAL_LM"
)

dpo_trainer = DPOTrainer(
    model,
    model_ref,
    tokenizer=tokenizer,
    args=training_args_dpo,
    peft_config=peft_config,
    train_dataset=train_ds_sample_prepared,
    eval_dataset=test_ds_sample_prepared,
    beta=0.1,
    max_prompt_length=512,
    max_length=1024
)


num_samples = 10
# wandb_cb = LLMSampleCB(dpo_trainer, test_ds_sample_prepared, num_samples=num_samples, max_new_tokens=512)
# dpo_trainer.add_callback(wandb_cb)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [23]:
dpo_trainer.train()

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss


Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


TrainOutput(global_step=6, training_loss=1.2982032299041748, metrics={'train_runtime': 62.954, 'train_samples_per_second': 1.588, 'train_steps_per_second': 0.095, 'total_flos': 0.0, 'train_loss': 1.2982032299041748, 'epoch': 0.96})