In [1]:
#!pip install git+https://github.com/lvwerra/trl.git@25fa1bd

In [2]:
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfig
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, TaskType

# trl: Transformer Reinforcement Learning library
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

import torch
import evaluate

import numpy as np
import pandas as pd

# tqdm library makes the loops show a smart progress meter.
from tqdm import tqdm
tqdm.pandas()

In [3]:
model_name="google/flan-t5-base"
#model_name="google/flan-t5-large"
huggingface_dataset_name = "knkarthick/dialogsum"

dataset_original = load_dataset(huggingface_dataset_name)

dataset_original

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
})

In [4]:
def build_dataset(model_name,
                  dataset_name,
                  input_min_text_length, 
                  input_max_text_length):

    """
    Preprocess the dataset and split it into train and test parts.

    Parameters:
    - model_name (str): Tokenizer model name.
    - dataset_name (str): Name of the dataset to load.
    - input_min_text_length (int): Minimum length of the dialogues.
    - input_max_text_length (int): Maximum length of the dialogues.
        
    Returns:
    - dataset_splits (datasets.dataset_dict.DatasetDict): Preprocessed dataset containing train and test parts.
    """
    
    # load dataset (only "train" part will be enough for this lab).
    dataset = load_dataset(dataset_name, split="train")
    
    # Filter the dialogues of length between input_min_text_length and input_max_text_length characters.
    dataset = dataset.filter(lambda x: len(x["dialogue"]) > input_min_text_length and len(x["dialogue"]) <= input_max_text_length, batched=False)

    # Prepare tokenizer. Setting device_map="auto" allows to switch between GPU and CPU automatically.
    tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
    
    def tokenize(sample):
        
        # Wrap each dialogue with the instruction.
        prompt = f"""
Summarize the following conversation.

{sample["dialogue"]}

Summary:
"""
        sample["input_ids"] = tokenizer.encode(prompt)
        
        # This must be called "query", which is a requirement of our PPO library.
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    # Tokenize each dialogue.
    dataset = dataset.map(tokenize, batched=False)
    dataset.set_format(type="torch")
    
    # Split the dataset into train and test parts.
    dataset_splits = dataset.train_test_split(test_size=0.2, shuffle=False, seed=42)

    return dataset_splits

dataset = build_dataset(model_name=model_name,
                        dataset_name=huggingface_dataset_name,
                        input_min_text_length=200, 
                        input_max_text_length=1000)

print(dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 8017
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 2005
    })
})


In [5]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"\ntrainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

In [6]:
lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name, 
                                              torch_dtype=torch.bfloat16)

peft_model = PeftModel.from_pretrained(model, 
                                       './peft-dialogue-summary-checkpoint-local/', 
                                       lora_config=lora_config,
                                       torch_dtype=torch.bfloat16, 
                                       device_map="auto",                                       
                                       is_trainable=True)

print(f'PEFT model parameters to be updated:\n{print_number_of_trainable_model_parameters(peft_model)}\n')

PEFT model parameters to be updated:

trainable model parameters: 3538944
all model parameters: 251116800
percentage of trainable model parameters: 1.41%



In [7]:
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,                                                               
                                                               torch_dtype=torch.bfloat16,
                                                               is_trainable=True)

print(f'PPO model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(ppo_model)}\n')
print(ppo_model.v_head)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


PPO model parameters to be updated (ValueHead + 769 params):

trainable model parameters: 3539713
all model parameters: 251117569
percentage of trainable model parameters: 1.41%

ValueHead(
  (dropout): Dropout(p=0.1, inplace=False)
  (summary): Linear(in_features=768, out_features=1, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)


In [8]:
ref_model = create_reference_model(ppo_model)

print(f'Reference model parameters to be updated:\n{print_number_of_trainable_model_parameters(ref_model)}\n')

Reference model parameters to be updated:

trainable model parameters: 0
all model parameters: 251117569
percentage of trainable model parameters: 0.00%



In [9]:
toxicity_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_name, device_map="auto")
toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name, device_map="auto")
print(toxicity_model.config.id2label)

{0: 'nothate', 1: 'hate'}


