In [1]:
!pip install git+https://github.com/lvwerra/trl.git@25fa1bd

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting git+https://github.com/lvwerra/trl.git@25fa1bd
  Cloning https://github.com/lvwerra/trl.git (to revision 25fa1bd) to /tmp/pip-req-build-dd6ervmu
  Running command git clone --quiet https://github.com/lvwerra/trl.git /tmp/pip-req-build-dd6ervmu
[0m  Running command git checkout -q 25fa1bd
  Resolved https://github.com/lvwerra/trl.git to commit 25fa1bd
  Preparing metadata (setup.py) ... [?25ldone


In [1]:
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfig
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, TaskType

# trl: Transformer Reinforcement Learning library
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

import torch
import evaluate

import numpy as np
import pandas as pd

# tqdm library makes the loops show a smart progress meter.
from tqdm import tqdm
tqdm.pandas()

In [2]:
model_name="google/flan-t5-base"
#model_name="google/flan-t5-large"
huggingface_dataset_name = "knkarthick/dialogsum"

dataset_original = load_dataset(huggingface_dataset_name)

dataset_original

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
})

In [3]:
def build_dataset(model_name,
                  dataset_name,
                  input_min_text_length, 
                  input_max_text_length):

    """
    Preprocess the dataset and split it into train and test parts.

    Parameters:
    - model_name (str): Tokenizer model name.
    - dataset_name (str): Name of the dataset to load.
    - input_min_text_length (int): Minimum length of the dialogues.
    - input_max_text_length (int): Maximum length of the dialogues.
        
    Returns:
    - dataset_splits (datasets.dataset_dict.DatasetDict): Preprocessed dataset containing train and test parts.
    """
    
    # load dataset (only "train" part will be enough for this lab).
    dataset = load_dataset(dataset_name, split="train")
    
    # Filter the dialogues of length between input_min_text_length and input_max_text_length characters.
    dataset = dataset.filter(lambda x: len(x["dialogue"]) > input_min_text_length and len(x["dialogue"]) <= input_max_text_length, batched=False)

    # Prepare tokenizer. Setting device_map="auto" allows to switch between GPU and CPU automatically.
    tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
    
    def tokenize(sample):
        
        # Wrap each dialogue with the instruction.
        prompt = f"""
Summarize the following conversation.

{sample["dialogue"]}

Summary:
"""
        sample["input_ids"] = tokenizer.encode(prompt)
        
        # This must be called "query", which is a requirement of our PPO library.
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    # Tokenize each dialogue.
    dataset = dataset.map(tokenize, batched=False)
    dataset.set_format(type="torch")
    
    # Split the dataset into train and test parts.
    dataset_splits = dataset.train_test_split(test_size=0.2, shuffle=False, seed=42)

    return dataset_splits

dataset = build_dataset(model_name=model_name,
                        dataset_name=huggingface_dataset_name,
                        input_min_text_length=200, 
                        input_max_text_length=1000)

print(dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 8017
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 2005
    })
})


In [4]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"\ntrainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

In [5]:
lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name, 
                                              torch_dtype=torch.bfloat16)

peft_model = PeftModel.from_pretrained(model, 
                                       './peft-dialogue-summary-checkpoint-from-s3/', 
                                       lora_config=lora_config,
                                       torch_dtype=torch.bfloat16, 
                                       device_map="auto",                                       
                                       is_trainable=True)

print(f'PEFT model parameters to be updated:\n{print_number_of_trainable_model_parameters(peft_model)}\n')

PEFT model parameters to be updated:

trainable model parameters: 3538944
all model parameters: 251116800
percentage of trainable model parameters: 1.41%



In [6]:
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,                                                               
                                                               torch_dtype=torch.bfloat16,
                                                               is_trainable=True)

print(f'PPO model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(ppo_model)}\n')
print(ppo_model.v_head)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


PPO model parameters to be updated (ValueHead + 769 params):

