In [1]:
#!pip install git+https://github.com/lvwerra/trl.git@25fa1bd

In [2]:
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfig
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, TaskType

# trl: Transformer Reinforcement Learning library
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

import torch
import evaluate

import numpy as np
import pandas as pd

# tqdm library makes the loops show a smart progress meter.
from tqdm import tqdm
tqdm.pandas()

In [3]:
model_name="google/flan-t5-base"
#model_name="google/flan-t5-large"
huggingface_dataset_name = "knkarthick/dialogsum"

dataset_original = load_dataset(huggingface_dataset_name)

dataset_original

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
})

In [4]:
def build_dataset(model_name,
                  dataset_name,
                  input_min_text_length, 
                  input_max_text_length):

    """
    Preprocess the dataset and split it into train and test parts.

    Parameters:
    - model_name (str): Tokenizer model name.
    - dataset_name (str): Name of the dataset to load.
    - input_min_text_length (int): Minimum length of the dialogues.
    - input_max_text_length (int): Maximum length of the dialogues.
        
    Returns:
    - dataset_splits (datasets.dataset_dict.DatasetDict): Preprocessed dataset containing train and test parts.
    """
    
    # load dataset (only "train" part will be enough for this lab).
    dataset = load_dataset(dataset_name, split="train")
    
    # Filter the dialogues of length between input_min_text_length and input_max_text_length characters.
    dataset = dataset.filter(lambda x: len(x["dialogue"]) > input_min_text_length and len(x["dialogue"]) <= input_max_text_length, batched=False)

    # Prepare tokenizer. Setting device_map="auto" allows to switch between GPU and CPU automatically.
    tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
    
    def tokenize(sample):
        
        # Wrap each dialogue with the instruction.
        prompt = f"""
Summarize the following conversation.

{sample["dialogue"]}

Summary:
"""
        sample["input_ids"] = tokenizer.encode(prompt)
        
        # This must be called "query", which is a requirement of our PPO library.
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    # Tokenize each dialogue.
    dataset = dataset.map(tokenize, batched=False)
    dataset.set_format(type="torch")
    
    # Split the dataset into train and test parts.
    dataset_splits = dataset.train_test_split(test_size=0.2, shuffle=False, seed=42)

    return dataset_splits

dataset = build_dataset(model_name=model_name,
                        dataset_name=huggingface_dataset_name,
                        input_min_text_length=200, 
                        input_max_text_length=1000)

print(dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 8017
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 2005
    })
})


In [5]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"\ntrainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

In [6]:
lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name, 
                                              torch_dtype=torch.bfloat16)

peft_model = PeftModel.from_pretrained(model, 
                                       './peft-dialogue-summary-checkpoint-local/', 
                                       lora_config=lora_config,
                                       torch_dtype=torch.bfloat16, 
                                       device_map="auto",                                       
                                       is_trainable=True)

print(f'PEFT model parameters to be updated:\n{print_number_of_trainable_model_parameters(peft_model)}\n')

PEFT model parameters to be updated:

trainable model parameters: 3538944
all model parameters: 251116800
percentage of trainable model parameters: 1.41%



In [7]:
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,                                                               
                                                               torch_dtype=torch.bfloat16,
                                                               is_trainable=True)

print(f'PPO model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(ppo_model)}\n')
print(ppo_model.v_head)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


PPO model parameters to be updated (ValueHead + 769 params):

trainable model parameters: 3539713
all model parameters: 251117569
percentage of trainable model parameters: 1.41%

ValueHead(
  (dropout): Dropout(p=0.1, inplace=False)
  (summary): Linear(in_features=768, out_features=1, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)


In [8]:
ref_model = create_reference_model(ppo_model)

print(f'Reference model parameters to be updated:\n{print_number_of_trainable_model_parameters(ref_model)}\n')

Reference model parameters to be updated:

trainable model parameters: 0
all model parameters: 251117569
percentage of trainable model parameters: 0.00%



In [9]:
toxicity_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_name, device_map="auto")
toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name, device_map="auto")
print(toxicity_model.config.id2label)

{0: 'nothate', 1: 'hate'}


In [10]:
non_toxic_text = "#Person 1# tells Tommy that he didn't like the movie."

toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors="pt").input_ids

#logits = toxicity_model(input_ids=toxicity_input_ids).logits
logits = toxicity_model(input_ids=toxicity_input_ids.to("cuda")).logits

print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

# get the logits for "not hate" - this is the reward!
not_hate_index = 0
nothate_reward = (logits[:, not_hate_index]).tolist()
print(f'reward (high): {nothate_reward}')

logits [not hate, hate]: [3.114102363586426, -2.489619016647339]
probabilities [not hate, hate]: [0.9963293671607971, 0.0036706042010337114]
reward (high): [3.114102363586426]


In [11]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
non_toxic_text = "#Person 1# tells Tommy that he didn't like the movie."

toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors="pt").input_ids
toxicity_input_ids = toxicity_input_ids.to(device)

logits = toxicity_model(input_ids=toxicity_input_ids).logits
print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

# get the logits for "not hate" - this is the reward!
not_hate_index = 0
nothate_reward = (logits[:, not_hate_index]).tolist()
print(f'reward (high): {nothate_reward}')

logits [not hate, hate]: [3.114102363586426, -2.489619016647339]
probabilities [not hate, hate]: [0.9963293671607971, 0.0036706042010337114]
reward (high): [3.114102363586426]


In [12]:
toxic_text = "#Person 1# tells Tommy that the movie was terrible, dumb and stupid."

toxicity_input_ids = toxicity_tokenizer(toxic_text, return_tensors="pt").input_ids
toxicity_input_ids = toxicity_input_ids.to(device)

logits = toxicity_model(toxicity_input_ids).logits
print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

# Get the logits for "not hate" - this is the reward!
nothate_reward = (logits[:, not_hate_index]).tolist() 
print(f'reward (low): {nothate_reward}')

logits [not hate, hate]: [-0.6921154856681824, 0.372269868850708]
probabilities [not hate, hate]: [0.25647228956222534, 0.7435277104377747]
reward (low): [-0.6921154856681824]


In [13]:
device = 0 if torch.cuda.is_available() else "cpu"

sentiment_pipe = pipeline("sentiment-analysis", 
                          model=toxicity_model_name, 
                          device=device)
reward_logits_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "none", # Set to "none" to retrieve raw logits.
    "batch_size": 16
}

reward_probabilities_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "softmax", # Set to "softmax" to apply softmax and retrieve probabilities.
    "batch_size": 16
}

