In [None]:
!pip install -qU transformers datasets accelerate peft trl bitsandbytes wandb --progress-bar off

# Fine-tune Llama 3 with ORPO

In [None]:
import os
import gc
import torch
import wandb
from datasets import load_dataset
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import AUtoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, pipeline
from trl import ORPOConfig, ORPOTrainer, setup_chat_format
from google.colab import userdata

# model
base_model = 'meta-llama/Meta-Llama-3-8B'
new_model = 'OrpoLlama-3-8B'

# setups
wb_token = userdata.get('WB_TOKEN')
wandb.login(key=wb_token)

# set torch dtype and attention implementation
if torch.cuda.get_device_capability()[0] >= 8:
    !pip install -qU flash-attn
    torch_dtype = torch.bfloat16
    attn_implementation = 'flash_attention_2'
else:
    torch_dtype = torch.float16
    attn_implementation = 'eager'

In [None]:
# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True
)

# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias='none',
    task_type='CAUSAL_LM',
    target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'up_proj', 'down_proj', 'gate_proj']
)

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)
# load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map='auto',
    attn_implementation=attn_implementation,
)

model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)

## Load dataset

In [None]:
# load dataset
dataset_name = 'mlabonne/orpo-dpo-mix-40k'
dataset = load_dataset(dataset_name, split='all')
dataset = dataset.shuffle(seed=111).select(range(1000)) # only use 1000 samples for demo purpose

def format_chat_template(row):
    row['chosen'] = tokenizer.apply_chat_template(row['chosen'], tokenize=False)
    row['rejected'] = tokenizer.apply_chat_template(row['rejected'], tokenize=False)
    return row


dataset = dataset.map(
    format_chat_template,
    num_proc=os.cpu_count()
)
dataset = dataset.train_test_split(test_size=0.01)

## Train ORPO

In [None]:
# ORPO config
orpo_args = ORPOConfig(
    learning_rate=8e-6,
    lr_scheduler_type='linear',
    max_length=1024,
    max_prompt_length=512,
    beta=0.1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    optim='paged_adamw_8bit',
    num_train_epochs=1,
    evaluation_strategy='steps',
    eval_steps=0.2, # 20% of the total trianing steps
    logging_steps=1,
    warmup_steps=10,
    report_to='wandb',
    output_dir='./results/'
)

trainer = ORPOTrainer(
    model,
    args=orpo_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    peft_config=peft_config,
    tokenizer=tokenizer
)

In [None]:
trianer.train()
trainer.save_model(new_model)

## Inference

In [None]:
# flush memory
del trainer, model
torch.cuda.empty_cache()
gc.collect()

In [None]:
# reload tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model)
fp16_model = AutoModelForCausalLM.from_pretrained(
    base_model,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map='auto',
)
fp16_model, tokenizer = setup_chat_format(fp16_model, tokenizer)

In [None]:
# Merge adapter with base model
model = PeftModel.from_pretrained(fp16_model, new_model)
model = model.merge_and_unload()

In [None]:
model.push_to_hub(new_model, use_temp_dir=False)
tokenizer.push_to_hub(new_model, use_temp_dir=False)