In [10]:
non_toxic_text = "#Person 1# tells Tommy that he didn't like the movie."

toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors="pt").input_ids

#logits = toxicity_model(input_ids=toxicity_input_ids).logits
logits = toxicity_model(input_ids=toxicity_input_ids.to("cuda")).logits

print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

# get the logits for "not hate" - this is the reward!
not_hate_index = 0
nothate_reward = (logits[:, not_hate_index]).tolist()
print(f'reward (high): {nothate_reward}')

logits [not hate, hate]: [3.1142148971557617, -2.4896438121795654]
probabilities [not hate, hate]: [0.9963299632072449, 0.0036701022181659937]
reward (high): [3.1142148971557617]


In [11]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
non_toxic_text = "#Person 1# tells Tommy that he didn't like the movie."

toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors="pt").input_ids
toxicity_input_ids = toxicity_input_ids.to(device)

logits = toxicity_model(input_ids=toxicity_input_ids).logits
print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

# get the logits for "not hate" - this is the reward!
not_hate_index = 0
nothate_reward = (logits[:, not_hate_index]).tolist()
print(f'reward (high): {nothate_reward}')

logits [not hate, hate]: [3.1142148971557617, -2.4896438121795654]
probabilities [not hate, hate]: [0.9963299632072449, 0.0036701022181659937]
reward (high): [3.1142148971557617]


In [12]:
toxic_text = "#Person 1# tells Tommy that the movie was terrible, dumb and stupid."

toxicity_input_ids = toxicity_tokenizer(toxic_text, return_tensors="pt").input_ids
toxicity_input_ids = toxicity_input_ids.to(device)

logits = toxicity_model(toxicity_input_ids).logits
print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

# Get the logits for "not hate" - this is the reward!
nothate_reward = (logits[:, not_hate_index]).tolist() 
print(f'reward (low): {nothate_reward}')

logits [not hate, hate]: [-0.692075252532959, 0.3721935451030731]
probabilities [not hate, hate]: [0.25649452209472656, 0.7435054779052734]
reward (low): [-0.692075252532959]


In [13]:
device = 0 if torch.cuda.is_available() else "cpu"

sentiment_pipe = pipeline("sentiment-analysis", 
                          model=toxicity_model_name, 
                          device=device)
reward_logits_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "none", # Set to "none" to retrieve raw logits.
    "batch_size": 16
}

reward_probabilities_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "softmax", # Set to "softmax" to apply softmax and retrieve probabilities.
    "batch_size": 16
}

