# Fine-Tune FLAN-T5 with Reinforcement Learning (PPO) and PEFT to Generate Less-Toxic Summaries




In [1]:
%pip install --upgrade pip
%pip install --disable-pip-version-check \
    torch \
    torchdata --quiet
%pip install \
    torch==1.13.1+cu117 torchvision==0.14.1+cu117\
    --extra-index-url\
    https://download.pytorch.org/whl/cu117

%pip install \
    transformers \
    datasets \
    evaluate \
    rouge_score \
    peft --quiet

# Installing the Reinforcement Learning library directly from github.
%pip install git+https://github.com/lvwerra/trl.git@25fa1bd    

Defaulting to user installation because normal site-packages is not writeable
[33mDEPRECATION: distro-info 0.23ubuntu1 has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of distro-info or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m[33mDEPRECATION: python-debian 0.1.36ubuntu1 has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of python-debian or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.
[33mDEPRECATION: distro-info 0.23ubuntu1 has a non-standard version number. pip 23.3 will enforce this behaviou

[33mDEPRECATION: distro-info 0.23ubuntu1 has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of distro-info or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m[33mDEPRECATION: python-debian 0.1.36ubuntu1 has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of python-debian or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfig,  TrainingArguments, Trainer
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, TaskType

# trl: Transformer Reinforcement Learning library
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

import torch
import evaluate

import numpy as np
import pandas as pd

# tqdm library makes the loops show a smart progress meter.
from tqdm import tqdm
tqdm.pandas()

### First we need the fine-tuned model to summarize the text.
#### I use LoRA to fine-tune the Flan-T5 model.

In [3]:
# load dialogue and summary dataset
model_name="google/flan-t5-base"
huggingface_dataset_name = "knkarthick/dialogsum"

dataset_original = load_dataset(huggingface_dataset_name)

dataset_original

Found cached dataset csv (/home/azadeh/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-931380d0e19583fc/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
})

In [4]:
# load pre-trained model and its tokenizer
model_name="google/flan-t5-base"
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)


In [5]:
def tokenize_function(example):
    prompt = []
    summ = []
    
    for idx, dialogue in enumerate(example["dialogue"]):
        prompt.append(f"""Summarize the following conversation.

        {dialogue}

        Summary:
        """)
        
        
    

        
    example = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt")
    
    
#     i = torch.tensor(item).to('cuda:0')
#         f.append(i)


    return example



tlr_datasets = dataset_original.map(tokenize_function, batched=True)
        


Loading cached processed dataset at /home/azadeh/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-931380d0e19583fc/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-0760e73a230b4eed.arrow
Loading cached processed dataset at /home/azadeh/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-931380d0e19583fc/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-90349f0fee0bae0d.arrow


Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [6]:
# define the map function for the tokenizer to modify the prompt

def tokenize_function(example):
    prompt = []
    summ = []
    
    
    for idx, dialogue in enumerate(example["dialogue"]):
        prompt.append(f"""Summarize the following conversation.

        {dialogue}

        Summary:
        """)
        summ.append(example['summary'][idx])
        
#         prompt.append(torch.tensor(tokenizer(p, padding="max_length", truncation=True, return_tensors="pt").input_ids))


#     print(prompt)  
    example['input_ids'] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
    
    example['labels'] = tokenizer(summ, padding="max_length", truncation=True, return_tensors="pt").input_ids
#     example['attention_mask'] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").attention_mask
    example['query'] =  prompt
    return example



tokenized_datasets = dataset_original.map(tokenize_function, batched=True)
        
tokenized_datasets = tokenized_datasets.remove_columns(['id', 'topic', 'dialogue', 'summary',])
print(tokenized_datasets)

Loading cached processed dataset at /home/azadeh/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-931380d0e19583fc/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-d194f84673c2fb74.arrow
Loading cached processed dataset at /home/azadeh/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-931380d0e19583fc/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-eb8feb074be70ebe.arrow
Loading cached processed dataset at /home/azadeh/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-931380d0e19583fc/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-7f8d536b7a155430.arrow


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels', 'query'],
        num_rows: 12460
    })
    test: Dataset({
        features: ['input_ids', 'labels', 'query'],
        num_rows: 1500
    })
    validation: Dataset({
        features: ['input_ids', 'labels', 'query'],
        num_rows: 500
    })
})