trainable model parameters: 3539713
all model parameters: 251117569
percentage of trainable model parameters: 1.41%

ValueHead(
  (dropout): Dropout(p=0.1, inplace=False)
  (summary): Linear(in_features=768, out_features=1, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)


In [7]:
ref_model = create_reference_model(ppo_model)

print(f'Reference model parameters to be updated:\n{print_number_of_trainable_model_parameters(ref_model)}\n')

Reference model parameters to be updated:

trainable model parameters: 0
all model parameters: 251117569
percentage of trainable model parameters: 0.00%



In [8]:
toxicity_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_name, device_map="auto")
toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name, device_map="auto")
print(toxicity_model.config.id2label)

{0: 'nothate', 1: 'hate'}


In [9]:
non_toxic_text = "#Person 1# tells Tommy that he didn't like the movie."

toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors="pt").input_ids

#logits = toxicity_model(input_ids=toxicity_input_ids).logits
logits = toxicity_model(input_ids=toxicity_input_ids.to("cuda")).logits

print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

# get the logits for "not hate" - this is the reward!
not_hate_index = 0
nothate_reward = (logits[:, not_hate_index]).tolist()
print(f'reward (high): {nothate_reward}')

logits [not hate, hate]: [3.1140496730804443, -2.489488124847412]
probabilities [not hate, hate]: [0.9963287711143494, 0.003671277780085802]
reward (high): [3.1140496730804443]


In [10]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
non_toxic_text = "#Person 1# tells Tommy that he didn't like the movie."

toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors="pt").input_ids
toxicity_input_ids = toxicity_input_ids.to(device)

logits = toxicity_model(input_ids=toxicity_input_ids).logits
print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

# get the logits for "not hate" - this is the reward!
not_hate_index = 0
nothate_reward = (logits[:, not_hate_index]).tolist()
print(f'reward (high): {nothate_reward}')

logits [not hate, hate]: [3.1140496730804443, -2.489488124847412]
probabilities [not hate, hate]: [0.9963287711143494, 0.003671277780085802]
reward (high): [3.1140496730804443]


In [11]:
toxic_text = "#Person 1# tells Tommy that the movie was terrible, dumb and stupid."

toxicity_input_ids = toxicity_tokenizer(toxic_text, return_tensors="pt").input_ids
toxicity_input_ids = toxicity_input_ids.to(device)

logits = toxicity_model(toxicity_input_ids).logits
print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

# Get the logits for "not hate" - this is the reward!
nothate_reward = (logits[:, not_hate_index]).tolist() 
print(f'reward (low): {nothate_reward}')

logits [not hate, hate]: [-0.6939336657524109, 0.3737802803516388]
probabilities [not hate, hate]: [0.2558380663394928, 0.7441619634628296]
reward (low): [-0.6939336657524109]


In [12]:
device = 0 if torch.cuda.is_available() else "cpu"

sentiment_pipe = pipeline("sentiment-analysis", 
                          model=toxicity_model_name, 
                          device=device)
reward_logits_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "none", # Set to "none" to retrieve raw logits.
    "batch_size": 16
}

reward_probabilities_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "softmax", # Set to "softmax" to apply softmax and retrieve probabilities.
    "batch_size": 16
}