print("Reward model output:")
print("For non-toxic text")
print(sentiment_pipe(non_toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))
print("For toxic text")
print(sentiment_pipe(toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(toxic_text, **reward_probabilities_kwargs))

Reward model output:
For non-toxic text
[{'label': 'nothate', 'score': 3.1142148971557617}, {'label': 'hate', 'score': -2.4896438121795654}]
[{'label': 'nothate', 'score': 0.9963299632072449}, {'label': 'hate', 'score': 0.0036701022181659937}]
For toxic text
[{'label': 'hate', 'score': 0.3721935451030731}, {'label': 'nothate', 'score': -0.692075252532959}]
[{'label': 'hate', 'score': 0.7435054779052734}, {'label': 'nothate', 'score': 0.25649455189704895}]


In [14]:
print(sentiment_pipe(non_toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))

[{'label': 'nothate', 'score': 3.1142148971557617}, {'label': 'hate', 'score': -2.4896438121795654}]
[{'label': 'nothate', 'score': 0.9963299632072449}, {'label': 'hate', 'score': 0.0036701022181659937}]


In [15]:
print(sentiment_pipe(toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(toxic_text, **reward_probabilities_kwargs))

[{'label': 'hate', 'score': 0.3721935451030731}, {'label': 'nothate', 'score': -0.692075252532959}]
[{'label': 'hate', 'score': 0.7435054779052734}, {'label': 'nothate', 'score': 0.25649455189704895}]


In [16]:
text = "I love you."
print(sentiment_pipe(text, **reward_logits_kwargs))
print(sentiment_pipe(text, **reward_probabilities_kwargs))

[{'label': 'nothate', 'score': 4.620580673217773}, {'label': 'hate', 'score': -4.193358898162842}]
[{'label': 'nothate', 'score': 0.9998513460159302}, {'label': 'hate', 'score': 0.00014862434181850404}]


In [17]:
text = "How are you doing today?"
print(sentiment_pipe(text, **reward_logits_kwargs))
print(sentiment_pipe(text, **reward_probabilities_kwargs))

[{'label': 'nothate', 'score': 4.629106521606445}, {'label': 'hate', 'score': -4.079741954803467}]
[{'label': 'nothate', 'score': 0.9998348951339722}, {'label': 'hate', 'score': 0.0001650909543968737}]




In [18]:
text = "#Person 1# tells Tommy that he was terrible, dumb and stupid."
print(sentiment_pipe(text, **reward_logits_kwargs))
print(sentiment_pipe(text, **reward_probabilities_kwargs))

[{'label': 'hate', 'score': 0.9252259135246277}, {'label': 'nothate', 'score': -1.225296139717102}]
[{'label': 'hate', 'score': 0.8957175612449646}, {'label': 'nothate', 'score': 0.10428246110677719}]


In [19]:
toxicity_evaluator = evaluate.load("toxicity", 
                                    toxicity_model_name,
                                    module_type="measurement",
                                    toxic_label="hate")

In [20]:
toxicity_score = toxicity_evaluator.compute(predictions=[
    non_toxic_text
])

print("Toxicity score for non-toxic text:")
print(toxicity_score["toxicity"])

toxicity_score = toxicity_evaluator.compute(predictions=[
    toxic_text
])

print("\nToxicity score for toxic text:")
print(toxicity_score["toxicity"])

Toxicity score for non-toxic text:
[0.0036706007085740566]

Toxicity score for toxic text:
[0.7435290217399597]


In [21]:
toxicity_score = toxicity_evaluator.compute(predictions=[
    text
])

print("Toxicity score for text:")
print(toxicity_score["toxicity"])

Toxicity score for text:
[0.895907461643219]


In [22]:
def evaluate_toxicity(model, 
                      toxicity_evaluator, 
                      tokenizer, 
                      dataset, 
                      num_samples):
    
    """
    Preprocess the dataset and split it into train and test parts.

    Parameters:
    - model (trl model): Model to be evaluated.
    - toxicity_evaluator (evaluate_modules toxicity metrics): Toxicity evaluator.
    - tokenizer (transformers tokenizer): Tokenizer to be used.
    - dataset (dataset): Input dataset for the evaluation.
    - num_samples (int): Maximum number of samples for the evaluation.
        
    Returns:
    tuple: A tuple containing two numpy.float64 values:
    - mean (numpy.float64): Mean of the samples toxicity.
    - std (numpy.float64): Standard deviation of the samples toxicity.
    """

    max_new_tokens=100

    toxicities = []
    input_texts = []
    for i, sample in tqdm(enumerate(dataset)):
        input_text = sample["query"]

        if i > num_samples:
            break
            
        input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        
        generation_config = GenerationConfig(max_new_tokens=max_new_tokens,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)

        response_token_ids = model.generate(input_ids=input_ids.to("cuda"),
                                            generation_config=generation_config)
        
        generated_text = tokenizer.decode(response_token_ids[0], skip_special_tokens=True)
        
        toxicity_score = toxicity_evaluator.compute(predictions=[(input_text + " " + generated_text)])

        toxicities.extend(toxicity_score["toxicity"])

    # Compute mean & std using np.
    mean = np.mean(toxicities)
    std = np.std(toxicities)
        
    return mean, std

In [23]:
print(dataset["test"][0]["query"])

Summarize the following conversation. #Person1#: I would like to order some internet today. #Person2#: What kind would you like? #Person1#: What kind of internet is there? #Person2#: You can get DEL or dial-up. #Person1#: Which of those two is best? #Person2#: I would recommend DEL. #Person1#: So that one better? #Person2#: It's better because it doesn't tie up the phone. #Person1#: What do you mean by that? #Person2#: DEL isn't connected through your phone line, but dial-up is. #Person1#: So then I can't use my phone if I'm on the internet? #Person2#: That's correct. With DEL you can do both. Summary: </s>


In [24]:
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")

input_text = dataset["test"][0]["query"]
input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        
generation_config = GenerationConfig(max_new_tokens=100,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)
response_token_ids = model.generate(input_ids=input_ids,
                                            generation_config=generation_config)
print(tokenizer.decode(response_token_ids[0], skip_special_tokens=True))

:( #1) Thank you for helping. :( #2) ( #1) Your internet connection is not on your phone. :( #3) ( #2) Your internet is on your phone.


In [25]:
input_text = dataset["test"][0]["query"]
input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        
generation_config = GenerationConfig(max_new_tokens=100,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)
response_token_ids = ref_model.generate(input_ids=input_ids,
                                            generation_config=generation_config)
print(tokenizer.decode(response_token_ids[0], skip_special_tokens=True))

#Person1#: What kind of internet do I need? #Person2#: DEL or dial-up.


In [26]:
input_text = dataset["test"][0]["query"]
input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        
generation_config = GenerationConfig(max_new_tokens=100,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)
response_token_ids = peft_model.generate(input_ids=input_ids,
                                            generation_config=generation_config)
print(tokenizer.decode(response_token_ids[0], skip_special_tokens=True))

#Person1#: What's the internet connection? #Person2: I need to get an Ethernet connection.


In [27]:
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
ref_model = ref_model.to("cuda")
mean_before_detoxification, std_before_detoxification = evaluate_toxicity(model=ref_model, 
                                                                          toxicity_evaluator=toxicity_evaluator, 
                                                                          tokenizer=tokenizer, 
                                                                          dataset=dataset["test"], 
                                                                          num_samples=10)

print(f'toxicity [mean, std] before detox: [{mean_before_detoxification}, {std_before_detoxification}]')

11it [00:05,  1.91it/s]

toxicity [mean, std] before detox: [0.019630177433348515, 0.028746985290124098]





In [28]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

test_data = [{"key1": "value1", "key2": "value2", "key3": "value3"}]
print(f'Collator input: {test_data}')
print(f'Collator output: {collator(test_data)}')

Collator input: [{'key1': 'value1', 'key2': 'value2', 'key3': 'value3'}]
Collator output: {'key1': ['value1'], 'key2': ['value2'], 'key3': ['value3']}


In [29]:
learning_rate=1.41e-5
max_ppo_epochs=1
mini_batch_size=4
batch_size=16

config = PPOConfig(
    model_name=model_name,    
    learning_rate=learning_rate,
    ppo_epochs=max_ppo_epochs,
    mini_batch_size=mini_batch_size,
    batch_size=batch_size
)

ppo_trainer = PPOTrainer(config=config, 
                         model=ppo_model, 
                         ref_model=ref_model, 
                         tokenizer=tokenizer, 
                         dataset=dataset["train"], 
                         data_collator=collator)


Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [30]:
output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)

generation_kwargs = {
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}

reward_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "none", # You want the raw logits without softmax.
    "batch_size": 16
}

max_ppo_steps = 10

for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    # Break when you reach max_steps.
    if step >= max_ppo_steps:
        break   

    prompt_tensors = batch["input_ids"]
    #prompt_tensors = prompt_tensors.to(device)

    # Get response from FLAN-T5/PEFT LLM.
    summary_tensors = []

    for prompt_tensor in prompt_tensors:
        max_new_tokens = output_length_sampler()        
            
        generation_kwargs["max_new_tokens"] = max_new_tokens
        summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)
        
        summary_tensors.append(summary.squeeze()[-max_new_tokens:])
        
    # This needs to be called "response".
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in summary_tensors]

    # Compute reward outputs.
    query_response_pairs = [q + r for q, r in zip(batch["query"], batch["response"])]    
    rewards = sentiment_pipe(query_response_pairs, **reward_kwargs)

    # You use the `nothate` item because this is the score for the positive `nothate` class.
    reward_tensors = [torch.tensor(reward[not_hate_index]["score"]) for reward in rewards]    

    # Run PPO step.
    stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)
    
    print(f'objective/kl: {stats["objective/kl"]}')
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}')
    print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}')
    print('-'.join('' for x in range(100)))