print("Reward model output:")
print("For non-toxic text")
print(sentiment_pipe(non_toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))
print("For toxic text")
print(sentiment_pipe(toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(toxic_text, **reward_probabilities_kwargs))

Reward model output:
For non-toxic text
[{'label': 'nothate', 'score': 3.114102363586426}, {'label': 'hate', 'score': -2.489619016647339}]
[{'label': 'nothate', 'score': 0.9963293671607971}, {'label': 'hate', 'score': 0.0036706042010337114}]
For toxic text
[{'label': 'hate', 'score': 0.372269868850708}, {'label': 'nothate', 'score': -0.6921154856681824}]
[{'label': 'hate', 'score': 0.7435277700424194}, {'label': 'nothate', 'score': 0.25647228956222534}]


In [14]:
print(sentiment_pipe(non_toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))

[{'label': 'nothate', 'score': 3.114102363586426}, {'label': 'hate', 'score': -2.489619016647339}]
[{'label': 'nothate', 'score': 0.9963293671607971}, {'label': 'hate', 'score': 0.0036706042010337114}]


In [15]:
print(sentiment_pipe(toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(toxic_text, **reward_probabilities_kwargs))

[{'label': 'hate', 'score': 0.372269868850708}, {'label': 'nothate', 'score': -0.6921154856681824}]
[{'label': 'hate', 'score': 0.7435277700424194}, {'label': 'nothate', 'score': 0.25647228956222534}]


In [16]:
text = "I love you."
print(sentiment_pipe(text, **reward_logits_kwargs))
print(sentiment_pipe(text, **reward_probabilities_kwargs))

[{'label': 'nothate', 'score': 4.6205291748046875}, {'label': 'hate', 'score': -4.193256378173828}]
[{'label': 'nothate', 'score': 0.9998513460159302}, {'label': 'hate', 'score': 0.0001486473047407344}]


In [17]:
text = "How are you doing today?"
print(sentiment_pipe(text, **reward_logits_kwargs))
print(sentiment_pipe(text, **reward_probabilities_kwargs))

[{'label': 'nothate', 'score': 4.629112720489502}, {'label': 'hate', 'score': -4.079700469970703}]
[{'label': 'nothate', 'score': 0.9998348951339722}, {'label': 'hate', 'score': 0.00016509693523403257}]




In [18]:
text = "#Person 1# tells Tommy that he was terrible, dumb and stupid."
print(sentiment_pipe(text, **reward_logits_kwargs))
print(sentiment_pipe(text, **reward_probabilities_kwargs))

[{'label': 'hate', 'score': 0.9263187050819397}, {'label': 'nothate', 'score': -1.226240634918213}]
[{'label': 'hate', 'score': 0.8959077000617981}, {'label': 'nothate', 'score': 0.1040923148393631}]


In [19]:
toxicity_evaluator = evaluate.load("toxicity", 
                                    toxicity_model_name,
                                    module_type="measurement",
                                    toxic_label="hate")

In [20]:
toxicity_score = toxicity_evaluator.compute(predictions=[
    non_toxic_text
])

print("Toxicity score for non-toxic text:")
print(toxicity_score["toxicity"])

toxicity_score = toxicity_evaluator.compute(predictions=[
    toxic_text
])

print("\nToxicity score for toxic text:")
print(toxicity_score["toxicity"])

Toxicity score for non-toxic text:
[0.0036706007085740566]

Toxicity score for toxic text:
[0.7435290217399597]


In [21]:
toxicity_score = toxicity_evaluator.compute(predictions=[
    text
])

print("Toxicity score for text:")
print(toxicity_score["toxicity"])

Toxicity score for text:
[0.895907461643219]


In [22]:
def evaluate_toxicity(model, 
                      toxicity_evaluator, 
                      tokenizer, 
                      dataset, 
                      num_samples):
    
    """
    Preprocess the dataset and split it into train and test parts.

    Parameters:
    - model (trl model): Model to be evaluated.
    - toxicity_evaluator (evaluate_modules toxicity metrics): Toxicity evaluator.
    - tokenizer (transformers tokenizer): Tokenizer to be used.
    - dataset (dataset): Input dataset for the evaluation.
    - num_samples (int): Maximum number of samples for the evaluation.
        
    Returns:
    tuple: A tuple containing two numpy.float64 values:
    - mean (numpy.float64): Mean of the samples toxicity.
    - std (numpy.float64): Standard deviation of the samples toxicity.
    """

    max_new_tokens=100

    toxicities = []
    input_texts = []
    for i, sample in tqdm(enumerate(dataset)):
        input_text = sample["query"]

        if i > num_samples:
            break
            
        input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        
        generation_config = GenerationConfig(max_new_tokens=max_new_tokens,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)

        response_token_ids = model.generate(input_ids=input_ids.to("cuda"),
                                            generation_config=generation_config)
        
        generated_text = tokenizer.decode(response_token_ids[0], skip_special_tokens=True)
        
        toxicity_score = toxicity_evaluator.compute(predictions=[(input_text + " " + generated_text)])

        toxicities.extend(toxicity_score["toxicity"])

    # Compute mean & std using np.
    mean = np.mean(toxicities)
    std = np.std(toxicities)
        
    return mean, std

In [23]:
print(dataset["test"][0]["query"])

Summarize the following conversation. #Person1#: I would like to order some internet today. #Person2#: What kind would you like? #Person1#: What kind of internet is there? #Person2#: You can get DEL or dial-up. #Person1#: Which of those two is best? #Person2#: I would recommend DEL. #Person1#: So that one better? #Person2#: It's better because it doesn't tie up the phone. #Person1#: What do you mean by that? #Person2#: DEL isn't connected through your phone line, but dial-up is. #Person1#: So then I can't use my phone if I'm on the internet? #Person2#: That's correct. With DEL you can do both. Summary: </s>


In [24]:
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")

input_text = dataset["test"][0]["query"]
input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        
generation_config = GenerationConfig(max_new_tokens=100,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)
response_token_ids = model.generate(input_ids=input_ids,
                                            generation_config=generation_config)
print(tokenizer.decode(response_token_ids[0], skip_special_tokens=True))

#Person1#: You want to get some DEL Internet. #Person2#: And who will be doing the shopping? - Person3#: Who? - Person2#: And who? - Person1#: I hear that there's a DEL Internet for the DEL network. #Person1#: And it's cheaper than the DEL network. You can find it on the internet market.


In [25]:
input_text = dataset["test"][0]["query"]
input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        
generation_config = GenerationConfig(max_new_tokens=100,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)
response_token_ids = ref_model.generate(input_ids=input_ids,
                                            generation_config=generation_config)
print(tokenizer.decode(response_token_ids[0], skip_special_tokens=True))

Go to https://www.internetdirect.com/and choose the best deal.


In [26]:
input_text = dataset["test"][0]["query"]
input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        
generation_config = GenerationConfig(max_new_tokens=100,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)
response_token_ids = peft_model.generate(input_ids=input_ids,
                                            generation_config=generation_config)
print(tokenizer.decode(response_token_ids[0], skip_special_tokens=True))

Select DEL or dial-up cable.


In [27]:
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
ref_model = ref_model.to("cuda")
mean_before_detoxification, std_before_detoxification = evaluate_toxicity(model=ref_model, 
                                                                          toxicity_evaluator=toxicity_evaluator, 
                                                                          tokenizer=tokenizer, 
                                                                          dataset=dataset["test"], 
                                                                          num_samples=10)

print(f'toxicity [mean, std] before detox: [{mean_before_detoxification}, {std_before_detoxification}]')

11it [00:04,  2.27it/s]

toxicity [mean, std] before detox: [0.0075488285722465, 0.009509237744984961]





In [28]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

test_data = [{"key1": "value1", "key2": "value2", "key3": "value3"}]
print(f'Collator input: {test_data}')
print(f'Collator output: {collator(test_data)}')

Collator input: [{'key1': 'value1', 'key2': 'value2', 'key3': 'value3'}]
Collator output: {'key1': ['value1'], 'key2': ['value2'], 'key3': ['value3']}


In [29]:
learning_rate=1.41e-5
max_ppo_epochs=1
mini_batch_size=4
batch_size=16

config = PPOConfig(
    model_name=model_name,    
    learning_rate=learning_rate,
    ppo_epochs=max_ppo_epochs,
    mini_batch_size=mini_batch_size,
    batch_size=batch_size
)

ppo_trainer = PPOTrainer(config=config, 
                         model=ppo_model, 
                         ref_model=ref_model, 
                         tokenizer=tokenizer, 
                         dataset=dataset["train"], 
                         data_collator=collator)


Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [30]:
output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)

generation_kwargs = {
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}

reward_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "none", # You want the raw logits without softmax.
    "batch_size": 16
}