print("Reward model output:")
print("For non-toxic text")
print(sentiment_pipe(non_toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))
print("For toxic text")
print(sentiment_pipe(toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(toxic_text, **reward_probabilities_kwargs))

Reward model output:
For non-toxic text
[{'label': 'nothate', 'score': 3.1140496730804443}, {'label': 'hate', 'score': -2.489488124847412}]
[{'label': 'nothate', 'score': 0.9963287711143494}, {'label': 'hate', 'score': 0.003671277780085802}]
For toxic text
[{'label': 'hate', 'score': 0.3737802803516388}, {'label': 'nothate', 'score': -0.6939336657524109}]
[{'label': 'hate', 'score': 0.7441619634628296}, {'label': 'nothate', 'score': 0.2558380663394928}]


In [13]:
print(sentiment_pipe(non_toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))

[{'label': 'nothate', 'score': 3.1140496730804443}, {'label': 'hate', 'score': -2.489488124847412}]
[{'label': 'nothate', 'score': 0.9963287711143494}, {'label': 'hate', 'score': 0.003671277780085802}]


In [14]:
print(sentiment_pipe(toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(toxic_text, **reward_probabilities_kwargs))

[{'label': 'hate', 'score': 0.3737802803516388}, {'label': 'nothate', 'score': -0.6939336657524109}]
[{'label': 'hate', 'score': 0.7441619634628296}, {'label': 'nothate', 'score': 0.2558380663394928}]


In [15]:
text = "I love you."
print(sentiment_pipe(text, **reward_logits_kwargs))
print(sentiment_pipe(text, **reward_probabilities_kwargs))

[{'label': 'nothate', 'score': 4.6204657554626465}, {'label': 'hate', 'score': -4.193373203277588}]
[{'label': 'nothate', 'score': 0.9998513460159302}, {'label': 'hate', 'score': 0.0001486393593950197}]


In [16]:
text = "How are you doing today?"
print(sentiment_pipe(text, **reward_logits_kwargs))
print(sentiment_pipe(text, **reward_probabilities_kwargs))

[{'label': 'nothate', 'score': 4.629106521606445}, {'label': 'hate', 'score': -4.079707145690918}]
[{'label': 'nothate', 'score': 0.9998348951339722}, {'label': 'hate', 'score': 0.00016509677516296506}]




In [17]:
text = "#Person 1# tells Tommy that he was terrible, dumb and stupid."
print(sentiment_pipe(text, **reward_logits_kwargs))
print(sentiment_pipe(text, **reward_probabilities_kwargs))

[{'label': 'hate', 'score': 0.9259346723556519}, {'label': 'nothate', 'score': -1.225987434387207}]
[{'label': 'hate', 'score': 0.8958483338356018}, {'label': 'nothate', 'score': 0.10415174067020416}]


In [18]:
toxicity_evaluator = evaluate.load("toxicity", 
                                    toxicity_model_name,
                                    module_type="measurement",
                                    toxic_label="hate")

In [19]:
toxicity_score = toxicity_evaluator.compute(predictions=[
    non_toxic_text
])

print("Toxicity score for non-toxic text:")
print(toxicity_score["toxicity"])

toxicity_score = toxicity_evaluator.compute(predictions=[
    toxic_text
])

print("\nToxicity score for toxic text:")
print(toxicity_score["toxicity"])

Toxicity score for non-toxic text:
[0.0036706042010337114]

Toxicity score for toxic text:
[0.7435287833213806]


In [20]:
toxicity_score = toxicity_evaluator.compute(predictions=[
    text
])

print("Toxicity score for text:")
print(toxicity_score["toxicity"])

Toxicity score for text:
[0.8959075212478638]


In [21]:
def evaluate_toxicity(model, 
                      toxicity_evaluator, 
                      tokenizer, 
                      dataset, 
                      num_samples):
    
    """
    Preprocess the dataset and split it into train and test parts.

    Parameters:
    - model (trl model): Model to be evaluated.
    - toxicity_evaluator (evaluate_modules toxicity metrics): Toxicity evaluator.
    - tokenizer (transformers tokenizer): Tokenizer to be used.
    - dataset (dataset): Input dataset for the evaluation.
    - num_samples (int): Maximum number of samples for the evaluation.
        
    Returns:
    tuple: A tuple containing two numpy.float64 values:
    - mean (numpy.float64): Mean of the samples toxicity.
    - std (numpy.float64): Standard deviation of the samples toxicity.
    """

    max_new_tokens=100

    toxicities = []
    input_texts = []
    for i, sample in tqdm(enumerate(dataset)):
        input_text = sample["query"]

        if i > num_samples:
            break
            
        input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        
        generation_config = GenerationConfig(max_new_tokens=max_new_tokens,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)

        response_token_ids = model.generate(input_ids=input_ids.to("cuda"),
                                            generation_config=generation_config)
        
        generated_text = tokenizer.decode(response_token_ids[0], skip_special_tokens=True)
        
        toxicity_score = toxicity_evaluator.compute(predictions=[(input_text + " " + generated_text)])

        toxicities.extend(toxicity_score["toxicity"])

    # Compute mean & std using np.
    mean = np.mean(toxicities)
    std = np.std(toxicities)
        
    return mean, std

In [22]:
print(dataset["test"][0]["query"])

Summarize the following conversation. #Person1#: I would like to order some internet today. #Person2#: What kind would you like? #Person1#: What kind of internet is there? #Person2#: You can get DEL or dial-up. #Person1#: Which of those two is best? #Person2#: I would recommend DEL. #Person1#: So that one better? #Person2#: It's better because it doesn't tie up the phone. #Person1#: What do you mean by that? #Person2#: DEL isn't connected through your phone line, but dial-up is. #Person1#: So then I can't use my phone if I'm on the internet? #Person2#: That's correct. With DEL you can do both. Summary: </s>


In [23]:
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")

input_text = dataset["test"][0]["query"]
input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        
generation_config = GenerationConfig(max_new_tokens=100,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)
response_token_ids = model.generate(input_ids=input_ids,
                                            generation_config=generation_config)
print(tokenizer.decode(response_token_ids[0], skip_special_tokens=True))

#Person2# recommends Dial-Up or DEL to order some internet.


In [24]:
input_text = dataset["test"][0]["query"]
input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        
generation_config = GenerationConfig(max_new_tokens=100,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)
response_token_ids = ref_model.generate(input_ids=input_ids,
                                            generation_config=generation_config)
print(tokenizer.decode(response_token_ids[0], skip_special_tokens=True))

#Person1# wants to buy some internet with dial-up, so #Person2# recommends DEL. Because DEL isn't connected through phone line, then #Person1# can't use their phone if they are online.


In [25]:
input_text = dataset["test"][0]["query"]
input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        
generation_config = GenerationConfig(max_new_tokens=100,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)
response_token_ids = peft_model.generate(input_ids=input_ids,
                                            generation_config=generation_config)
print(tokenizer.decode(response_token_ids[0], skip_special_tokens=True))

#Person1# needs some internet because #Person2# says DEL isn't connected through the phone line.


In [26]:
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
ref_model = ref_model.to("cuda")
mean_before_detoxification, std_before_detoxification = evaluate_toxicity(model=ref_model, 
                                                                          toxicity_evaluator=toxicity_evaluator, 
                                                                          tokenizer=tokenizer, 
                                                                          dataset=dataset["test"], 
                                                                          num_samples=10)

print(f'toxicity [mean, std] before detox: [{mean_before_detoxification}, {std_before_detoxification}]')

11it [00:09,  1.14it/s]

toxicity [mean, std] before detox: [0.038271674388934945, 0.044015441967679925]





In [27]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

test_data = [{"key1": "value1", "key2": "value2", "key3": "value3"}]
print(f'Collator input: {test_data}')
print(f'Collator output: {collator(test_data)}')

Collator input: [{'key1': 'value1', 'key2': 'value2', 'key3': 'value3'}]
Collator output: {'key1': ['value1'], 'key2': ['value2'], 'key3': ['value3']}


In [28]:
learning_rate=1.41e-5
max_ppo_epochs=1
mini_batch_size=4
batch_size=16

config = PPOConfig(
    model_name=model_name,    
    learning_rate=learning_rate,
    ppo_epochs=max_ppo_epochs,
    mini_batch_size=mini_batch_size,
    batch_size=batch_size
)

ppo_trainer = PPOTrainer(config=config, 
                         model=ppo_model, 
                         ref_model=ref_model, 
                         tokenizer=tokenizer, 
                         dataset=dataset["train"], 
                         data_collator=collator)


Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [29]:
output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)

generation_kwargs = {
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}

reward_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "none", # You want the raw logits without softmax.
    "batch_size": 16
}