1it [00:12, 12.50s/it]

objective/kl: 0.0
ppo/returns/mean: 1.1921799182891846
ppo/policy/advantages_mean: 6.9323480289540385e-09
---------------------------------------------------------------------------------------------------


2it [00:21, 10.21s/it]

objective/kl: -0.005525421351194382
ppo/returns/mean: 1.0142624378204346
ppo/policy/advantages_mean: 2.901365547813839e-08
---------------------------------------------------------------------------------------------------


3it [00:27,  8.68s/it]

objective/kl: -0.015313930809497833
ppo/returns/mean: 1.2783902883529663
ppo/policy/advantages_mean: 7.457150275058666e-08
---------------------------------------------------------------------------------------------------


4it [00:34,  8.02s/it]

objective/kl: 0.0009560601320117712
ppo/returns/mean: 1.4632378816604614
ppo/policy/advantages_mean: -3.484988653212895e-08
---------------------------------------------------------------------------------------------------


5it [00:43,  8.15s/it]

objective/kl: 0.01455492340028286
ppo/returns/mean: 1.370133638381958
ppo/policy/advantages_mean: 5.6136784110094595e-08
---------------------------------------------------------------------------------------------------


6it [00:53,  8.87s/it]

objective/kl: -0.0070638507604599
ppo/returns/mean: 1.07023286819458
ppo/policy/advantages_mean: 7.304766569404819e-09
---------------------------------------------------------------------------------------------------