max_ppo_steps = 10

for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    # Break when you reach max_steps.
    if step >= max_ppo_steps:
        break   

    prompt_tensors = batch["input_ids"]
    #prompt_tensors = prompt_tensors.to(device)

    # Get response from FLAN-T5/PEFT LLM.
    summary_tensors = []

    for prompt_tensor in prompt_tensors:
        max_new_tokens = output_length_sampler()        
            
        generation_kwargs["max_new_tokens"] = max_new_tokens
        summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)
        
        summary_tensors.append(summary.squeeze()[-max_new_tokens:])
        
    # This needs to be called "response".
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in summary_tensors]

    # Compute reward outputs.
    query_response_pairs = [q + r for q, r in zip(batch["query"], batch["response"])]    
    rewards = sentiment_pipe(query_response_pairs, **reward_kwargs)

    # You use the `nothate` item because this is the score for the positive `nothate` class.
    reward_tensors = [torch.tensor(reward[not_hate_index]["score"]) for reward in rewards]    

    # Run PPO step.
    stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)
    
    print(f'objective/kl: {stats["objective/kl"]}')
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}')
    print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}')
    print('-'.join('' for x in range(100)))

1it [00:08,  8.77s/it]

objective/kl: 0.0
ppo/returns/mean: 1.4007325172424316
ppo/policy/advantages_mean: 4.5106538948402886e-08
---------------------------------------------------------------------------------------------------


