In [2]:
%pip install torch==1.13.1 torchdata

%pip install --disable-pip-version-check -q \
    transformers==4.27.2 \
    datasets==2.11.0 \
    accelerate==0.18.0 \
    evaluate==0.4.0 \
    trl==0.4.1 \
    rouge_score==0.1.2 \
    loralib==0.1.1

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.


In [3]:
!pip install git+https://github.com/huggingface/peft.git

Collecting git+https://github.com/huggingface/peft.git
  Cloning https://github.com/huggingface/peft.git to /tmp/pip-req-build-rtv8ktfc
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/peft.git /tmp/pip-req-build-rtv8ktfc
  Resolved https://github.com/huggingface/peft.git to commit a37156c2c7966ff7d1487d33734738414aa0bafc
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [4]:
!pip install git+https://github.com/lvwerra/trl.git

Collecting git+https://github.com/lvwerra/trl.git
  Cloning https://github.com/lvwerra/trl.git to /tmp/pip-req-build-7ks838p8
  Running command git clone --filter=blob:none --quiet https://github.com/lvwerra/trl.git /tmp/pip-req-build-7ks838p8
  Resolved https://github.com/lvwerra/trl.git to commit 08f550674c553c36c51d1027613c29f14f3676a5
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: trl
  Building wheel for trl (setup.py) ... [?25ldone
[?25h  Created wheel for trl: filename=trl-0.4.2.dev0-py3-none-any.whl size=59756 sha256=2b0ce8cc4651071e12a765157e0f2b4a7cb9144eae3c1784b957722dd37097ab
  Stored in directory: /tmp/pip-ephem-wheel-cache-n9nvkn_n/wheels/ca/6e/f4/b183ecbed483efdcd2041a8021ce7bcb9f7b09c74bff5bb00a
Successfully built trl
Installing collected packages: trl
  Attempting uninstall: trl
    Found existing installation: trl 0.4.1
    Uninstalling trl-0.4.1:
      Successfully uninstalled trl-0.4.1
Successfully installed trl-0.4.2.dev0


In [5]:
from dataclasses import dataclass, field
from typing import Optional

import torch
from datasets import load_dataset
from torch.optim import Adam
from tqdm import tqdm
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    HfArgumentParser,
    AutoModelForSequenceClassification,
    AutoTokenizer,
)

from trl import AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, create_reference_model, set_seed
from trl.core import LengthSampler

tqdm.pandas()

@dataclass
class ScriptArguments:
    """
    The name of the Seq2Seq LM model we wish to fine with PPO
    """
    model_name: Optional[str] = field(default="google/flan-t5-base", metadata={"help": "the model name"})    
    log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
    learning_rate: Optional[float] = field(default=(1.47e-5) * 2, metadata={"help": "the learning rate"})
    mini_batch_size: Optional[int] = field(default=4, metadata={"help": "the PPO minibatch size"})
    batch_size: Optional[int] = field(default=16, metadata={"help": "the batch size"})
    gradient_accumulation_steps: Optional[int] = field(
        default=1, metadata={"help": "the number of gradient accumulation steps"}
    )
    model_save_path: Optional[str] = field(
        default="./peft_fine_tuned_with_detoxification_rewards",
        metadata={"help": "the path to save the model"},
    )


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)[0]

config = PPOConfig(
    model_name=script_args.model_name,
    learning_rate=script_args.learning_rate,
    log_with=script_args.log_with,
    ppo_epochs=1, # was 100
    mini_batch_size=script_args.mini_batch_size,
    batch_size=script_args.batch_size,
    gradient_accumulation_steps=script_args.gradient_accumulation_steps,
)

