# Fine-Tune FLAN T5 with Reinforcement Learning (PPO) and PEFT to Generate Less-Toxic Summaries

Fine-tune a FLAN-T5 model to generate less toxic content by Facebook's hate speech reward model. The reward model is a binary classifier that predicts either "not hate" or "hate" for the given text. Proximal Policy Optimization will be used to fine-tune and reduce the model's toxicity.

In [1]:
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfig
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, TaskType

# trl: Transformer Reinforcement Learning library
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

import torch
import evaluate

import numpy as np
import pandas as pd

from tqdm import tqdm
tqdm.pandas()

  from .autonotebook import tqdm as notebook_tqdm
  warn(


In [2]:
model_name = "google/flan-t5-base"
hf_dataset_name = "knkarthick/dialogsum" 

dataset_original = load_dataset(hf_dataset_name)
print(dataset_original)

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
})


In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
def tokenize_f(row):
    tokens = {'input_ids': [], 'query': []}
    for dialogue in row['dialogue']:
        prompt = "Summarize the following conversation.\n"
        prompt += dialogue
        prompt += "\nSummary: \n"
        
        token = tokenizer.encode(prompt)
        tokens['input_ids'].append(token)

        query = tokenizer.decode(token)
        # this must be called "query", which is a requirement of our PPO library
        tokens['query'].append(query)
    return tokens



def build_dataset(model_name, dataset_name, input_min_text_length, input_max_text_length):
    
    # load dataset (only train part)
    dataset = load_dataset(dataset_name, split="train")

    # Filter the dialogues of length betwee input_min_text_length and input_max_text_length characters.
    # dataset = dataset.filter(lambda x: len(x['dialogue']) > input_min_text_length and len(x['dialogue']) <= input_max_text_length, batched=False)
    dataset = dataset.filter(lambda examples: [len(example) > input_min_text_length and len(example) <= input_max_text_length for example in examples['dialogue']], batched=True)

    dataset = dataset.map(tokenize_f, batched=True)
    dataset.set_format(type="torch")

    dataset_splits = dataset.train_test_split(test_size=0.2, shuffle=False, seed=42)

    return dataset_splits


dataset = build_dataset(model_name, hf_dataset_name, 200, 1000)
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 8017
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 2005
    })
})


In [4]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0

    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters {trainable_model_params/all_model_params * 100}%"


In [5]:
output_dir = f"./dialogue-summary-training-peft"
peft_model_path = "./dialogue-summary-training-peft/lora"

lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)

peft_model = PeftModel.from_pretrained(model, 
                                       peft_model_path,
                                       lora_config=lora_config,
                                       torch_dtype=torch.bfloat16,
                                       device_map="auto",
                                       is_trainable=True)

print(f'PEFT model parameters to be updataed:\n{print_number_of_trainable_model_parameters(peft_model)}')

PEFT model parameters to be updataed:
trainable model parameters: 3538944
all model parameters: 251116800
percentage of trainable model parameters 1.4092820552029972%


In [6]:
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,
                                                               torch_dtype=torch.bfloat16,
                                                               is_trainable=True)

print(f'PPO model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(ppo_model)}')
print(ppo_model.v_head)

PPO model parameters to be updated (ValueHead + 769 params):
trainable model parameters: 3539713
all model parameters: 251117569
percentage of trainable model parameters 1.4095839706062143%
ValueHead(
  (dropout): Dropout(p=0.1, inplace=False)
  (summary): Linear(in_features=768, out_features=1, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)


In [7]:
# create a frozen copy of the PPO which will not be fine-tuned - a reference model. The reference model will represent the LLM before detoxification. None of the parameters of the reference model will be updated during PPO training. This is on purpose.

ref_model = create_reference_model(ppo_model)
print(f'Reference model parameters to be updated:\n{print_number_of_trainable_model_parameters(ref_model)}')


Reference model parameters to be updated:
trainable model parameters: 0
all model parameters: 251117569
percentage of trainable model parameters 0.0%


### 2.1 Prepare Reward Model

In [8]:
toxicity_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_name)
toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name)

print(toxicity_model.config.id2label)

{0: 'nothate', 1: 'hate'}


In [9]:
non_toxic_text = "I want to kiss you"

toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors="pt")
with torch.no_grad():
    logits = toxicity_model(input_ids=toxicity_input_ids.input_ids).logits

print(f'logits [nothate, hate]: {logits.tolist()[0]}')

probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [nothate hate]: {probabilities}')

# get the logits for "not hate" - this is the reward!
not_hate_index = 0
nothate_reward = (logits[:, not_hate_index]).tolist()
print(f'reward (high): {nothate_reward}')


logits [nothate, hate]: [4.657957077026367, -4.078614234924316]
probabilities [nothate hate]: [0.9998394250869751, 0.00016057782340794802]
reward (high): [4.657957077026367]


In [10]:
toxic_text = "You are disgusting and terrible and I damn hate you"

toxic_input_ids = toxicity_tokenizer(toxic_text, return_tensors="pt")
with torch.no_grad():
    logits = toxicity_model(input_ids=toxic_input_ids.input_ids).logits