In [7]:
from peft import LoraConfig, get_peft_model, TaskType
lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)



peft_model = get_peft_model(original_model, 
                            lora_config)

In [8]:
output_dir ='./model'

peft_training_args = TrainingArguments(
    output_dir=output_dir,
    auto_find_batch_size=True,
    learning_rate=1e-3, # Higher learning rate than full fine-tuning.
    num_train_epochs=1,
    
)
    
peft_trainer = Trainer(
    model=peft_model,
    args=peft_training_args,
    train_dataset=tokenized_datasets["train"],
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [9]:
peft_trainer.train()




huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[34m[1mwandb[0m: Currently logged in as: [33maz-mozaffari[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
500,1.1418
1000,0.122
1500,0.1157


TrainOutput(global_step=1558, training_loss=0.4471316453887195, metrics={'train_runtime': 478.9049, 'train_samples_per_second': 26.018, 'train_steps_per_second': 3.253, 'total_flos': 8667537195663360.0, 'train_loss': 0.4471316453887195, 'epoch': 1.0})

In [10]:
# load fine tuned model from the memory
peft_model_path="./model"

peft_trainer.model.save_pretrained(peft_model_path)
tokenizer.save_pretrained(peft_model_path)




('./model/tokenizer_config.json',
 './model/special_tokens_map.json',
 './model/tokenizer.json')

In [11]:
model_name="google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)

peft_model_path="./model"
tokenizer = AutoTokenizer.from_pretrained(peft_model_path)

# generate some queries to test the accuracy of the new model and compare it to the original model
dialogues = dataset_original['test'][0:20]['dialogue']
human_baseline_summaries = dataset_original['test'][0:20]['summary']

original_model_summaries = []
instruct_model_summaries = []
peft_model_summaries = []
# original_model = original_model.to('cpu')

for idx, dialogue in enumerate(dialogues):
    prompt = f"""
Summarize the following conversation.

{dialogue}

Summary: """
    
#     input_ids = tokenizer_original(prompt, return_tensors="pt").input_ids
#     input_ids = input_ids.to('cuda')
    input_ids  = tokenizer.encode(prompt, return_tensors="pt")
    input_ids = input_ids.to('cuda:0')
    human_baseline_text_output = human_baseline_summaries[idx]
    
    original_model_outputs = original_model.generate( input_ids, generation_config=GenerationConfig(max_new_tokens=200))
    original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)

    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    input_ids = input_ids.to('cuda:0')
    
    peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))
    peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)

    original_model_summaries.append(original_model_text_output)
    peft_model_summaries.append(peft_model_text_output)
    


In [12]:
# evaluate the accuracy of the trained peft model and compare to original model using rouge metric 
rouge = evaluate.load('rouge')

original_model_results = rouge.compute(
    predictions=original_model_summaries,
    references=human_baseline_summaries[0:len(original_model_summaries)],
    use_aggregator=True,
    use_stemmer=True,
)



peft_model_results = rouge.compute(
    predictions=peft_model_summaries,
    references=human_baseline_summaries[0:len(peft_model_summaries)],
    use_aggregator=True,
    use_stemmer=True,
)

print('ORIGINAL MODEL:')
print(original_model_results)
print('PEFT MODEL:')
print(peft_model_results)

ORIGINAL MODEL:
{'rouge1': 0.3625816328737917, 'rouge2': 0.10143738651442691, 'rougeL': 0.2867149055312152, 'rougeLsum': 0.2871078425652874}
PEFT MODEL:
{'rouge1': 0.3916678434489629, 'rouge2': 0.1328566293652899, 'rougeL': 0.3052362517745412, 'rougeLsum': 0.3052010416808383}


###  Load FLAN-T5 Fine Tuned Model, Prepare Reward Model and Toxicity Evaluator

In [13]:
# the peft model from memory


peft_model_base = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base", torch_dtype=torch.bfloat16)
tokenizer_original = AutoTokenizer.from_pretrained("google/flan-t5-base")

peft_model = PeftModel.from_pretrained(peft_model_base, 
                                       './model', 
                                       torch_dtype=torch.bfloat16,
                                       is_trainable=False)
tokenizer_peft = AutoTokenizer.from_pretrained("./model")


In [14]:
# define lora model for training with rl 
lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)

model_name = "google/flan-t5-base"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, 
                                              torch_dtype=torch.bfloat16,
                                             device_map="auto")


