### Resources
- [Fine-tune Llama2 with DPO](https://huggingface.co/blog/dpo-trl) | [codebase](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2/scripts)

### Setup

In [49]:
import os
from datasets import load_dataset, Dataset
import torch
from transformers import BitsAndBytesConfig, AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig, AutoPeftModelForCausalLM
from trl import SFTTrainer, DPOTrainer
from ml_collections import config_dict
import huggingface_hub
import wandb
from tqdm.auto import tqdm

from utils import LLMSampleCB

os.environ["WANDB_PROJECT"] = "dpo_llama2_finetuning"


# huggingface_hub.login()
# wandb.login()

# huggingface-cli login
# wandb login



### Load dataset

In [2]:
dataset = load_dataset(
    "lvwerra/stack-exchange-paired",
    # split="train",
    # data_dir="data/rl"
    streaming=True
)

train_ds = dataset["train"]
test_ds = dataset["test"]

Downloading readme:   0%|          | 0.00/737 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/72 [00:00<?, ?it/s]

In [3]:
dataset

IterableDatasetDict({
    train: IterableDataset({
        features: ['qid', 'question', 'date', 'metadata', 'response_j', 'response_k'],
        n_shards: 72
    })
    test: IterableDataset({
        features: ['qid', 'question', 'date', 'metadata', 'response_j', 'response_k'],
        n_shards: 12
    })
})

### Sample dataset

In [4]:
sample_size = 100
train_sample_data = []
test_sample_data = []

for i, example in enumerate(train_ds):
    if i == sample_size:
        break

    train_sample_data.append(example)

    if (i + 1) % 1000 == 0:
        print(f"[INFO] processing {i+1} of {sample_size}...")

for i, example in enumerate(test_ds):
    if i == sample_size:
        break

    test_sample_data.append(example)

    if (i + 1) % 1000 == 0:
        print(f"[INFO] processing {i+1} of {sample_size}...")

In [5]:
train_ds_sample = Dataset.from_list(train_sample_data)
test_ds_sample = Dataset.from_list(test_sample_data)

In [6]:
print(train_ds_sample)
print(test_ds_sample)

Dataset({
    features: ['qid', 'question', 'date', 'metadata', 'response_j', 'response_k'],
    num_rows: 100
})
Dataset({
    features: ['qid', 'question', 'date', 'metadata', 'response_j', 'response_k'],
    num_rows: 100
})


### Preprocessing

In [7]:
def return_prompt_and_responses(samples):
    output = {
        "prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
        "chosen": samples["response_j"],
        "rejected": samples["response_k"]
    }

    return output

In [8]:
original_columns = train_ds_sample.column_names
original_columns

['qid', 'question', 'date', 'metadata', 'response_j', 'response_k']

In [9]:
train_ds_sample_prepared = train_ds_sample.map(
    return_prompt_and_responses,
    batched=True,
    # batch_size=1000,
    remove_columns=original_columns
)

test_ds_sample_prepared = test_ds_sample.map(
    return_prompt_and_responses,
    batched=True,
    # batch_size=1000,
    remove_columns=original_columns
)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [10]:
train_ds_sample_prepared

Dataset({
    features: ['prompt', 'chosen', 'rejected'],
    num_rows: 100
})

### Supervised Fine-tuning step

In [None]:
script_args = config_dict.ConfigDict()
script_args.model_name = "meta-llama/Llama-2-7b-hf"
script_args.lora_r = 8
script_args.lora_alpha = 16
script_args.lora_dropout = 0.05

In [None]:
def formatting_func(example):
    text = example["prompt"] + example["chosen"]
    return text

# print(formatting_func(train_ds_sample_prepared[1]))

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

model_kwargs = dict(
    device_map={"": 0},
    trust_remote_code=True,
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",
    use_cache=False,
    quantization_config=bnb_config,
    # token=True
)

base_model = AutoModelForCausalLM.from_pretrained(script_args.model_name, **model_kwargs)

# base_model = AutoModelForCausalLM.from_pretrained(
#     script_args.model_name,
#     quantization_config=bnb_config,
#     device_map={"": 0},
#     trust_remote_code=True,
#     token=True
# )

# base_model.config.use_cache = False

tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_size = "right"

peft_config = LoraConfig(
    r=script_args.lora_r,
    lora_alpha=script_args.lora_alpha,
    lora_dropout=script_args.lora_dropout,
    target_modules=["q_proj", "v_proj"],
    bias="none",
    task_type="CAUSAL_LM"
)

training_args = TrainingArguments(
    output_dir="./sft",
    # max_steps=500,
    num_train_epochs=1,
    logging_steps=10,
    save_steps=10,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    gradient_checkpointing=False,
    group_by_length=False,
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    warmup_steps=100,
    weight_decay=0.05,
    optim="paged_adamw_32bit",
    bf16=True,
    remove_unused_columns=False,
    run_name="sft_llama2",
    report_to="wandb"
)
    
trainer = SFTTrainer(
    model=base_model,
    train_dataset=train_ds_sample_prepared,
    eval_dataset=test_ds_sample_prepared,
    peft_config=peft_config,
    packing=True,
    max_seq_length=None,
    tokenizer=tokenizer,
    args=training_args,
    formatting_func=formatting_func
)

In [None]:
trainer.train()

### DPO step

In [50]:
import torch
from transformers import GenerationConfig
from transformers.integrations import WandbCallback
import wandb


class LLMSampleCB(WandbCallback):
    def __init__(self, trainer, test_dataset, num_samples=10, max_new_tokens=256, log_model="checkpoint"):
        super().__init__()
        self._log_model = log_model
        self.sample_dataset = test_dataset.select(range(num_samples))
        self.model, self.tokenizer = trainer.model, trainer.tokenizer
        self.gen_config = GenerationConfig.from_pretrained(trainer.model.name_or_path, max_new_tokens=max_new_tokens)
        
    def generate(self, prompt):
        # tokenized_prompt = self.tokenizer(prompt, return_tensors='pt')['input_ids'].cuda()
        tokenized_prompt = self.tokenizer(prompt, return_tensors='pt')
        tokenized_prompt = {k: v.to("cuda") for k, v in tokenized_prompt.items()}

        with torch.inference_mode():
            output = self.model.generate(**tokenized_prompt, generation_config=self.gen_config, pad_token_id=self.tokenizer.eos_token_id)
            
        return self.tokenizer.decode(output[0][len(tokenized_prompt["input_ids"][0]):], skip_special_tokens=True)
    
    def samples_table(self, examples):
        records_table = wandb.Table(columns=["prompt", "generation"] + list(self.gen_config.to_dict().keys()))
        for example in tqdm(examples, leave=False):
            prompt = example["prompt"]
            generation = self.generate(prompt=prompt)
            records_table.add_data(prompt, generation, *list(self.gen_config.to_dict().values()))
        return records_table
        
    def on_evaluate(self, args, state, control,  **kwargs):
        super().on_evaluate(args, state, control, **kwargs)
        records_table = self.samples_table(self.sample_dataset)
        self._wandb.log({"sample_predictions":records_table})

In [79]:
model_path = "./sft/checkpoint-290/"
gradient_checkpointing = True

model = AutoPeftModelForCausalLM.from_pretrained(
    model_path,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    is_trainable=True,
    attn_implementation="flash_attention_2",
    use_cache=False if gradient_checkpointing else True,
)

model_ref = AutoPeftModelForCausalLM.from_pretrained(
    model_path,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    attn_implementation="flash_attention_2",
    use_cache=False if gradient_checkpointing else True,
)

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_size = "right"

training_args_dpo = TrainingArguments(
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=1,
    # max_steps=3,
    logging_steps=10,
    save_steps=100,
    gradient_accumulation_steps=16,
    gradient_checkpointing=False,
    learning_rate=5e-4,
    evaluation_strategy="epoch",
    # evaluation_strategy="steps",
    # eval_steps=100,
    output_dir="./results",
    report_to="wandb",
    lr_scheduler_type="cosine",
    warmup_steps=100,
    optim="paged_adamw_32bit",
    bf16=True,
    remove_unused_columns=False,
    # run_name="dpo_llama2",
)

peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj", "k_proj", "out_proj", "fc_in", "fc_out", "wte"],
    bias="none",
    task_type="CAUSAL_LM"
)