max_ppo_steps = 10

for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    # Break when you reach max_steps.
    if step >= max_ppo_steps:
        break   

    prompt_tensors = batch["input_ids"]
    #prompt_tensors = prompt_tensors.to(device)

    # Get response from FLAN-T5/PEFT LLM.
    summary_tensors = []

    for prompt_tensor in prompt_tensors:
        max_new_tokens = output_length_sampler()        
            
        generation_kwargs["max_new_tokens"] = max_new_tokens
        summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)
        
        summary_tensors.append(summary.squeeze()[-max_new_tokens:])
        
    # This needs to be called "response".
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in summary_tensors]

    # Compute reward outputs.
    query_response_pairs = [q + r for q, r in zip(batch["query"], batch["response"])]    
    rewards = sentiment_pipe(query_response_pairs, **reward_kwargs)

    # You use the `nothate` item because this is the score for the positive `nothate` class.
    reward_tensors = [torch.tensor(reward[not_hate_index]["score"]) for reward in rewards]    

    # Run PPO step.
    stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)
    
    print(f'objective/kl: {stats["objective/kl"]}')
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}')
    print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}')
    print('-'.join('' for x in range(100)))

0it [00:00, ?it/s]You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
1it [00:14, 14.85s/it]

objective/kl: 33.82997131347656
ppo/returns/mean: -0.9831958413124084
ppo/policy/advantages_mean: 1.840775354366997e-08
---------------------------------------------------------------------------------------------------


