In [1]:
pip install -U transformers datasets evaluate accelerate trl wandb

Note: you may need to restart the kernel to use updated packages.


In [2]:
from datasets import load_dataset

dataset = load_dataset("Jise/hh-rlhf-helpful-base")
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['chosen', 'rejected', 'prompt'],
        num_rows: 43835
    })
    test: Dataset({
        features: ['chosen', 'rejected', 'prompt'],
        num_rows: 2354
    })
})


In [3]:
from transformers import AutoTokenizer

model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
from transformers import AutoModelForSeq2SeqLM, TrainingArguments
from trl import DPOTrainer, DPOConfig
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
import requests
import pickle
from transformers.optimization import Adafactor, AdafactorSchedule
import wandb

run_name = "Flan-T5_DPO_HH-RLHF"

with open("TOKENS.pkl", "rb") as f:
    TOKENS = pickle.load(f)

WANDB_TOKEN = TOKENS["WANDB_TOKEN"]
HF_TOKEN = TOKENS["HF_TOKEN"]

wandb.login(key=WANDB_TOKEN)

if tokenizer.chat_template is None:
    tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
lr_scheduler = AdafactorSchedule(optimizer)

training_args = DPOConfig(
    output_dir="./flan-t5-dpo",
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    learning_rate=1e-5,
    num_train_epochs=1,
    logging_dir="./logs",
    logging_steps=50,
    weight_decay=0.01,
    max_length=512,
    save_total_limit=1,
    truncation_mode='keep_end',
    bf16=True,
    push_to_hub=True,
    report_to="wandb",
    run_name=run_name,
    hub_token=HF_TOKEN,
    hub_model_id="Jise/flan-t5-hh-dpo",
    save_safetensors=False,
)


trainer = DPOTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=tokenizer,
    optimizers=(optimizer, lr_scheduler),
)

2024-11-25 02:29:00.529084: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjiseshen[0m ([33mjise[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/idies/.netrc


Tokenizing eval dataset:   0%|          | 0/2354 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (553 > 512). Running this sequence through the model will result in indexing errors
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [5]:
trainer.train()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,Rewards/chosen,Rewards/rejected,Rewards/accuracies,Rewards/margins,Logps/chosen,Logps/rejected,Logits/chosen,Logits/rejected
50,0.7189,0.672445,0.178261,0.125236,0.596186,0.053026,-159.658859,-128.914856,-18.052717,-18.023928
100,0.6828,0.64828,0.171495,0.032626,0.623729,0.138869,-159.726517,-129.840942,-18.309679,-18.320799
150,0.6743,0.639346,0.262279,0.063989,0.633898,0.19829,-158.81868,-129.527328,-18.401348,-18.413612
200,0.6642,0.633705,0.258616,-0.019284,0.636441,0.2779,-158.855301,-130.360046,-18.267185,-18.315479
250,0.6582,0.620953,0.257827,-0.074899,0.652119,0.332726,-158.86319,-130.916214,-18.736591,-18.793457
300,0.6545,0.618397,0.589382,0.247174,0.666949,0.342208,-155.547638,-127.695473,-18.531569,-18.618721
350,0.6445,0.625849,0.380338,-0.123173,0.664831,0.503511,-157.638092,-131.398941,-18.739565,-18.846008
400,0.659,0.615225,0.409246,-0.035849,0.662712,0.445094,-157.34903,-130.525696,-18.65481,-18.776262
450,0.6475,0.631365,0.037991,-0.485717,0.647034,0.523708,-161.061554,-135.024368,-18.771259,-18.874756
500,0.671,0.631687,0.189137,-0.288903,0.648729,0.47804,-159.550079,-133.056244,-18.264383,-18.333036


TrainOutput(global_step=2739, training_loss=0.8236945821839382, metrics={'train_runtime': 23062.1233, 'train_samples_per_second': 1.901, 'train_steps_per_second': 0.119, 'total_flos': 0.0, 'train_loss': 0.8236945821839382, 'epoch': 0.9997262523952916})

In [6]:
results = trainer.evaluate()
print(results)

# 保存模型
trainer.save_model("./flan-t5-dpo-trained")
tokenizer.save_pretrained("./flan-t5-dpo-trained")

{'eval_loss': 1.1456202268600464, 'eval_runtime': 258.6968, 'eval_samples_per_second': 9.099, 'eval_steps_per_second': 1.14, 'eval_rewards/chosen': -12.91569995880127, 'eval_rewards/rejected': -13.150213241577148, 'eval_rewards/accuracies': 0.5381355881690979, 'eval_rewards/margins': 0.23451119661331177, 'eval_logps/chosen': -290.5984802246094, 'eval_logps/rejected': -261.6693420410156, 'eval_logits/chosen': -5.470920562744141, 'eval_logits/rejected': -5.361225128173828, 'epoch': 0.9997262523952916}


('./flan-t5-dpo-trained/tokenizer_config.json',
 './flan-t5-dpo-trained/special_tokens_map.json',
 './flan-t5-dpo-trained/tokenizer.json')