In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, TrainingArguments, Trainer
import torch
from torch.optim import AdamW
import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm
from datasets import Dataset,concatenate_datasets
from peft import AutoPeftModelForCausalLM,PeftModel
from trl import  DPOConfig, DPOTrainer
from unsloth import FastLanguageModel
from peft import get_peft_model, TaskType, LoraConfig

In [None]:
sft_model_root = "/root/of/the/sft_model" #After the second stage of the sft model

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = sft_model_root,
    dtype=torch.bfloat16,
    load_in_4bit=False
)
tokenizer.padding_side="left"
model.load_adapter(sft_model_root, adapter_name="reference")

In [None]:
ds1 = Dataset.from_csv("./RL data/harmful") #Select the corresponding harmful data
ds2 = Dataset.from_csv("./RL data/harmless")#Select the corresponding harmless data
ds = concatenate_datasets([ds1,ds2])
rl_ds = ds.shuffle(seed=42)

In [None]:
def chatml_format(example):
    messages = [
        #{"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": example["prompt"]}
    ]

    example["prompt"] = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    example['chosen'] = example['chosen'] + tokenizer.eos_token +"\n"
    example['rejected'] = example['rejected'] + tokenizer.eos_token +"\n"
    
    return example

formatted_rl_ds = rl_ds.map(chatml_format).select_columns(["prompt", "chosen", "rejected"])

In [None]:
#Default training parameters
training_args = DPOConfig(
    output_dir="output_dir",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    logging_steps=10,
    learning_rate= 3e-5,
    bf16=True,
    max_length=8192,
    num_train_epochs=1,
    torch_empty_cache_steps=5,
    rpo_alpha=1.0,
    model_adapter_name="default",
    ref_adapter_name="reference"
)

In [None]:
dpo_trainer = DPOTrainer(
    model = model,
    ref_model= None,
    args=training_args,
    train_dataset=formatted_rl_ds,
    processing_class=tokenizer
)

In [None]:
dpo_trainer.train()