2it [00:28, 14.22s/it]

objective/kl: 31.035005569458008
ppo/returns/mean: -0.806267499923706
ppo/policy/advantages_mean: -7.809353164134336e-09
---------------------------------------------------------------------------------------------------


3it [00:40, 13.26s/it]

objective/kl: 27.124052047729492
ppo/returns/mean: -0.721274733543396
ppo/policy/advantages_mean: -5.676632941487014e-09
---------------------------------------------------------------------------------------------------


4it [00:52, 12.47s/it]

objective/kl: 20.720809936523438
ppo/returns/mean: -0.23343631625175476
ppo/policy/advantages_mean: 1.4579657481306185e-08
---------------------------------------------------------------------------------------------------


5it [01:04, 12.59s/it]

objective/kl: 28.249019622802734
ppo/returns/mean: -0.5034594535827637
ppo/policy/advantages_mean: -6.191809731603826e-09
---------------------------------------------------------------------------------------------------


6it [01:17, 12.59s/it]

objective/kl: 31.5592098236084
ppo/returns/mean: -0.8670631647109985
ppo/policy/advantages_mean: 2.4060563319494577e-08
---------------------------------------------------------------------------------------------------


7it [01:29, 12.47s/it]

objective/kl: 26.614458084106445
ppo/returns/mean: -0.5831557512283325
ppo/policy/advantages_mean: -5.418604231977042e-09
---------------------------------------------------------------------------------------------------


8it [01:43, 12.84s/it]

objective/kl: 25.330913543701172
ppo/returns/mean: -0.5219548344612122
ppo/policy/advantages_mean: 3.408430426787845e-09
---------------------------------------------------------------------------------------------------


9it [01:55, 12.59s/it]

objective/kl: 27.744409561157227
ppo/returns/mean: -0.789953351020813
ppo/policy/advantages_mean: -7.9428614796484e-09
---------------------------------------------------------------------------------------------------


10it [02:06, 12.67s/it]