# lora model should be trainable 
peft_model = PeftModel.from_pretrained(model, 
                                       './model', 
                                       lora_config=lora_config,
                                       torch_dtype=torch.bfloat16, 
                                       device_map="auto",                                       
                                       is_trainable=True)


In [15]:
# number of trainable parameters
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"\ntrainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

In [16]:
# check the number of trainable parameters of base model
print(f'PEFT model parameters to be updated:\n{print_number_of_trainable_model_parameters(peft_model)}\n')


PEFT model parameters to be updated:

trainable model parameters: 3538944
all model parameters: 251116800
percentage of trainable model parameters: 1.41%



In [17]:
#define the training model with one layer Value head more for RL training  
# peft_model = peft_model.to('cpu')
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,                                                               
                                                               torch_dtype=torch.bfloat16,
                                                               device_map="auto",
                                                               is_trainable=True)

print(f'PPO model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(ppo_model)}\n')
print(ppo_model.v_head)

PPO model parameters to be updated (ValueHead + 769 params):

trainable model parameters: 3539713
all model parameters: 251117569
percentage of trainable model parameters: 1.41%

ValueHead(
  (dropout): Dropout(p=0.1, inplace=False)
  (summary): Linear(in_features=768, out_features=1, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)


In [18]:
# creat a refrence model with value head for KL divergence
ref_model = create_reference_model(ppo_model)

print(f'Reference model parameters to be updated:\n{print_number_of_trainable_model_parameters(ref_model)}\n')

Reference model parameters to be updated:

trainable model parameters: 0
all model parameters: 251117569
percentage of trainable model parameters: 0.00%



In [19]:
### reward model for the toxicity 
# the logit layer score would be kept for the not-hate output as the reward
toxicity_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_name, device_map="auto")
toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name, device_map="auto")
print(toxicity_model.config.id2label)

{0: 'nothate', 1: 'hate'}


In [20]:
# test a sample 
non_toxic_text = "#Person 1# tells Tommy that he didn't like the movie."

toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors="pt").input_ids

toxicity_input_ids  = toxicity_input_ids.to('cuda:0')

logits = toxicity_model(input_ids=toxicity_input_ids).logits
print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

# get the logits for "not hate" - this is the reward!
not_hate_index = 0
nothate_reward = (logits[:, not_hate_index]).tolist()
print(f'reward (high): {nothate_reward}')

logits [not hate, hate]: [3.114102840423584, -2.489619255065918]
probabilities [not hate, hate]: [0.9963293671607971, 0.003670602338388562]
reward (high): [3.114102840423584]


In [21]:
# define reward pipeline for rl model 
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

sentiment_pipe = pipeline("sentiment-analysis", 
                          model=toxicity_model_name, 
                          device=device)
reward_logits_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "none", # Set to "none" to retrieve raw logits.
    "batch_size": 16
}

reward_probabilities_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "softmax", # Set to "softmax" to apply softmax and retrieve probabilities.
    "batch_size": 16
}

print("Reward model output:")
print("For non-toxic text")
print(sentiment_pipe(non_toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))


Reward model output:
For non-toxic text
[{'label': 'nothate', 'score': 3.114102840423584}, {'label': 'hate', 'score': -2.489619255065918}]
[{'label': 'nothate', 'score': 0.9963293671607971}, {'label': 'hate', 'score': 0.003670602571219206}]


In [22]:





learning_rate=1.41e-5
max_ppo_epochs=1
mini_batch_size=4
batch_size=512
model_name="google/flan-t5-base"
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

device = 'cuda:0'








model = AutoModelForSeq2SeqLM.from_pretrained(model_name, 
                                              torch_dtype=torch.bfloat16,
                                             device_map="auto")
# lora model should be trainable 
peft_model = PeftModel.from_pretrained(model, 
                                       './model', 
                                       lora_config=lora_config,
                                       torch_dtype=torch.bfloat16, 
                                       device_map="auto",                                       
                                       is_trainable=True)







