In [2]:
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfig
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, TaskType

from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

import torch
import evaluate

import numpy as np
import pandas as pd

from tqdm import tqdm
tqdm.pandas()

device = "mps" if torch.backends.mps.is_available() else "cpu"

In [3]:
# Load dataset from hugging face dialogsum

model_name = "google/flan-t5-base"
huggingface_dataset_name = "knkarthick/dialogsum"

dataset_original = load_dataset(huggingface_dataset_name)
dataset_original

README.md:   0%|          | 0.00/4.65k [00:00<?, ?B/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
})

In [4]:
# Take a subset of the dataset and filter only long enough and easy to read
# Then wrap each dialog with the instruction and tokenize the prompt
# Save the token_ids in the field token)ids
# Decoded version of prompt in the field query

def build_dataset(model_name, dataset_name, input_min_text_length, input_max_text_length):
    """
    Preprocess the dataset and split into training and test sets.

    Parameters:
    - model_name: name of the model
    - dataset_name: name of the dataset
    - input_min_text_length: minimum text length
    - input_max_text_length: maximum text length

    Returns:
    - dataset_splits (datasets.dataset_dict.DatasetDict): Processed dataset containing training and test sets.
    """
    dataset = load_dataset(dataset_name, split="train")
    dataset = dataset.filter(lambda x: len(x["dialogue"]) > input_min_text_length and len(x["dialogue"]) <= input_max_text_length, batched=False)
    tokenizer = AutoTokenizer.from_pretrained(model_name, device="auto")

    def tokenize(sample):

        # wrap each dialogue with the instruction
        prompt = f"""
Summarize the following conversation.

{sample["dialogue"]}

Summary:
"""
        sample["input_ids"] = tokenizer.encode(prompt)
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    dataset = dataset.map(tokenize, batched=False)
    dataset.set_format(type="torch")
    dataset_splits = dataset.train_test_split(test_size=0.2, shuffle=False, seed=42)
    return dataset_splits

dataset = build_dataset(model_name, huggingface_dataset_name, 200, 1000)
print(dataset)

Map:   0%|          | 0/10022 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 8017
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 2005
    })
})


In [5]:
def print_number_of_trainable_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"\ntrainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

In [6]:
# use fine-tuned PEFT model with summarization instructions from previous project
lora_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.float16)

peft_model = PeftModel.from_pretrained(model,
                                       "intotheverse/peft-dialogue-summary-checkpoint",
                                       lora_config=lora_config,
                                       torch_dtype=torch.float32,
                                       device_map={"": device},
                                       is_trainable=True)

print(f'PEFT model parameters to be updated:\n{print_number_of_trainable_parameters(peft_model)}')

adapter_config.json:   0%|          | 0.00/334 [00:00<?, ?B/s]

adapter_model.bin:   0%|          | 0.00/14.2M [00:00<?, ?B/s]

PEFT model parameters to be updated:

trainable model parameters: 3538944
all model parameters: 251116800
percentage of trainable model parameters: 1.41%


In [7]:
# Prepare reward model
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,
                                                               torch_dtype=torch.float32,
                                                               is_trainable=True,
                                                               device_map={"": device})
ppo_model.generation_config = ppo_model.config
# ppo_model.base_model_prefix = "encoder"
# ppo_model.encoder = ppo_model.pretrained_model.encoder
print(f'PPO model parameters to be updated (ValueHead + 769 parameters):\n{print_number_of_trainable_parameters(ppo_model)}\n')
print(ppo_model.v_head)

PPO model parameters to be updated (ValueHead + 769 parameters):

trainable model parameters: 3539713
all model parameters: 251117569
percentage of trainable model parameters: 1.41%

ValueHead(
  (dropout): Dropout(p=0.1, inplace=False)
  (summary): Linear(in_features=768, out_features=1, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)


In [8]:
ref_model = create_reference_model(ppo_model).to(device)
print(f'Reference model parameters to be upldated:\n{print_number_of_trainable_parameters(ref_model)}')

Reference model parameters to be upldated:

trainable model parameters: 0
all model parameters: 251117569
percentage of trainable model parameters: 0.00%


In [9]:
# Create a reference model
toxicity_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_name, device_map={"": device})
toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name, device_map={"": device})
print(toxicity_model.config.id2label)