objective/kl: 23.706890106201172
ppo/returns/mean: -0.32528597116470337
ppo/policy/advantages_mean: -8.270869322757335e-09
---------------------------------------------------------------------------------------------------





In [30]:
#ppo_model = ppo_model.to("cpu")
ppo_model = ppo_model.to("cuda")
mean_after_detoxification, std_after_detoxification = evaluate_toxicity(model=ppo_model, 
                                                                        toxicity_evaluator=toxicity_evaluator, 
                                                                        tokenizer=tokenizer, 
                                                                        dataset=dataset["test"],
                                                                        #dataset=dataset["test"].to(device),
                                                                        num_samples=10)
print(f'toxicity [mean, std] after detox: [{mean_after_detoxification}, {std_after_detoxification}]')

11it [00:09,  1.21it/s]

toxicity [mean, std] after detox: [0.020302407792769372, 0.01911036627594313]





In [31]:
mean_improvement = (mean_before_detoxification - mean_after_detoxification) / mean_before_detoxification
std_improvement = (std_before_detoxification - std_after_detoxification) / std_before_detoxification

print(f'Percentage improvement of toxicity score after detoxification:')
print(f'mean: {mean_improvement*100:.2f}%')
print(f'std: {std_improvement*100:.2f}%')

Percentage improvement of toxicity score after detoxification:
mean: 46.95%
std: 56.58%


In [32]:
ppo_model = ppo_model.to(device)

batch_size = 20
compare_results = {}

df_batch = dataset["test"][0:batch_size]

compare_results["query"] = df_batch["query"]
prompt_tensors = df_batch["input_ids"]

summary_tensors_ref = []
summary_tensors = []

# Get response from ppo and base model.
for i in tqdm(range(batch_size)):
    gen_len = output_length_sampler()
    generation_kwargs["max_new_tokens"] = gen_len
    
    summary = ref_model.generate(
        input_ids=torch.as_tensor(prompt_tensors[i]).unsqueeze(dim=0).to(device), 
        **generation_kwargs
    ).squeeze()[-gen_len:]
    summary_tensors_ref.append(summary)

    summary = ppo_model.generate(
        input_ids=torch.as_tensor(prompt_tensors[i]).unsqueeze(dim=0).to(device), 
        **generation_kwargs
    ).squeeze()[-gen_len:]
    summary_tensors.append(summary)

# Decode responses.
compare_results["response_before"] = [tokenizer.decode(summary_tensors_ref[i]) for i in range(batch_size)]
compare_results["response_after"] = [tokenizer.decode(summary_tensors[i]) for i in range(batch_size)]

# Sentiment analysis of query/response pairs before/after.
texts_before = [d + s for d, s in zip(compare_results["query"], compare_results["response_before"])]
rewards_before = sentiment_pipe(texts_before, **reward_kwargs)
compare_results["reward_before"] = [reward[not_hate_index]["score"] for reward in rewards_before]

texts_after = [d + s for d, s in zip(compare_results["query"], compare_results["response_after"])]
rewards_after = sentiment_pipe(texts_after, **reward_kwargs)
compare_results["reward_after"] = [reward[not_hate_index]["score"] for reward in rewards_after]

100%|██████████| 20/20 [00:27<00:00,  1.38s/it]


In [33]:
pd.set_option('display.max_colwidth', 500)
df_compare_results = pd.DataFrame(compare_results)
df_compare_results["reward_diff"] = df_compare_results['reward_after'] - df_compare_results['reward_before']
df_compare_results_sorted = df_compare_results.sort_values(by=['reward_diff'], ascending=False).reset_index(drop=True)
df_compare_results_sorted

