## Load Llama-3-8B

In [2]:
# Warning: Using transformer version in DPO will lead to errors of loading Llama3
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
transformers.logging.set_verbosity_error()

In [3]:
tokenizer = AutoTokenizer.from_pretrained("<some path>/model_zoo/Meta-Llama-3-8B")
model = AutoModelForCausalLM.from_pretrained("<some path>/model_zoo/Meta-Llama-3-8B")


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [6]:
from datasets import load_dataset
#hh_dataset = load_dataset("Anthropic/hh-rlhf")
ds = load_dataset("<some path>/Anthropic_hh-rlhf")

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

# Data Processing

In [30]:
from datasets import load_dataset,load_from_disk
dataset_path = "<some path>/trl/examples/datasets/hh_trl"

raw_datasets = load_from_disk(dataset_path)

In [32]:
ds = raw_datasets

In [33]:
for split in ["train","test"]:
    ds[split]=ds[split].remove_columns(["prompt","rejected"])
    ds[split]=ds[split].rename_column("chosen","messages")
ds.save_to_disk("<some path>/trl/examples/datasets/hh_sft_trl/")

Saving the dataset (0/1 shards):   0%|          | 0/160800 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/8552 [00:00<?, ? examples/s]

## DataSet Checking

In [2]:
from datasets import load_from_disk
dataset_path = "<some path>/trl/examples/datasets/hh_sft_trl"
hh_sft_trl = load_from_disk(dataset_path)
# ds = raw_datasets.remove_column("prompt")

# # def process(row):
# #     row["completion"] = tokenizer.apply_chat_template(row["completion"], tokenize=False)
# #     return row

# ds = ds.map(
#     process,
#     num_proc=16,
#     #num_proc=multiprocessing.cpu_count(),
#     load_from_cache_file=True,
# )

hh_sft_trl

DatasetDict({
    train: Dataset({
        features: ['messages'],
        num_rows: 160800
    })
    test: Dataset({
        features: ['messages'],
        num_rows: 8552
    })
})

In [3]:
from datasets import load_from_disk
dataset_path = "<some path>/trl/examples/datasets/hh_trl"
hh_trl = load_from_disk(dataset_path)

In [4]:
hh_trl

DatasetDict({
    train: Dataset({
        features: ['chosen', 'rejected', 'prompt'],
        num_rows: 160800
    })
    test: Dataset({
        features: ['chosen', 'rejected', 'prompt'],
        num_rows: 8552
    })
})

# SFT

In [None]:
!python examples/scripts/sft.py \
    --model_name_or_path="facebook/opt-350m" \
    --report_to="wandb" \
    --learning_rate=1.41e-5 \
    --per_device_train_batch_size=64 \
    --gradient_accumulation_steps=16 \
    --output_dir="sft_openassistant-guanaco" \
    --logging_steps=1 \
    --num_train_epochs=3 \
    --max_steps=-1 \
    --gradient_checkpointing

In [None]:
# FSDP
!accelerate launch --config_file=examples/accelerate_configs/FSDP_Llama.yaml --num_processes 8 examples/scripts/sft_fsdp.py \
    --dataset_name="<some path>/trl/examples/datasets/hh_sft_trl" \
    --model_name_or_path="<some path>/model_zoo/Meta-Llama-3-8B" \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --attn_implementation 'flash_attention_2' \
    --learning_rate 2e-6 \
    --gradient_accumulation_steps 8 \
    --gradient_checkpointing \
    --logging_steps 10 \
    --output_dir="sft_anthropic_hh_1" \
    --optim adamw_torch \
    --warmup_steps 150 #\

In [None]:
# Zero 2 will lead to OOM, similar for DPO
!accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3_sft.yaml --num_processes 8 examples/scripts/sft.py \
    --dataset_name="<some path>/trl/examples/datasets/hh_sft_trl" \
    --model_name_or_path="<some path>/model_zoo/Meta-Llama-3-8B" \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --learning_rate 5e-7 \
    --gradient_accumulation_steps 8 \
    --gradient_checkpointing \
    --logging_steps 10 \
    --output_dir="sft_anthropic_hh" \
    --optim adamw_torch \
    --warmup_steps 150 \
    #--eval_steps 500 \
    #--bf16 \
    #--logging_first_step \
    #--no_remove_unused_columns

