In [None]:
# flake8: noqa
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# regular:
python examples/scripts/dpo.py \
    --dataset_name=trl-internal-testing/hh-rlhf-helpful-base-trl-style \
    --model_name_or_path=gpt2 \
    --per_device_train_batch_size 4 \
    --learning_rate 1e-3 \
    --gradient_accumulation_steps 1 \
    --logging_steps 10 \
    --eval_steps 500 \
    --output_dir="dpo_anthropic_hh" \
    --warmup_steps 150 \
    --report_to wandb \
    --bf16 \
    --logging_first_step \
    --no_remove_unused_columns

# peft:
python examples/scripts/dpo.py \
    --dataset_name=trl-internal-testing/hh-rlhf-helpful-base-trl-style \
    --model_name_or_path=gpt2 \
    --per_device_train_batch_size 4 \
    --learning_rate 1e-3 \
    --gradient_accumulation_steps 1 \
    --logging_steps 10 \
    --eval_steps 500 \
    --output_dir="dpo_anthropic_hh" \
    --optim rmsprop \
    --warmup_steps 150 \
    --report_to wandb \
    --bf16 \
    --logging_first_step \
    --no_remove_unused_columns \
    --use_peft \
    --lora_r=16 \
    --lora_alpha=16
"""

import logging
import multiprocessing
import os
from contextlib import nullcontext

TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)

from trl.commands.cli_utils import DPOScriptArguments, init_zero_verbose, TrlParser

if TRL_USE_RICH:
    init_zero_verbose()
    FORMAT = "%(message)s"

    from rich.console import Console
    from rich.logging import RichHandler

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from trl import (
    DPOConfig,
    DPOTrainer,
    ModelConfig,
    RichProgressCallback,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)


if TRL_USE_RICH:
    logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)

In [9]:
%%bash --out TOP_LEVEL
printf "$(git rev-parse --show-toplevel)"

In [33]:
#parser = TrlParser((DPOScriptArguments, DPOConfig, ModelConfig))
#args, training_args, model_config = parser.parse_args_and_config()

import yaml
from pathlib import Path

# Force use our print callback
if TRL_USE_RICH:
    training_args.disable_tqdm = True
    console = Console()

config = yaml.safe_load(Path(TOP_LEVEL + '/configs/default.yaml').read_text())


################
# Model & Tokenizer
################
# torch_dtype = (
#     model_config.torch_dtype
#     if model_config.torch_dtype in ["auto", None]
#     else getattr(torch, model_config.torch_dtype)
# )
config = yaml.safe_load(Path(TOP_LEVEL + '/configs/default.yaml').read_text())
match config['model']['torch_dtype']:
    case 'float16':
        torch_dtype = torch.float16
    case 'float32':
        torch_dtype = torch.float32
    case 'float64':
        torch_dtype = torch.float64
    case 'bfloat16':
        torch_dtype = torch.bfloat16
    case 'auto':
        torch_dtype = "auto"
    case _:
        raise ValueError('torch_dtype is invalid')
    
# quantization_config = get_quantization_config(model_config)
# model_kwargs = dict(
#     revision=model_config.model_revision,
#     trust_remote_code=model_config.trust_remote_code,
#     attn_implementation=model_config.attn_implementation,
#     torch_dtype=torch_dtype,
#     use_cache=False if training_args.gradient_checkpointing else True,
#     device_map=get_kbit_device_map() if quantization_config is not None else None,
#     quantization_config=quantization_config,
# )
model = AutoModelForCausalLM.from_pretrained(config['model']['path'], torch_dtype=torch_dtype).to("cuda") #, **model_kwargs)
#peft_config = get_peft_config(model_config)
#if peft_config is None:
model_ref = AutoModelForCausalLM.from_pretrained(config['model']['path']) #, **model_kwargs)
# else:
#     model_ref = None
tokenizer = AutoTokenizer.from_pretrained(config['model']['path'])
# if tokenizer.pad_token is None:
#     tokenizer.pad_token = tokenizer.eos_token
# if tokenizer.chat_template is None:
#     tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"
# if args.ignore_bias_buffers:
#     # torch distributed hack
#     model._ddp_params_and_buffers_to_ignore = [
#         name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
#     ]

# ################
# # Optional rich context managers
# ###############
# init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the DPOTrainer...")
# save_context = (
#     nullcontext()
#     if not TRL_USE_RICH
#     else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
# )

In [14]:
def return_prompt_and_responses(samples): # -> Dict[str, str, str]:
    return {
        "prompt": [
            "Question: " + question + "\n\nAnswer: "
            for question in samples["question"]
        ],
        "chosen": samples["response_j"],   # rated better than k
        "rejected": samples["response_k"], # rated worse than j
    }

dataset = load_dataset(
    "lvwerra/stack-exchange-paired",
    split="train",
    data_dir="data/rl"
)
original_columns = dataset.column_names

train_dataset = dataset.map(
    return_prompt_and_responses,
    batched=True,
    remove_columns=original_columns
)
ds = train_dataset.select(range(100))
print(train_dataset)

Dataset({
    features: ['prompt', 'chosen', 'rejected'],
    num_rows: 7435908
})


In [23]:
################
# Dataset
################
# ds = load_dataset(args.dataset_name)
# if args.sanity_check:
#     for key in ds:
#         ds[key] = ds[key].select(range(50))

# def process(row):
#     row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
#     row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
#     return row

# ds = ds.map(
#     process,
#     num_proc=multiprocessing.cpu_count(),
#     load_from_cache_file=False,
# )
train_dataset = ds #[0:90] #[args.dataset_train_split]
#eval_dataset = ds[91:99] #[args.dataset_test_split]
#print(train_dataset)

In [None]:
################
# Training
################
#with init_context:
NUM_TRAIN_EPOCHS = 20
OUTPUT_DIR = TOP_LEVEL + f"/alfred/output/{config['model']['path']},torch_dtype={torch_dtype}/epoch={NUM_TRAIN_EPOCHS}"
#os.makedirs(os.path.dirname(OUTPUT_DIR), exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

training_args = DPOConfig(
    beta=0.1,
    # does not automatically save model output
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_TRAIN_EPOCHS
)

trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    train_dataset=train_dataset,
    #eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    #peft_config=get_peft_config(model_config),
    #callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
)

trainer.train()

#with save_context:

In [27]:
trainer.save_model(OUTPUT_DIR+'/final-dpo1')

In [28]:
print(torch_dtype)
model = AutoModelForCausalLM.from_pretrained(OUTPUT_DIR + '/final-dpo1', torch_dtype=torch_dtype).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR + '/final-dpo1') #config["tokenizer"]["path"])

torch.float16


In [29]:
prompt_ids = torch.tensor(tokenizer("Echo all the statements that are provided.\n")["input_ids"])
query_ids = torch.tensor(tokenizer("Why is the sky blue?")["input_ids"])
prompt_and_query_ids = torch.cat([prompt_ids, query_ids], dim=0)
print(prompt_and_query_ids)
tokenizer.decode(prompt_and_query_ids, skip_special_tokens=True)

tensor([224619,   1728,    368,  43163,    861,   1306,  15984,    336,  23857,
           632,    368,  60614,  29853,     34])


'Echo all the statements that are provided.\nWhy is the sky blue?'

In [34]:
response_ids = model.generate(
                  torch.Tensor(prompt_and_query_ids).unsqueeze(0).to("cuda"),
    num_beams=1, max_new_tokens=100,              
    repetition_penalty=1.2 #,temperature = 0
            )
print(response_ids[0])
tokenizer.batch_decode(response_ids)

tensor([224619,   1728,    368,  43163,    861,   1306,  15984,    336,  23857,
           632,    368,  60614,  29853,     34,   1387,  12300,    427,   1119,
          5893,  39152,    664,   2632,  32391,    461,   3595,    267,    567,
         47490,      5,    791,  16554,  18681,   1776,  73173,  37287,  10925,
            17,   1004,   5827,     15,    718,  64559,  55326,    189,     36,
         19182,  12364,    375,    280,    660,   4052,   1002,    654,  12490,
            12,    361,   2131,   2782,  24763,   3804,   2592,   9671,     30,
           613,   6635,   1380,   2175,  22779,   1256,  13682,   1809,   3784,
         32046,  29369,   1331,   3776,  51890,    530,  17393,  27660,    661,
          1320,   6648,   1130, 226305,    919,   2914,   3509,   2494,   6168,
          6416,   1400,    722,  15397,   3262,   1152,   5382,   6147,  43624,
          6054,      2], device='cuda:0')


['Echo all the statements that are provided.\nWhy is the sky blue? The answer to this question depends on your understanding of what a "black" or "brownish browning" means. In general, it refers to:\nA black body (or an object with no light) in which there exists only one color; for example,\nThe sun\'s surface has been darkened by its radiation and therefore appears as if it\'s not shining at any time. (This effect can be seen when you look through windows.)</s>']