tokenizer_config.json:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/816 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

{0: 'nothate', 1: 'hate'}


In [10]:
# Example of reward for a non-toxic text
non_toxic_text = "#Person 1# tells Tommy that he didn't like the movie."
toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors="pt").input_ids.to(device)
logits = toxicity_model(input_ids=toxicity_input_ids).logits
print(f'logits [not hate, hate]: {logits.tolist()[0]}')

probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

not_hate_index = 0
not_hate_reward = (logits[:, not_hate_index].tolist())
print(f'reward (high): {not_hate_reward}')

logits [not hate, hate]: [3.114098310470581, -2.4896152019500732]
probabilities [not hate, hate]: [0.9963293671607971, 0.003670633537694812]
reward (high): [3.114098310470581]


In [11]:
# Example of reward for a toxic text
toxic_text = "#Person 1# tells Tommy that the movie was terrible, dumb and stupid."
toxicity_input_ids = toxicity_tokenizer(toxic_text, return_tensors="pt").input_ids.to(device)
logits = toxicity_model(input_ids=toxicity_input_ids).logits
print(f'logits [not hate, hate]: {logits.tolist()[0]}')

probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

not_hate_reward = (logits[:, not_hate_index].tolist())
print(f'reward (low): {not_hate_reward}')

logits [not hate, hate]: [-0.692116916179657, 0.37227126955986023]
probabilities [not hate, hate]: [0.25647175312042236, 0.7435283064842224]
reward (low): [-0.692116916179657]


In [12]:
# Create a hugging face pipeline to simplify the reward model
sentiment_pipe = pipeline("sentiment-analysis",
                          model=toxicity_model_name,
                          device=device)

reward_logits_kwargs = {
    "top_k": None, # Return all scores (not just the top prediction)
    "function_to_apply": "none", # PPO uses raw logits for reward estimation
    "batch_size": 16
}

reward_probabilities_kwargs = {
    "top_k": None,
    "function_to_apply": "softmax", # Easier to interpret since it shows confidence scores
    "batch_size": 16
}

print("Reward model output:")
print("For non-toxic text")
print(sentiment_pipe(non_toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))
print("For toxic text")
print(sentiment_pipe(toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(toxic_text, **reward_probabilities_kwargs))

Device set to use mps


Reward model output:
For non-toxic text
[{'label': 'nothate', 'score': 3.114098310470581}, {'label': 'hate', 'score': -2.4896152019500732}]
[{'label': 'nothate', 'score': 0.9963293671607971}, {'label': 'hate', 'score': 0.0036706337705254555}]
For toxic text
[{'label': 'hate', 'score': 0.37227126955986023}, {'label': 'nothate', 'score': -0.692116916179657}]
[{'label': 'hate', 'score': 0.7435281872749329}, {'label': 'nothate', 'score': 0.25647175312042236}]


In [13]:
# Set up an evaluation metric for toxicity
toxicity_evaluator = evaluate.load("toxicity",
                                   toxicity_model_name,
                                   module_type="measurement",
                                   toxic_label="hate",
                                   device=device)

Device set to use mps:0


In [14]:
toxicity_score = toxicity_evaluator.compute(predictions=[non_toxic_text, toxic_text])

print(f"Toxicity score for non-toxic text and toxic text: {toxicity_score["toxicity"]}")

Toxicity score for non-toxic text and toxic text: [0.0036706337705254555, 0.7435281872749329]


In [15]:
def evaluate_toxicity(model,
                      toxicity_evaluator,
                      tokenizer,
                      dataset,
                      num_samples,
                      device):
    """
    Preprocess the dataset and split it into training and test sets.

    Parameters:
        - model (trl model): model to be evaluated.
        - toxicity_evaluator: toxicity evaluator.
        - tokenizer (transformers tokenizer): tokenizer to be used.
        - dataset (dataset): dataset to be evaluated.
        - num_samples (int): number of samples to be evaluated.

    Returns:
        tuple: a tuple containing two numpy.float64 values:
            - mean of samples toxicity
            - standard deviation of samples toxicity
    """

    toxicities = []
    for i, sample in tqdm(enumerate(dataset)):
        input_text = sample["query"]
        if i > num_samples:
            break

        input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids.to(device)
        generation_config = GenerationConfig(max_new_tokens=100,
                                             do_sample=True,
                                             top_k=0.0,
                                             top_p=1.0)
        response_token_ids = model.generate(input_ids=input_ids, generation_config=generation_config)
        generated_text_output = tokenizer.decode(response_token_ids[0], skip_special_tokens=True)

        toxicity_score = toxicity_evaluator.compute(predictions=[(input_text + " " + generated_text_output)])
        toxicities.extend(toxicity_score["toxicity"])

    mean = np.mean(toxicities)
    std = np.std(toxicities)
    return mean, std

In [16]:
# Assess toxicity of PEFT trained model before RLHF
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map={"": device})
mean_before_rlhf, std_before_rlhf = evaluate_toxicity(model=ref_model,
                                                      toxicity_evaluator=toxicity_evaluator,
                                                      tokenizer=tokenizer,
                                                      dataset=dataset["test"],
                                                      num_samples=10,
                                                      device=device)