# peft_model = peft_model.to(device)
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,                                                               
                                                               torch_dtype=torch.bfloat16,
                                                               device_map="auto",
                                                               is_trainable=True)
ref_model = create_reference_model(ppo_model)








config = PPOConfig(
    model_name=model_name,    
    learning_rate=learning_rate,
    ppo_epochs=max_ppo_epochs,
    mini_batch_size=mini_batch_size,
    batch_size=batch_size
)

def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])


tokenizer_peft = AutoTokenizer.from_pretrained('./model',return_tensors="pt")

ppo_trainer = PPOTrainer(config=config, 
                         model=ppo_model, 
                         ref_model=ref_model, 
                         tokenizer=tokenizer_peft, 
                         dataset=tokenized_datasets["train"],
                         data_collator=collator)

In [23]:
generation_kwargs = {
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}

reward_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "none", # You want the raw logits without softmax.
    "batch_size": 512
}

max_ppo_steps = 10
max_len = 200




for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    
    if step >= max_ppo_steps:
        break
        
       
        
        
        
    s = []
    f = []
    for item in batch["input_ids"]: 
        i = torch.tensor(item).to('cuda:0')
        f.append(i)
        summary = ppo_trainer.generate(i,**generation_kwargs)
        s.append(summary[0][0:max_len])
    
    
    batch["response"] = [tokenizer.decode(r) for r in s]
   
    
    rewards = sentiment_pipe(batch["response"], **reward_kwargs)

    reward_tensors = [torch.tensor(reward[not_hate_index]["score"]) for reward in rewards]    

    stats = ppo_trainer.step(f, s, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)
    
    print(f'objective/kl: {stats["objective/kl"]}')
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}')
    print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}')
    print('-'.join('' for x in range(100)))
    
    

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
1it [03:00, 180.68s/it]

objective/kl: 17.10700798034668
ppo/returns/mean: 0.6143479347229004
ppo/policy/advantages_mean: -8.030651699186819e-10
---------------------------------------------------------------------------------------------------


2it [05:55, 177.03s/it]

objective/kl: 15.933660507202148
ppo/returns/mean: 0.7839390635490417
ppo/policy/advantages_mean: -5.490540022634605e-09
---------------------------------------------------------------------------------------------------


3it [08:52, 177.06s/it]

objective/kl: 15.4185791015625
ppo/returns/mean: 0.8828891515731812
ppo/policy/advantages_mean: 7.138047930510538e-10
---------------------------------------------------------------------------------------------------


4it [11:49, 177.21s/it]

objective/kl: 16.317434310913086
ppo/returns/mean: 0.7959517240524292
ppo/policy/advantages_mean: 2.0143748891143787e-09
---------------------------------------------------------------------------------------------------


5it [14:46, 177.05s/it]

objective/kl: 15.308398246765137
ppo/returns/mean: 0.8806829452514648
ppo/policy/advantages_mean: 4.5860595876412447e-10
---------------------------------------------------------------------------------------------------


6it [17:44, 177.27s/it]

objective/kl: 15.318288803100586
ppo/returns/mean: 0.9563939571380615
ppo/policy/advantages_mean: -1.2055442200065158e-09
---------------------------------------------------------------------------------------------------


7it [20:40, 177.04s/it]

objective/kl: 15.013477325439453
ppo/returns/mean: 0.9635177850723267
ppo/policy/advantages_mean: -3.211149346427078e-09
---------------------------------------------------------------------------------------------------


8it [23:42, 178.53s/it]

objective/kl: 13.681775093078613
ppo/returns/mean: 1.1279995441436768
ppo/policy/advantages_mean: -2.384359598650576e-09
---------------------------------------------------------------------------------------------------


9it [26:46, 180.37s/it]

objective/kl: 14.386486053466797
ppo/returns/mean: 1.0782793760299683
ppo/policy/advantages_mean: -2.737688742371347e-09
---------------------------------------------------------------------------------------------------


10it [29:43, 179.32s/it]

objective/kl: 13.230035781860352
ppo/returns/mean: 1.2634953260421753
ppo/policy/advantages_mean: -4.4154462308654274e-09
---------------------------------------------------------------------------------------------------


10it [29:44, 178.41s/it]