print(f'logits [nothate, hate]: {logits.tolist()[0]}')

probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [nothate hate]: {probabilities}')

# get the logits for "not hate" - this is the reward!
not_hate_index = 0
nothate_reward = (logits[:, not_hate_index]).tolist()
print(f'reward (high): {nothate_reward}')


logits [nothate, hate]: [4.697163105010986, -4.222471237182617]
probabilities [nothate hate]: [0.999866247177124, 0.00013371932436712086]
reward (high): [4.697163105010986]


In [11]:
sentiment_pipe = pipeline("sentiment-analysis", model=toxicity_model_name)

reward_logits_kwargs = {
    "top_k": None, # Return all scores
    "function_to_apply": "none", # set to "none" to retrieve raw logits
    "batch_size": 16,
}

reward_probabilities_kwargs = {
    "top_k": None, # Return all scores
    "function_to_apply": "softmax", # set to "softmax" to apply softmax to retrieve probabilities
    "batch_size": 16,
}

print("Reward model output for non-toxic text:")
print(sentiment_pipe(non_toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))
print("Reward model output for toxic text:")
print(sentiment_pipe(toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(toxic_text, **reward_probabilities_kwargs))


Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


Reward model output for non-toxic text:
[{'label': 'nothate', 'score': 4.657957077026367}, {'label': 'hate', 'score': -4.078614234924316}]
[{'label': 'nothate', 'score': 0.9998394250869751}, {'label': 'hate', 'score': 0.00016057782340794802}]
Reward model output for toxic text:
[{'label': 'nothate', 'score': 4.697163105010986}, {'label': 'hate', 'score': -4.222471237182617}]
[{'label': 'nothate', 'score': 0.999866247177124}, {'label': 'hate', 'score': 0.00013371930981520563}]


### 2.3 Evaluate Toxicity

In [12]:
toxicity_evaluator = evaluate.load("toxicity",
                                   toxicity_model_name,
                                   module_type="measurement",
                                   toxic_label="hate")

Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


In [13]:
toxicity_score = toxicity_evaluator.compute(predictions=[non_toxic_text])

print("Toxicity score for non-toxic text:")
print(toxicity_score["toxicity"])

toxicity_score = toxicity_evaluator.compute(predictions=[toxic_text])

print("Toxicity score for toxic text:")
print(toxicity_score["toxicity"])

Toxicity score for non-toxic text:
[0.00016057782340794802]
Toxicity score for toxic text:
[0.00013371930981520563]


In [14]:
def evaluate_toxicity(model, toxicity_evaluator, tokenizer, dataset, num_samples):
    max_new_tokens = 100
    toxicities = []
    input_text = []

    for i, sample in tqdm(enumerate(dataset)):
        input_text = sample["query"]

        if i > num_samples:
            break

        input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        generation_config = GenerationConfig(max_new_tokens=max_new_tokens,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)
        response_token_ids = model.generate(input_ids=input_ids, generation_config=generation_config)
        generated_text = tokenizer.decode(response_token_ids[0], skip_special_tokens=True)
        toxicity_score = toxicity_evaluator.compute(predictions=[(input_text + " " + generated_text)])
        toxicities.extend(toxicity_score["toxicity"])

    mean = np.mean(toxicities)
    std = np.std(toxicities)

    return mean, std

In [15]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
# mean_before_detoxification, std_before_detoxification = evaluate_toxicity(model=ref_model, toxicity_evaluator=toxicity_evaluator, tokenizer=tokenizer, dataset=dataset["test"], num_samples=10)

# print(f"toxicity [mean, std] before detox: [{mean_before_detoxification, std_before_detoxification}]")

## 3 Perform Fine-Tuning to Detoxify the Summaries

### 3.1 Initialize PPOTrainer

In [16]:
learning_rate = 1.41e-5
max_ppo_epochs=1
mini_batch_size=4
batch_size=16

config = PPOConfig(
    model_name=model_name,
    learning_rate=learning_rate,
    ppo_epochs= max_ppo_epochs,
    mini_batch_size=mini_batch_size,
    batch_size=batch_size,
)

def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

# You can uncomment the following lines to test the collator
# test_data = [{"key1": "value1", "key2":"value2", "key3": "value3"}]
# print(f'collator input: {test_data}')
# print(f'colator output: {collator(test_data)}')

ppo_trainer = PPOTrainer(config=config,
                         model=ppo_model,
                         ref_model=ref_model,
                         tokenizer=tokenizer,
                         dataset=dataset['train'],
                         data_collator=collator,
                         )




### 3.2 Fine-Tune the Model

In [17]:
output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)

generation_kwargs = {
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}
reward_kwargs = {
    "top_k": None,
    "function_to_apply": "none",
    "batch_size": 16
}

max_ppo_steps = 10

