Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update KTO example with good dataset & chat format #1481

Merged
merged 2 commits into from
Mar 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
105 changes: 32 additions & 73 deletions examples/scripts/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,49 +13,45 @@
# limitations under the License.

"""
Run the KTO training script with the following command with some example arguments.
In general, the optimal configuration for KTO will be similar to that of DPO:
Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO.

# regular:
# Full training:
python examples/scripts/kto.py \
--model_name_or_path=gpt2 \
--model_name_or_path=stabilityai/stablelm-2-zephyr-1_6b \
--per_device_train_batch_size 16 \
--max_steps 1000 \
--num_train_epochs 1 \
--learning_rate 2e-5 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="kto_anthropic_hh" \
--output_dir="kto-aligned-model" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns
--logging_first_step

# peft:
# LoRA:
python examples/scripts/kto.py \
--model_name_or_path=gpt2 \
--model_name_or_path=stabilityai/stablelm-2-zephyr-1_6b \
--per_device_train_batch_size 16 \
--max_steps 1000 \
--num_train_epochs 1 \
--learning_rate 2e-4 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="kto_anthropic_hh" \
--output_dir="kto-aligned-model-lora" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
"""

from dataclasses import dataclass, field
from typing import Optional
from dataclasses import dataclass

from datasets import Dataset, load_dataset
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config
Expand All @@ -68,85 +64,48 @@ class ScriptArguments:
The arguments for the KTO training script.
"""

# debugging
sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"})


def extract_anthropic_prompt(prompt_and_response):
"""Extract the anthropic prompt from a prompt and response pair."""
search_term = "\n\nAssistant:"
search_term_idx = prompt_and_response.rfind(search_term)

if search_term_idx == -1:
raise ValueError(f"Prompt and response does not contain '{search_term}'")

return prompt_and_response[: search_term_idx + len(search_term)]


def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.

The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'completion': List[str],
'label': List[bool],
}

Prompts should be structured as follows:
\n\nHuman: <prompt>\n\nAssistant:
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
"""
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))

flat_data = {
"prompt": [],
"completion": [],
"label": [],
}
for sample in dataset:
prompt = extract_anthropic_prompt(sample["chosen"])
flat_data["prompt"].append(prompt)
flat_data["completion"].append(sample["chosen"][len(prompt) :])
flat_data["label"].append(True)
flat_data["prompt"].append(prompt)
flat_data["completion"].append(sample["rejected"][len(prompt) :])
flat_data["label"].append(False)

return dataset.from_dict(flat_data)
dataset_name: str = "trl-lib/kto-mix-14k"


if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
script_args, kto_args, model_args = parser.parse_args_into_dataclasses()

# 1. load a pretrained model
# Load a pretrained model
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)
model_ref = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)

tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
raise ValueError(
"Tokenizer must have a chat template in order to format the examples. Alternatively, adjust this script to format the examples differently."
)

# Load the dataset
dataset = load_dataset(script_args.dataset_name)

# 2. Load the Anthropic Helpful-Harmless dataset
train_dataset = get_hh("train", sanity_check=script_args.sanity_check)
# Apply chat template
def format_dataset(example):
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
kashif marked this conversation as resolved.
Show resolved Hide resolved
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
return example

# 3. Load evaluation dataset
eval_dataset = get_hh("test", sanity_check=script_args.sanity_check)
formatted_dataset = dataset.map(format_dataset)

# 4. initialize the KTO trainer
# Initialize the KTO trainer
kto_trainer = KTOTrainer(
model,
model_ref,
args=kto_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
train_dataset=formatted_dataset["train"],
eval_dataset=formatted_dataset["test"],
tokenizer=tokenizer,
peft_config=get_peft_config(model_args),
)

# 5. train and save the model
# Train and push the model to the Hub
kto_trainer.train()
kto_trainer.save_model(kto_args.output_dir)
kto_trainer.push_to_hub()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should put push_to_hub() in all our example scripts to track usage on the Hub