7it [01:01,  8.53s/it]

objective/kl: 0.027559738606214523
ppo/returns/mean: 1.2369341850280762
ppo/policy/advantages_mean: 4.3146012984607296e-08
---------------------------------------------------------------------------------------------------


8it [01:09,  8.43s/it]

objective/kl: 0.03046477399766445
ppo/returns/mean: 1.324051022529602
ppo/policy/advantages_mean: 5.8130364521957745e-08
---------------------------------------------------------------------------------------------------


9it [01:17,  8.29s/it]

objective/kl: 0.00448666512966156
ppo/returns/mean: 1.030814528465271
ppo/policy/advantages_mean: -3.2998727306221554e-08
---------------------------------------------------------------------------------------------------


10it [01:26,  8.68s/it]

objective/kl: -0.03454224392771721
ppo/returns/mean: 1.3974056243896484
ppo/policy/advantages_mean: -9.667988365436031e-08
---------------------------------------------------------------------------------------------------





In [31]:
#ppo_model = ppo_model.to("cpu")
ppo_model = ppo_model.to("cuda")
mean_after_detoxification, std_after_detoxification = evaluate_toxicity(model=ppo_model, 
                                                                        toxicity_evaluator=toxicity_evaluator, 
                                                                        tokenizer=tokenizer, 
                                                                        dataset=dataset["test"],
                                                                        #dataset=dataset["test"].to(device),
                                                                        num_samples=10)
print(f'toxicity [mean, std] after detox: [{mean_after_detoxification}, {std_after_detoxification}]')

11it [00:05,  1.85it/s]

toxicity [mean, std] after detox: [0.021162914159834723, 0.02703214640824284]





In [32]:
mean_improvement = (mean_before_detoxification - mean_after_detoxification) / mean_before_detoxification
std_improvement = (std_before_detoxification - std_after_detoxification) / std_before_detoxification

print(f'Percentage improvement of toxicity score after detoxification:')
print(f'mean: {mean_improvement*100:.2f}%')
print(f'std: {std_improvement*100:.2f}%')

Percentage improvement of toxicity score after detoxification:
mean: -7.81%
std: 5.97%


In [33]:
ppo_model = ppo_model.to(device)

batch_size = 20
compare_results = {}

df_batch = dataset["test"][0:batch_size]

compare_results["query"] = df_batch["query"]
prompt_tensors = df_batch["input_ids"]

summary_tensors_ref = []
summary_tensors = []