# RLHF

In [None]:
# just python
!python examples/scripts/dpo.py \
    --dataset_name="<some path>/trl/examples/datasets/hh_trl" \
    --model_name_or_path="<some path>/model_zoo/Meta-Llama-3-8B" \
    --per_device_train_batch_size 4 \
    --learning_rate 1e-3 \
    --gradient_accumulation_steps 1 \
    --logging_steps 10 \
    --eval_steps 500 \
    --output_dir="dpo_anthropic_hh" \
    --warmup_steps 150 \
    --bf16 \
    --logging_first_step \
    --no_remove_unused_columns

## Deepspeed Zero

In [None]:
# Zero3 Full parameters 
!accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml --num_processes 8 examples/scripts/dpo.py \
    --dataset_name="<some path>/trl/examples/datasets/hh_trl" \
    --model_name_or_path="<some path>/model_zoo/Meta-Llama-3-8B" \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 4 \
    --learning_rate 5e-7 \
    --gradient_accumulation_steps 8 \
    --gradient_checkpointing True \
    --logging_steps 10 \
    --eval_steps 500 \
    --output_dir="dpo_anthropic_hh" \
    --optim adamw_torch \
    --max_length 512 \
    --max_prompt_length 128 \
    --warmup_steps 150 \
    --bf16 \
    --logging_first_step \
    --no_remove_unused_columns

In [None]:
# Zero3 Full parameters (with Flash Attention)
!accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml --num_processes 8 examples/scripts/dpo.py \
    --dataset_name="<some path>/trl/examples/datasets/hh_trl" \
    --model_name_or_path="<some path>/model_zoo/Meta-Llama-3-8B" \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 4 \
    --attn_implementation 'flash_attention_2' \
    --learning_rate 5e-7 \
    --gradient_accumulation_steps 8 \
    --gradient_checkpointing True \
    --logging_steps 10 \
    --eval_steps 500 \
    --output_dir="dpo_anthropic_hh" \
    --optim rmsprop \
    --max_length 512 \
    --max_prompt_length 128 \
    --warmup_steps 150 \
    --bf16 \
    --logging_first_step \
    --no_remove_unused_columns

In [None]:
# Zero2 Full parameters (OOM)
!accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero2.yaml --num_processes 8 examples/scripts/dpo_zero2.py \
    --dataset_name="<some path>/trl/examples/datasets/hh_trl" \
    --model_name_or_path="<some path>/model_zoo/Meta-Llama-3-8B" \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 4 \
    --attn_implementation 'flash_attention_2' \
    --learning_rate 5e-7 \
    --gradient_accumulation_steps 8 \
    --gradient_checkpointing True \
    --logging_steps 10 \
    --eval_steps 500 \
    --output_dir="dpo_anthropic_hh" \
    --optim rmsprop \
    --max_length 512 \
    --max_prompt_length 128 \
    --warmup_steps 150 \
    --bf16 \
    --logging_first_step \
    --no_remove_unused_columns
    #--sanity_check

## FSDP

In [None]:
# FSDP: Note that batch_size 4 will currently lead to OOM error
!accelerate launch --config_file=examples/accelerate_configs/FSDP_Llama.yaml --num_processes 8 examples/scripts/dpo_fsdp.py \
    --dataset_name="<some path>/trl/examples/datasets/hh_trl" \
    --model_name_or_path="<some path>/model_zoo/Meta-Llama-3-8B" \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 4 \
    --learning_rate 5e-7 \
    --gradient_accumulation_steps 4 \
    --gradient_checkpointing True \
    --logging_steps 10 \
    --eval_steps 500 \
    --output_dir="dpo_anthropic_hh" \
    --optim rmsprop \
    --max_length 512 \
    --max_prompt_length 128 \
    --warmup_steps 150 \
    --bf16 \
    --logging_first_step \
    --no_remove_unused_columns
##
# ddp_timeout: 1800
# fsdp: "full_shard auto_wrap"
# fsdp_transformer_layer_cls_to_wrap: "LlamaDecoderLayer"