print(f"Toxicity[mean, std] before RLHF: [{mean_before_rlhf}, {std_before_rlhf}]")

11it [00:30,  2.74s/it]

Toxicity[mean, std] before RLHF: [0.03948024806397205, 0.047522372034200566]





In [17]:
# Create a collator function, a data collector is used to batch process the training data before passing it to PPO
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])
test_data = [
    {"query": "Summarize this text", "response": "Summary A", "reward": 0.8},
    {"query": "Explain this topic", "response": "Explanation B", "reward": 0.6}
]
print(collator(test_data))

{'query': ['Summarize this text', 'Explain this topic'], 'response': ['Summary A', 'Explanation B'], 'reward': [0.8, 0.6]}


In [18]:
learning_rate=1.41e-5
max_ppo_epochs=1
mini_batch_size=4
batch_size=16

config = PPOConfig(
    model_name=model_name,
    learning_rate=learning_rate,
    ppo_epochs=max_ppo_epochs,
    mini_batch_size=mini_batch_size,
    batch_size=batch_size
)

ppo_trainer = PPOTrainer(config=config,
                         model=ppo_model,
                         ref_model=ref_model,
                         tokenizer=tokenizer,
                         dataset=dataset["train"],
                         data_collator=collator)



In [None]:
output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)

generation_kwargs = {
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}

reward_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "none", # You want the raw logits without softmax.
    "batch_size": 16
}

max_ppo_steps = 10

for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    # Break when you reach max_steps.
    if step >= max_ppo_steps:
        break

    prompt_tensors = batch["input_ids"]

    # Get response from FLAN-T5/PEFT LLM.
    summary_tensors = []

    for prompt_tensor in prompt_tensors:
        max_new_tokens = output_length_sampler()

        generation_kwargs["max_new_tokens"] = max_new_tokens
        summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)

        summary_tensors.append(summary.squeeze()[-max_new_tokens:])

    # This needs to be called "response".
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in summary_tensors]

    # Compute reward outputs.
    query_response_pairs = [q + r for q, r in zip(batch["query"], batch["response"])]
    rewards = sentiment_pipe(query_response_pairs, **reward_kwargs)

    # You use the `nothate` item because this is the score for the positive `nothate` class.
    reward_tensors = [torch.tensor(reward[not_hate_index]["score"]) for reward in rewards]

    # Run PPO step.
    stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)

    print(f'objective/kl: {stats["objective/kl"]}')
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}')
    print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}')
    print('-'.join('' for x in range(100)))

0it [00:00, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
1it [00:59, 59.91s/it]

objective/kl: 30.84881591796875
ppo/returns/mean: -0.8026288747787476
ppo/policy/advantages_mean: -0.003186337649822235
---------------------------------------------------------------------------------------------------


2it [02:11, 66.95s/it]

objective/kl: 34.12253952026367
ppo/returns/mean: -0.9107413291931152
ppo/policy/advantages_mean: 0.0091189444065094
---------------------------------------------------------------------------------------------------


3it [03:07, 61.77s/it]

objective/kl: 26.020965576171875
ppo/returns/mean: -0.6090663075447083
ppo/policy/advantages_mean: 0.04889874905347824
---------------------------------------------------------------------------------------------------