# Get response from ppo and base model.
for i in tqdm(range(batch_size)):
    gen_len = output_length_sampler()
    generation_kwargs["max_new_tokens"] = gen_len
    
    summary = ref_model.generate(
        input_ids=torch.as_tensor(prompt_tensors[i]).unsqueeze(dim=0).to(device), 
        **generation_kwargs
    ).squeeze()[-gen_len:]
    summary_tensors_ref.append(summary)

    summary = ppo_model.generate(
        input_ids=torch.as_tensor(prompt_tensors[i]).unsqueeze(dim=0).to(device), 
        **generation_kwargs
    ).squeeze()[-gen_len:]
    summary_tensors.append(summary)

# Decode responses.
compare_results["response_before"] = [tokenizer.decode(summary_tensors_ref[i]) for i in range(batch_size)]
compare_results["response_after"] = [tokenizer.decode(summary_tensors[i]) for i in range(batch_size)]

# Sentiment analysis of query/response pairs before/after.
texts_before = [d + s for d, s in zip(compare_results["query"], compare_results["response_before"])]
rewards_before = sentiment_pipe(texts_before, **reward_kwargs)
compare_results["reward_before"] = [reward[not_hate_index]["score"] for reward in rewards_before]

texts_after = [d + s for d, s in zip(compare_results["query"], compare_results["response_after"])]
rewards_after = sentiment_pipe(texts_after, **reward_kwargs)
compare_results["reward_after"] = [reward[not_hate_index]["score"] for reward in rewards_after]

100%|██████████| 20/20 [00:16<00:00,  1.24it/s]


In [34]:
pd.set_option('display.max_colwidth', 500)
df_compare_results = pd.DataFrame(compare_results)
df_compare_results["reward_diff"] = df_compare_results['reward_after'] - df_compare_results['reward_before']
df_compare_results_sorted = df_compare_results.sort_values(by=['reward_diff'], ascending=False).reset_index(drop=True)
df_compare_results_sorted

