# Source
# https://huggingface.co/blog/mlabonne/orpo-llama-3

In [1]:
import gc
import os

import torch
from datasets import load_dataset
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
)
from trl import ORPOConfig, ORPOTrainer, setup_chat_format

  from .autonotebook import tqdm as notebook_tqdm


# Configuration

In [2]:
access_token = "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"

In [3]:
# Flash attention
attn_implementation = "flash_attention_2"
torch_dtype = torch.bfloat16

N_EPOCHS = 100

# Model
base_model = "meta-llama/Meta-Llama-3-8B"
new_model = "OrpoLlama-3-8B"

# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)

# Model and Tokenizer setup

In [4]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    base_model,
    token=access_token)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation,
    token=access_token
)

# Prepare model embedding layer for tokenizer dictionary
model, tokenizer = setup_chat_format(model, tokenizer)

# Prepare/Wrap model for quantization
model = prepare_model_for_kbit_training(model)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:06<00:00,  1.58s/it]


# Dataset preparation

In [10]:
dataset_name = "mlabonne/orpo-dpo-mix-40k"
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=42)
#dataset = dataset.select(range(3930))

def format_chat_template(row):
    row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
    row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
    return row

# Process entries when called to fit the correct template
dataset = dataset.map(
    format_chat_template,
    num_proc= os.cpu_count(),
)
dataset = dataset.train_test_split(test_size=0.001)

# Training

## Configure Trainer

In [11]:
orpo_args = ORPOConfig(
    learning_rate=8e-6,
    beta=0.1,
    lr_scheduler_type="linear",
    max_length=1024,
    max_prompt_length=512,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    optim="paged_adamw_8bit",
    num_train_epochs=N_EPOCHS,
    evaluation_strategy="steps",
    eval_steps=1,
    logging_steps=1,
    warmup_steps=10,
    output_dir="./results/",
    remove_unused_columns=False
)

trainer = ORPOTrainer(
    model=model,
    args=orpo_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    tokenizer=tokenizer
)

Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39663/39663 [01:26<00:00, 457.76 examples/s]
Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 404.10 examples/s]


## Train model

In [12]:
trainer.train()
trainer.save_model(new_model)



Step,Training Loss,Validation Loss,Runtime,Samples Per Second,Steps Per Second,Rewards/chosen,Rewards/rejected,Rewards/accuracies,Rewards/margins,Logps/rejected,Logps/chosen,Logits/rejected,Logits/chosen,Nll Loss,Log Odds Ratio,Log Odds Chosen
1,4.4338,4.921563,46.908,0.853,0.426,-0.147043,-0.143011,0.5,-0.004032,-1.430109,-1.47043,-1.451707,-1.470076,4.844595,-0.769671,-0.047337
2,4.701,4.866637,50.5028,0.792,0.396,-0.147021,-0.143002,0.475,-0.004019,-1.430023,-1.470213,-1.462631,-1.480098,4.789681,-0.769555,-0.047185
3,5.0984,4.756692,51.6554,0.774,0.387,-0.146952,-0.142946,0.475,-0.004005,-1.429465,-1.469516,-1.484007,-1.499783,4.679748,-0.769437,-0.047007
4,3.9634,4.592246,51.3824,0.778,0.389,-0.146874,-0.142866,0.45,-0.004009,-1.428655,-1.468741,-1.51479,-1.528304,4.515296,-0.769487,-0.047082
5,4.5198,4.374752,52.4593,0.762,0.381,-0.1467,-0.142709,0.475,-0.003991,-1.427088,-1.467002,-1.553582,-1.564027,4.297818,-0.769332,-0.046898
6,3.4786,4.107249,52.0602,0.768,0.384,-0.146322,-0.142386,0.45,-0.003937,-1.423856,-1.463223,-1.598696,-1.60586,4.030358,-0.768914,-0.046296
7,2.4182,3.798208,51.6513,0.774,0.387,-0.14607,-0.142149,0.45,-0.003921,-1.421494,-1.460704,-1.641586,-1.645468,3.721333,-0.76874,-0.04618
8,2.9431,3.448894,50.9486,0.785,0.393,-0.145639,-0.14177,0.45,-0.003869,-1.417701,-1.456388,-1.68324,-1.684193,3.372061,-0.76832,-0.045629
9,2.6249,3.063257,50.7467,0.788,0.394,-0.145308,-0.141453,0.45,-0.003855,-1.414526,-1.453077,-1.717063,-1.715632,2.986442,-0.768151,-0.045555
10,2.2639,2.641032,53.0964,0.753,0.377,-0.1449,-0.141067,0.45,-0.003833,-1.410672,-1.448999,-1.747823,-1.74412,2.564237,-0.767945,-0.045436


KeyboardInterrupt: 

# Inference

In [None]:
# Flush memory
del trainer, model
gc.collect()
torch.cuda.empty_cache()

# Reload tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map="auto",
)
model, tokenizer = setup_chat_format(model, tokenizer)

# Merge adapter with base model
model = PeftModel.from_pretrained(model, new_model)
model = model.merge_and_unload()

In [None]:
type(model)