2it [00:17,  8.82s/it]

objective/kl: -0.011115435510873795
ppo/returns/mean: 1.293785810470581
ppo/policy/advantages_mean: 9.919183696638356e-09
---------------------------------------------------------------------------------------------------


3it [00:26,  8.94s/it]

objective/kl: -0.021461091935634613
ppo/returns/mean: 1.1473360061645508
ppo/policy/advantages_mean: -1.1248141085218322e-08
---------------------------------------------------------------------------------------------------


4it [00:34,  8.44s/it]

objective/kl: 0.0064226132817566395
ppo/returns/mean: 1.3772499561309814
ppo/policy/advantages_mean: -1.6264074531591177e-09
---------------------------------------------------------------------------------------------------


5it [00:43,  8.52s/it]

objective/kl: -0.0069248308427631855
ppo/returns/mean: 1.5466547012329102
ppo/policy/advantages_mean: -5.4908397828512534e-08
---------------------------------------------------------------------------------------------------


6it [00:54,  9.40s/it]

objective/kl: -0.02902948297560215
ppo/returns/mean: 1.371447205543518
ppo/policy/advantages_mean: 1.2460017906334997e-08
---------------------------------------------------------------------------------------------------


7it [01:04,  9.76s/it]

objective/kl: 0.024337127804756165
ppo/returns/mean: 1.3515512943267822
ppo/policy/advantages_mean: -5.393269475462148e-08
---------------------------------------------------------------------------------------------------


8it [01:14,  9.78s/it]

objective/kl: -0.023060236126184464
ppo/returns/mean: 1.2474091053009033
ppo/policy/advantages_mean: -2.3226455425628956e-08
---------------------------------------------------------------------------------------------------


9it [01:24,  9.73s/it]

objective/kl: -0.05064701288938522
ppo/returns/mean: 1.1414856910705566
ppo/policy/advantages_mean: -5.505519595772057e-09
---------------------------------------------------------------------------------------------------


10it [01:35,  9.56s/it]

objective/kl: -0.05296263098716736
ppo/returns/mean: 1.1406078338623047
ppo/policy/advantages_mean: 4.858181057443289e-08
---------------------------------------------------------------------------------------------------





In [31]:
#ppo_model = ppo_model.to("cpu")
ppo_model = ppo_model.to("cuda")
mean_after_detoxification, std_after_detoxification = evaluate_toxicity(model=ppo_model, 
                                                                        toxicity_evaluator=toxicity_evaluator, 
                                                                        tokenizer=tokenizer, 
                                                                        dataset=dataset["test"],
                                                                        #dataset=dataset["test"].to(device),
                                                                        num_samples=10)
print(f'toxicity [mean, std] after detox: [{mean_after_detoxification}, {std_after_detoxification}]')

11it [00:05,  2.12it/s]

toxicity [mean, std] after detox: [0.011690370278136636, 0.014647131650818414]





In [32]:
mean_improvement = (mean_before_detoxification - mean_after_detoxification) / mean_before_detoxification
std_improvement = (std_before_detoxification - std_after_detoxification) / std_before_detoxification

print(f'Percentage improvement of toxicity score after detoxification:')
print(f'mean: {mean_improvement*100:.2f}%')
print(f'std: {std_improvement*100:.2f}%')

Percentage improvement of toxicity score after detoxification:
mean: -54.86%
std: -54.03%


In [33]:
ppo_model = ppo_model.to(device)

batch_size = 20
compare_results = {}

df_batch = dataset["test"][0:batch_size]

compare_results["query"] = df_batch["query"]
prompt_tensors = df_batch["input_ids"]

summary_tensors_ref = []
summary_tensors = []