for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    if step >= max_ppo_steps:
        break

    prompt_tensors = batch['input_ids']

    # Get response fro FLAN-T5/PEFT LLM
    summary_tensors = []

    for prompt_tensor in prompt_tensors:
        max_new_tokens = output_length_sampler()

        generation_kwargs["max_new_tokens"] = max_new_tokens
        summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)

        summary_tensors.append(summary.squeeze()[-max_new_tokens:])

    # This needs to be called "response".
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in summary_tensors]

    # Compute reward outputs.
    query_response_pairs = [q + r for q,r in zip(batch['query'], batch['response'])]
    rewards = sentiment_pipe(query_response_pairs, **reward_kwargs)

    # Use the 'nothate' item because this is the score for the positive 'nothate' class
    reward_tensors = [torch.tensor(reward[not_hate_index]["score"]) for reward in rewards]

    # Run PPO step
    stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)

    print(f'objective/kl: {stats["objective/kl"]}')
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}')
    print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}')
    print('-'.join('' for _ in range(50)))

1it [00:48, 48.62s/it]

objective/kl: 24.6973876953125
ppo/returns/mean: -0.695402979850769
ppo/policy/advantages_mean: -0.0012105926871299744
-------------------------------------------------


2it [01:40, 50.27s/it]

objective/kl: 22.060958862304688
ppo/returns/mean: -0.55558842420578
ppo/policy/advantages_mean: 0.02511385828256607
-------------------------------------------------


3it [02:25, 48.16s/it]

objective/kl: 19.92251968383789
ppo/returns/mean: -0.2632604241371155
ppo/policy/advantages_mean: 0.049807898700237274
-------------------------------------------------


4it [03:28, 53.99s/it]

objective/kl: 27.786954879760742
ppo/returns/mean: -0.9784796237945557
ppo/policy/advantages_mean: 0.1155431792140007
-------------------------------------------------


5it [04:31, 57.35s/it]

objective/kl: 22.715763092041016
ppo/returns/mean: -0.2824663519859314
ppo/policy/advantages_mean: 0.1722594052553177
-------------------------------------------------


6it [05:46, 63.29s/it]

objective/kl: 15.116639137268066
ppo/returns/mean: 0.15190088748931885
ppo/policy/advantages_mean: 0.10655581951141357
-------------------------------------------------


7it [06:30, 56.83s/it]

objective/kl: 14.222274780273438
ppo/returns/mean: 0.24004651606082916
ppo/policy/advantages_mean: -0.01625499688088894
-------------------------------------------------


8it [07:19, 54.56s/it]

objective/kl: 13.79407024383545
ppo/returns/mean: 0.19849753379821777
ppo/policy/advantages_mean: -0.026937760412693024
-------------------------------------------------


9it [08:06, 52.12s/it]

objective/kl: 15.047981262207031
ppo/returns/mean: -0.11166799068450928
ppo/policy/advantages_mean: 0.09892699867486954
-------------------------------------------------


10it [08:41, 52.18s/it]

objective/kl: 9.896147727966309
ppo/returns/mean: 0.7504624724388123
ppo/policy/advantages_mean: 0.14878997206687927
-------------------------------------------------





### 3.4 Evaluate the Model Qualitatively 

In [None]:
batch_size = 20
compare_results = {}

df_batch = dataset["test"][0:batch_size]

compare_results["query"] = df_batch["query"]
prompt_tensors = df_batch["input_ids"]

summary_tensors_ref = []
summary_tensors = []

# Get response from PPO and base model
for i in tqdm(range(batch_size)):
    gen_len = output_length_sampler()
    generation_kwargs["max_new_tokens"] = gen_len

    summary = ref_model.generate(
        input_ids=torch.as_tensor(prompt_tensors[i]).unsqueeze(dim=0), 
        **generation_kwargs,
    ).squeeze()[-gen_len:]
    summary_tensors_ref.append(summary)

    summary = ppo_model.generate(
        input_ids=torch.as_tensor(prompt_tensors[i]).unsqueeze(dim=0),
        **generation_kwargs
    ).squeeze()[-gen_len:]
    summary_tensors.append(summary)

# Decode responses
compare_results["response_before"] = [tokenizer.decode(summary_tensors_ref[i]) for i in range(batch_size)]
compare_results["response_after"] = [tokenizer.decode(summary_tensors[i]) for i in range(batch_size)]

# Sentiment analysis of query/response pairs before/after
texts_before = [d + s for d,s in zip(compare_results["query"], compare_results["response_before"])]
rewards_before = sentiment_pipe(texts_before, **reward_kwargs)
compare_results["reward_before"] = [reward[not_hate_index]["score"] for reward in rewards_before]

texts_after = [d + s for d,s in zip(compare_results["query"], compare_results["response_after"])]
rewards_after = sentiment_pipe(texts_after, **reward_kwargs)
compare_results["reward_after"] = [reward[not_hate_index]["score"] for reward in rewards_after]

In [None]:
pd.set_option('display.max_colwidth', 500)
df_compare_results = pd.DataFrame(compare_results)
df_compare_results["reward_diff"] = df_compare_results["reward_after"] - df_compare_results["reward_before"]
df_compare_results_sorted = df_compare_results.sort_values(by=["reward_diff"], ascending=False).reset_index(drop=True)
df_compare_results_sorted