In [6]:
def build_dataset(
    config, dataset_name="allenai/real-toxicity-prompts", input_min_text_length=5, input_max_text_length=10
):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)

    ds = load_dataset(dataset_name, split="train")

    def filter_fn(sample):
        toxicity = sample["prompt"]["toxicity"]
        return toxicity is not None and toxicity > 0.3

    ds = ds.filter(filter_fn, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        prompt = sample["prompt"]["text"]
        continuation = sample["continuation"]["text"]

        sample["input_ids"] = tokenizer.encode(prompt + continuation)[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")

    ds = ds.train_test_split(test_size=0.2, shuffle=False)["train"]

    return ds


min_input_length = 30
max_input_length = 40

dataset = build_dataset(config, input_min_text_length=min_input_length, input_max_text_length=max_input_length)

def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

Found cached dataset json (/root/.cache/huggingface/datasets/allenai___json/allenai--real-toxicity-prompts-d8a476abeeb3bf44/0.0.0/fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e)
Loading cached processed dataset at /root/.cache/huggingface/datasets/allenai___json/allenai--real-toxicity-prompts-d8a476abeeb3bf44/0.0.0/fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e/cache-2df44a0de5008bf6.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/allenai___json/allenai--real-toxicity-prompts-d8a476abeeb3bf44/0.0.0/fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e/cache-0e795f8f86f24bac.arrow


In [7]:
# We load the model in bf16 to save memory.
model = AutoModelForSeq2SeqLM.from_pretrained(config.model_name, torch_dtype=torch.bfloat16)
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(model)

ref_model = create_reference_model(model)

optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)

tokenizer = AutoTokenizer.from_pretrained(config.model_name)
#tokenizer.pad_token = tokenizer.eos_token

In [8]:
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(
    config,
    model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    dataset=dataset,
    data_collator=collator,
    optimizer=optimizer,
)

# Let's re-use Facebook/Meta's detoxification model to compute the reward.
toxicity_model_id = "facebook/roberta-hate-speech-dynabench-r4-target"
toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_id)

# We load the model in fp16 to save memory.
toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_id, torch_dtype=torch.float16).to(
    ppo_trainer.accelerator.device
)

# We then define the arguments to pass to the `generate` function. These arguments
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
# the `generate` function of the trained model.
generation_kwargs = {
    "min_length": 4, # changed from -1 to workaround "must be 4 min tokens" error
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
#    "pad_token_id": tokenizer.eos_token_id,
}

In [None]:
output_min_length = 20
output_max_length = 200
output_length_sampler = LengthSampler(output_min_length, output_max_length)

peft_fine_tuned_with_detoxification_rewards_checkpoint = script_args.model_save_path

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]

    # Get response from the policy model
    response_tensors = []
    for query in query_tensors:
        gen_len = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append(response.squeeze()[-gen_len:])
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

    # Compute toxicity score for the response pair
    texts = batch["response"]
    toxicity_inputs = toxicity_tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(
        ppo_trainer.accelerator.device
    )
    logits = toxicity_model(**toxicity_inputs).logits.float()
    toxicity_labels = (logits[:, 0]).tolist()

    rewards = [torch.tensor(output) for output in toxicity_labels]

    # Run PPO gradient-update step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

    # Save model every 100 epochs
    if epoch % 100 == 0:
        if ppo_trainer.accelerator.is_main_process:
            #ppo_trainer.save_pretrained(peft_fine_tuned_with_detoxification_rewards_checkpoint) # depends on huggingface hub
            ppo_trainer.tokenizer.save_pretrained(peft_fine_tuned_with_detoxification_rewards_checkpoint)
            ppo_trainer.accelerator.unwrap_model(ppo_trainer.model).save_pretrained(peft_fine_tuned_with_detoxification_rewards_checkpoint) # merge
            #ppo_trainer.model.save_pretrained(peft_fine_tuned_with_detoxification_rewards_checkpoint)

  f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training."
  f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training."
  f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training."
  f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training."
  f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training."
  f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training."
  f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training."
  f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training."
  f"KL divergence is starting to

# Save model

In [None]:
#ppo_trainer.save_pretrained(model_save_path) # depends on huggingface hub
ppo_trainer.tokenizer.save_pretrained(peft_fine_tuned_with_detoxification_rewards_checkpoint)
ppo_trainer.accelerator.unwrap_model(ppo_trainer.model).save_pretrained(peft_fine_tuned_with_detoxification_rewards_checkpoint) # merge?
#ppo_trainer.model.save_pretrained(peft_fine_tuned_with_detoxification_rewards_checkpoint)

In [None]:
# %store peft_fine_tuned_with_detoxification_rewards_checkpoint

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig

reward_model = AutoModelForSeq2SeqLM.from_pretrained(peft_fine_tuned_with_detoxification_rewards_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(peft_fine_tuned_with_detoxification_rewards_checkpoint)