dpo_trainer = DPOTrainer(
    model,
    model_ref,
    tokenizer=tokenizer,
    args=training_args_dpo,
    peft_config=peft_config,
    train_dataset=train_ds_sample_prepared,
    eval_dataset=test_ds_sample_prepared,
    beta=0.1,
    max_prompt_length=512,
    max_length=1024
)


# num_samples = 10
# wandb_cb = LLMSampleCB(dpo_trainer, test_ds_sample_prepared, num_samples=num_samples, max_new_tokens=512)
# dpo_trainer.add_callback(wandb_cb)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [80]:
dpo_trainer.train()
wandb.finish()

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Rewards/chosen,Rewards/rejected,Rewards/accuracies,Rewards/margins,Logps/rejected,Logps/chosen,Logits/rejected,Logits/chosen
0,No log,1.296792,-1.559337,-0.8043,0.28,-0.755037,-261.20108,-429.136627,-0.520507,-0.310032


VBox(children=(Label(value='0.035 MB of 0.048 MB uploaded (0.003 MB deduped)\r'), FloatProgress(value=0.719719…

0,1
eval/logits/chosen,▅▅▆▆▂▂▁▁█
eval/logits/rejected,▆▆██▂▂▁▁▆
eval/logps/chosen,▂▂▁▁▃▃▁▁█
eval/logps/rejected,▃▃██▃▃▃▃▁
eval/loss,▃▃██▄▄▅▅▁
eval/rewards/accuracies,██████▁▁▅
eval/rewards/chosen,▂▂▁▁▃▃▁▁█
eval/rewards/margins,▄▄▁▁▄▄▃▃█
eval/rewards/rejected,▃▃██▃▃▃▃▁
eval/runtime,▄▄▅▅██▁▁▄

0,1
eval/logits/chosen,-0.31003
eval/logits/rejected,-0.52051
eval/logps/chosen,-429.13663
eval/logps/rejected,-261.20108
eval/loss,1.29679
eval/rewards/accuracies,0.28
eval/rewards/chosen,-1.55934
eval/rewards/margins,-0.75504
eval/rewards/rejected,-0.8043
eval/runtime,74.9112


In [81]:
sample_dataset = test_ds_sample_prepared.select(range(num_samples))
sample_dataset

Dataset({
    features: ['prompt', 'chosen', 'rejected'],
    num_rows: 10
})

In [86]:
model_t, tokenizer_t = dpo_trainer.model, dpo_trainer.tokenizer
tokenizer_t.pad_token = tokenizer_t.eos_token
tokenizer_t.padding_size = "right"
gen_config_t = GenerationConfig.from_pretrained(dpo_trainer.model.name_or_path, max_new_tokens=512)

for p in sample_dataset:
    prompt = p["prompt"]

In [90]:
dpo_trainer.model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaFlashAttention2(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear4bit(
                (base_layer): Linear4bit

In [97]:
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
base_model.to("cuda")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Lin

In [None]:
tokenized_prompt = tokenizer_t(prompt, return_tensors="pt")
tokenized_prompt = {k: v.to("cuda") for k, v in tokenized_prompt.items()}

with torch.inference_mode():
    output = base_model.generate(**tokenized_prompt)
    
# return self.tokenizer.decode(output[0][len(tokenized_prompt["input_ids"][0]):], skip_special_tokens=True)