# Get response from ppo and base model.
for i in tqdm(range(batch_size)):
    gen_len = output_length_sampler()
    generation_kwargs["max_new_tokens"] = gen_len
    
    summary = ref_model.generate(
        input_ids=torch.as_tensor(prompt_tensors[i]).unsqueeze(dim=0).to(device), 
        **generation_kwargs
    ).squeeze()[-gen_len:]
    summary_tensors_ref.append(summary)

    summary = ppo_model.generate(
        input_ids=torch.as_tensor(prompt_tensors[i]).unsqueeze(dim=0).to(device), 
        **generation_kwargs
    ).squeeze()[-gen_len:]
    summary_tensors.append(summary)

# Decode responses.
compare_results["response_before"] = [tokenizer.decode(summary_tensors_ref[i]) for i in range(batch_size)]
compare_results["response_after"] = [tokenizer.decode(summary_tensors[i]) for i in range(batch_size)]

# Sentiment analysis of query/response pairs before/after.
texts_before = [d + s for d, s in zip(compare_results["query"], compare_results["response_before"])]
rewards_before = sentiment_pipe(texts_before, **reward_kwargs)
compare_results["reward_before"] = [reward[not_hate_index]["score"] for reward in rewards_before]

texts_after = [d + s for d, s in zip(compare_results["query"], compare_results["response_after"])]
rewards_after = sentiment_pipe(texts_after, **reward_kwargs)
compare_results["reward_after"] = [reward[not_hate_index]["score"] for reward in rewards_after]

100%|██████████| 20/20 [00:17<00:00,  1.15it/s]


In [34]:
pd.set_option('display.max_colwidth', 500)
df_compare_results = pd.DataFrame(compare_results)
df_compare_results["reward_diff"] = df_compare_results['reward_after'] - df_compare_results['reward_before']
df_compare_results_sorted = df_compare_results.sort_values(by=['reward_diff'], ascending=False).reset_index(drop=True)
df_compare_results_sorted