Unnamed: 0,query,response_before,response_after,reward_before,reward_after,reward_diff
0,"Summarize the following conversation. #Person1#: Today more and more families have personal computers. People have wider range of choice to communicate with the outside world. #Person2#: Right. With the establishment of Internet and a lot of web companies, people are getting more and more dependent on the web. #Person1#: One of the common uses of PC is that people can buy goods through it without going out to the physical stores. #Person2#: Can you tell me how it is done? #Person1#: If a cus...","<pad> Data found anywhere on Britain. As long as you don't know this, it's hard for you to know some vital facts about banking.</s>",<pad> Computers are not for sale anymore. They are for those families that use them. People now have full financial advantage.</s>,2.478276,3.263169,0.784892
1,"Summarize the following conversation. #Person1#: Oh, my God! What's this? #Person2#: What? #Person1#: Look! This window is open. #Person2#: Did you open it before we left? #Person1#: Are you kidding? It's winter. Why would I open it? #Person2#: I don't know. Wait. Is this yours? #Person1#: No! Oh, my God! Someone has broken into the house. #Person2#: It looks that way. That's probably why the door wasn't locked when we came in. #Person1#: I locked it when I left though. #Person2#: Yes, but t...",<pad> Get some money. Go find the robber.</s>,"<pad> The burglar has broken into Allen Allen's house, which he opened the window on.</s>",1.93955,2.559449,0.6199
2,"Summarize the following conversation. #Person1#: Amanda, how do you like this peaked cap? #Person2#: Didn't you say you want to buy a top hat? #Person1#: But I think this one fits me Well. Why don't you try on the sombrero in black? #Person2#: I don't like caps at all. Summary: </s>","<pad> Amanda has changed her mind about the peaked cap, because she usually doesn't like one.</s>",<pad> Person1 asks about the top hat matching to the sombrero.</s>,1.463726,1.917723,0.453996
3,"Summarize the following conversation. #Person1#: I'm forming a music band. #Person2#: Do you already know how to play an instrument? #Person1#: Uh... Yeah! I'Ve told you a thousand times that I'm learning to play the drums. Now that I know how to play well, I would like to form a rock band. #Person2#: Aside from yourself, who are the other members of the band? #Person1#: We have a guy who plays guitar, and another who plays bass. Although we still haven't found anyone to be our singer. You t...",<pad> Can you tell me what about your job at person 1.</s>,"<pad> Pay attention to the room space that you used for recording gear like speakers, microphones, and so on. Consider adding instruments. Ask your friends for help. Maybe even something else excessive.</s>",2.475046,2.838039,0.362993
4,"Summarize the following conversation. #Person1#: Let's take a coffee break, shall we? #Person2#: I wish I could, but I can't. #Person1#: What keeps you so busy? You've been sitting there for hours. You've got to walk around. You just can't stay on the computer forever. #Person2#: Well, I am up to my neck in work. I've got to finish this report. Sarah needs it by noon. I don't want to be scolded if I can't finish my work by the deadline. #Person1#: I understand that, but you'd feel better if ...",<pad> Taking a few silly break is bad news for you.</s>,"<pad> Can't break a sweat? It's getting to be exhausting, and sometimes you need to write or do something to get going.</s>",1.545272,1.856132,0.31086
5,"Summarize the following conversation. #Person1#: Hello. I want to reconfirm our flight to London. #Person2#: Yes, sir. Did you call the airline? #Person1#: Yes, I did. But I couldn't communicate with them in English. They speak only Spanish. So I need your help. #Person2#: Certainly, sir. What is the flight number and when are you leaving? #Person1#: We are taking IB 385 to London tomorrow at 1 p. m. #Person2#: Oh, I see, sir. We have the airline office inside the hotel. They have an English...",<pad> Need to cancel a flight to London.</s>,<pad> #Person1#: Tell me your telephone number.</s>,2.002092,2.020433,0.018341
6,"Summarize the following conversation. #Person1#: I would like to order some internet today. #Person2#: What kind would you like? #Person1#: What kind of internet is there? #Person2#: You can get DEL or dial-up. #Person1#: Which of those two is best? #Person2#: I would recommend DEL. #Person1#: So that one better? #Person2#: It's better because it doesn't tie up the phone. #Person1#: What do you mean by that? #Person2#: DEL isn't connected through your phone line, but dial-up is. #Person1#: S...","<pad> If you want dial-up or wireless Internet, choose DEL.</s>","<pad> You can buy DEL or dial-up Internet, there are channels to choose from.</s>",2.586808,2.547638,-0.03917
7,"Summarize the following conversation. #Person1#: Could you help me figure out how to look for a job? #Person2#: We have lots of options, what type of job do you need? #Person1#: I want to work in an office. #Person2#: Do you want to work part-time or full-time? #Person1#: I want to work full-time. #Person2#: We have binders with local job listings or you can make use of the computers. OK? #Person1#: I am confused a bit but I am sure that I can figure it out. #Person2#: If you make an appoint...",<pad> #Person1#: Has there been any difficulty with finding a job?</s>,<pad> Find an appropriate job for you.</s>,2.257139,2.139814,-0.117324
8,"Summarize the following conversation. #Person1#: Could you help me, Sir? My flight got in 15 minutes ago. Everyone else has picked up the luggage but mine hasn't come through. #Person2#: I'm sorry, Madam, I'll go and find out if there is any more to come. Summary: </s>",<pad> Their flight got in 15 minutes ago. It will be 4pm now.</s>,"<pad> Flight 536 has a long wait, because Person1 has already kept her ticket.</s>",2.585227,2.460175,-0.125053
9,"Summarize the following conversation. #Person1#: Where shall I register, please? #Person2#: Here. Do you have a registration card? #Person1#: Yes. Here you are. #Person2#: Please register your information here and pay for it. And I'll make a medical record for you. #Person1#: OK. How much do I need to pay for the registration? #Person2#: Please pay ten yuan for the registration. #Person1#: Here is my money. #Person2#: This is your registration card. Please don't lose it and bring it whenever...",<pad> #Person1#: Are you ready?</s>,"<pad> The first person to find out is Coin Shut. The second person asks if he has a credit card. If not, he can introduce himself as Person1 when he sees the first person but he may look into a transfer and not apply for a card.</s>",1.546453,1.411179,-0.135273