Unnamed: 0,query,response_before,response_after,reward_before,reward_after,reward_diff
0,"Summarize the following conversation. #Person1#: Oh, my God! What's this? #Person2#: What? #Person1#: Look! This window is open. #Person2#: Did you open it before we left? #Person1#: Are you kidding? It's winter. Why would I open it? #Person2#: I don't know. Wait. Is this yours? #Person1#: No! Oh, my God! Someone has broken into the house. #Person2#: It looks that way. That's probably why the door wasn't locked when we came in. #Person1#: I locked it when I left though. #Person2#: Yes, but t...",<pad> #Person2# tells Allen that someone has left the door open. Allen finds out that a robber broke in during winter so it does not enter the house. Allen tells #Person1# that the TV's still in the house and the stereo are on the door. They fight a fight against the robber.</s>,<pad> Allen believes a burglar broke into the house after opening the window. Nobody is sure if someone is still there.</s>,2.005516,2.900229,0.894712
1,"Summarize the following conversation. #Person1#: Amanda, how do you like this peaked cap? #Person2#: Didn't you say you want to buy a top hat? #Person1#: But I think this one fits me Well. Why don't you try on the sombrero in black? #Person2#: I don't like caps at all. Summary: </s>",<pad> Amanda likes the peaked hat but #Person2# wants to buy a top hat for herself. Amanda also just complains about the collar.</s>,<pad> Amanda praises her favorite peak cap and the cap she was purchasing. #Person2# isn't surprised by it.</s>,1.169476,1.5657,0.396224
2,"Summarize the following conversation. #Person1#: I would like to order some internet today. #Person2#: What kind would you like? #Person1#: What kind of internet is there? #Person2#: You can get DEL or dial-up. #Person1#: Which of those two is best? #Person2#: I would recommend DEL. #Person1#: So that one better? #Person2#: It's better because it doesn't tie up the phone. #Person1#: What do you mean by that? #Person2#: DEL isn't connected through your phone line, but dial-up is. #Person1#: S...",<pad> #Person1# wants to buy some dial-up internet. #Person2# suggests getting DEL Internet. #Person1# wouldn't use #Person1#'s phone if #Person1# wanted it.</s>,<pad> #Person2# suggests dial-up instead of DEL because the phone cannot be done while on the internet.</s>,2.097442,2.409575,0.312133
3,"Summarize the following conversation. #Person1#: So how did you like the restaurant? #Person2#: Actually, it could have been better. #Person1#: What didn't you like about it? #Person2#: It is a new restaurant. I don't think they have their act together yet. #Person1#: What did you think about the food? #Person2#: I felt that the food was pretty mediocre. #Person1#: The service wasn't that great, either. #Person2#: I agree. The service was not good. #Person1#: Do you think that you want to tr...",<pad> #Person2# thinks the restaurant is a new one and the service was not good. #Person2# has too much to choose from.</s>,<pad> #Person2# hasn't tried this restaurant yet and #Person2# has had enough of it.</s>,1.941611,2.239094,0.297483
4,"Summarize the following conversation. #Person1#: Hello? #Person2#: Hello? #Person1#: Can I speak to Li Hong, please? #Person2#: Speaking. #Person1#: Hi, Li Hong. This is Alice. #Person2#: Hi, Alice. How are you? #Person1#: Not bad. Li Hong, I am sorry that I can't go to see Mrs. Brown with you tomorrow morning. My mother is ill. I must take care of her. #Person2#: I'm sorry to hear that. You'd better stay at home. After all, we can visit Mrs. Brown later #Person1#: OK. Bye - bye. #Person2#: ...",<pad> Li Hong will not go to Mrs. Brown with Alice tomorrow morning because her mother is sick. Alice apologizes till she gets her mother to see Li Hong.</s>,<pad> Alice can't go to see Mrs. Brown with Li Hong because her mother is ill. I'm sorry to hear she can't.</s>,1.134586,1.428815,0.294229
5,"Summarize the following conversation. #Person1#: Here is the final draft of our contract. I'm glad that we have reached an agreement on almost every term in our trade. #Person2#: Yes, it seems to me we have come quite a long way. However, let me take a close look at the final draft. #Person1#: Do you have some points to bring up? #Person2#: Well, everything we've discussed seems to be here. #Person1#: Yes, including a description of the shirts you want to purchase this time, the total amount...",<pad> #Person1# tells #Person2# about the final draft of their contract and considers the terms and the quality standard. #Person1# then reminds the 2# to sign the contract right now.</s>,<pad> #Person2# asks #Person1# for the final draft of a contract.</s>,2.475902,2.748296,0.272394
6,"Summarize the following conversation. #Person1#: I'm forming a music band. #Person2#: Do you already know how to play an instrument? #Person1#: Uh... Yeah! I'Ve told you a thousand times that I'm learning to play the drums. Now that I know how to play well, I would like to form a rock band. #Person2#: Aside from yourself, who are the other members of the band? #Person1#: We have a guy who plays guitar, and another who plays bass. Although we still haven't found anyone to be our singer. You t...","<pad> #Person1# wants to form a rock band. #Person1# tells #Person2#, beside them, three other musicians thinks #Person1# is strong in music, but still they've never found anyone. #Person2# invites #Person1# to audition the performances at #Person1#'s house.</s>",<pad> #Person1# is joining a music group and plans to form a rock band later. #Person2# agrees and suggests ensemble inputs from a musician. #Person1# invites #Person2# to audition this weekend.</s>,2.766788,3.006749,0.239962
7,"Summarize the following conversation. #Person1#: Hello. I want to reconfirm our flight to London. #Person2#: Yes, sir. Did you call the airline? #Person1#: Yes, I did. But I couldn't communicate with them in English. They speak only Spanish. So I need your help. #Person2#: Certainly, sir. What is the flight number and when are you leaving? #Person1#: We are taking IB 385 to London tomorrow at 1 p. m. #Person2#: Oh, I see, sir. We have the airline office inside the hotel. They have an English...","<pad> #Person1# wants to confirm a flight to London, it couldn't communicate with the airlines in English. #Person2# takes #Person1# to call the airline office and gives #Person1# a 3655. Then #Person1# is satisfied with the flight localization.</s>","<pad> #Person1# wants to confirm a flight to London. #Person2# recommends an airline but #Person1# didn't know how, so they offer a service.</s>",1.836877,2.032277,0.1954
8,"Summarize the following conversation. #Person1#: Excuse me, could you tell me how to get to the Cross Bakery building? #Person2#: The Cross Bakery building? Oh sure. You're actually walking in the opposite direction. #Person1#: Oh, you're kidding! I thought I was heading east. #Person2#: No, east is the other direction. To get to the Bakery, you need to turn around and go three blocks to Broadway. When you get to the intersection of Broadway and Elm, you hang a left. Go straight down that st...",<pad> #Person1# asks #Person2# to tell #Person1# what to do when heading east at the Cross Bakery building. #Person1# agrees to help #Person2# find the way to get there.</s>,<pad> #Person2# tells #Person1# the direction to the Cross Bakery Building from Broadway to Elm.</s>,2.604495,2.744709,0.140214
9,"Summarize the following conversation. #Person1#: Let's take a coffee break, shall we? #Person2#: I wish I could, but I can't. #Person1#: What keeps you so busy? You've been sitting there for hours. You've got to walk around. You just can't stay on the computer forever. #Person2#: Well, I am up to my neck in work. I've got to finish this report. Sarah needs it by noon. I don't want to be scolded if I can't finish my work by the deadline. #Person1#: I understand that, but you'd feel better if ...",<pad> #Person2# can't accept a break after work. #Person1# tells #Person2# he has to finish the report and shouldn't be scolded. #Person1# doesn't understand and thinks he would feel better if taken a break.</s>,<pad> #Person2# lives because of lack of breaks because she has to finish a report and needs it by noon. #Person1# thinks it would feel better if she took a short break.</s>,1.516051,1.578153,0.062102