Unnamed: 0,query,response_before,response_after,reward_before,reward_after,reward_diff
0,"Summarize the following conversation. #Person1#: It smells like an ashtray in here! #Person2#: Hi honey! What's wrong? Why do you have that look on your face? #Person1#: What's wrong? I thought we agreed that you were gonna quit smoking. #Person2#: No! I said I was going to cut down which is very different. You can't just expect me to go cold turkey overnight! #Person1#: Look, there are other ways to quit. You can try the nicotine patch, or nicotine chewing gum. We spend a fortune on cigaret...",<pad> #Person1#: You're acting like an ashtray.</s>,"<pad> #Person2#: It smells like an ashtray. #Person1#: Ahh chapter 17A, that part of me?#Person2#: How can I, honey?#Person1#: Yeah, there is just an awfulcclimate that's causing my sniffles to go paranoid I can't help feeling sounge about it.#Person2#: Let me just gush on through the chapter 17A, you are so soooooo clueless. #Person2#: Well, I know I'm an ashtray, will I?#Person1#: Don't try to, but you really can't compromise.</s>",1.72725,2.414285,0.687034
1,"Summarize the following conversation. #Person1#: I'm forming a music band. #Person2#: Do you already know how to play an instrument? #Person1#: Uh... Yeah! I'Ve told you a thousand times that I'm learning to play the drums. Now that I know how to play well, I would like to form a rock band. #Person2#: Aside from yourself, who are the other members of the band? #Person1#: We have a guy who plays guitar, and another who plays bass. Although we still haven't found anyone to be our singer. You t...","<pad> Not sure if you can hear them, but someone punched you.</s>",<pad> #Person1#: I want to make a rock band.</s>,1.836627,2.49017,0.653543
2,"Summarize the following conversation. #Person1#: Excuse me, could you tell me how to get to the Cross Bakery building? #Person2#: The Cross Bakery building? Oh sure. You're actually walking in the opposite direction. #Person1#: Oh, you're kidding! I thought I was heading east. #Person2#: No, east is the other direction. To get to the Bakery, you need to turn around and go three blocks to Broadway. When you get to the intersection of Broadway and Elm, you hang a left. Go straight down that st...",<pad> Go to Broadway and hit Elm. Walk up Elm a short walk to the Cross Bakery building.</s>,"<pad> #Person1#: The Cross Bakery building on the south, right. Handle the cross, and then corner down, right.</s>",2.421151,2.990093,0.568943
3,"Summarize the following conversation. #Person1#: Could you help me, Sir? My flight got in 15 minutes ago. Everyone else has picked up the luggage but mine hasn't come through. #Person2#: I'm sorry, Madam, I'll go and find out if there is any more to come. Summary: </s>",<pad> #Person1#: My seat is still occupied.</s>,<pad> You should try the e-ticket kiosk.</s>,2.410102,2.753905,0.343804
4,"Summarize the following conversation. #Person1#: Judy, what is everybody talking about? #Person2#: Haven't you heard? Richard was fired by our manager. #Person1#: You're kidding. It can't be true. #Person2#: Believe it or not. Everybody is talking about it in the company. #Person1#: Really? I'm surprised. #Person2#: Me too. Summary: </s>",<pad> Judy thinks Richard hired a new assistant.</s>,"<pad> Person1 wants to know what everyone is talking about because since his was fired, his boss has told everyone.</s>",1.229174,1.553971,0.324797
5,"Summarize the following conversation. #Person1#: Hello. I want to reconfirm our flight to London. #Person2#: Yes, sir. Did you call the airline? #Person1#: Yes, I did. But I couldn't communicate with them in English. They speak only Spanish. So I need your help. #Person2#: Certainly, sir. What is the flight number and when are you leaving? #Person1#: We are taking IB 385 to London tomorrow at 1 p. m. #Person2#: Oh, I see, sir. We have the airline office inside the hotel. They have an English...",<pad> DN 35 please.</s>,"<pad> I want to arrange, please.</s>",2.136233,2.351408,0.215175
6,"Summarize the following conversation. #Person1#: Here is the final draft of our contract. I'm glad that we have reached an agreement on almost every term in our trade. #Person2#: Yes, it seems to me we have come quite a long way. However, let me take a close look at the final draft. #Person1#: Do you have some points to bring up? #Person2#: Well, everything we've discussed seems to be here. #Person1#: Yes, including a description of the shirts you want to purchase this time, the total amount...",<pad> #Person2#: Good morning.</s>,<pad> #Person1#: Thank you. Are you ready to sign?</s>,2.79626,2.914157,0.117896
7,"Summarize the following conversation. #Person1#: I would like to order some internet today. #Person2#: What kind would you like? #Person1#: What kind of internet is there? #Person2#: You can get DEL or dial-up. #Person1#: Which of those two is best? #Person2#: I would recommend DEL. #Person1#: So that one better? #Person2#: It's better because it doesn't tie up the phone. #Person1#: What do you mean by that? #Person2#: DEL isn't connected through your phone line, but dial-up is. #Person1#: S...",<pad> #Person1#: I would like to buy DEL or dial-up phone internet today.</s>,<pad> Go to https://www.iinternetdirect.com. Select DEL or dial-up internet. Wait for your DEL package to arrive.</s>,2.545534,2.65685,0.111315
8,"Summarize the following conversation. #Person1#: How much are you asking for this? #Person2#: I'm offering them to you at 150 yuan a piece. Is that all right? #Person1#: Is tax already included in their price? #Person2#: Yes. Our price can't be matched. #Person1#: Would you consider a volume discount? #Person2#: If you buy 1, 000 or more, you'll get a 10 % discount. #Person1#: I'll accept your offer. Summary: </s>",<pad> Our price can't be matched.</s>,<pad> Our prices are 150 yuan a piece to send it to you.</s>,3.149918,3.122718,-0.0272
9,"Summarize the following conversation. #Person1#: What can I do for you, madam? #Person2#: I'd like to buy a toy car for my son. #Person1#: How about this one? #Person2#: It looks nice. How much is it? #Person1#: They're three hundred dollars. #Person2#: Oh, I'm afraid it's too expensive. Can you show me something cheaper? #Person1#: OK, This one is one hundred and twenty. It's the cheapest here. #Person2#: OK, I'll take it. Here's the money. #Person1#: Thank you very much. Summary: </s>","<pad> #Person1#: Does this look good for you, madam? #Person2#: Oh, yes.</s>",<pad> #Person1# Can you show me a cheaper toy car for my son?</s>,1.411773,1.37732,-0.034454
