In [1]:
!pip install git+https://github.com/lvwerra/trl.git@25fa1bd

Collecting git+https://github.com/lvwerra/trl.git@25fa1bd
  Cloning https://github.com/lvwerra/trl.git (to revision 25fa1bd) to /tmp/pip-req-build-2k3cm3l5
  Running command git clone --filter=blob:none --quiet https://github.com/lvwerra/trl.git /tmp/pip-req-build-2k3cm3l5
[0m  Running command git checkout -q 25fa1bd
  Resolved https://github.com/lvwerra/trl.git to commit 25fa1bd
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: trl
  Building wheel for trl (setup.py) ... [?25ldone
[?25h  Created wheel for trl: filename=trl-0.4.2.dev0-py3-none-any.whl size=67533 sha256=72c911b073f53c7babc812242a349cf76e392b705b732adbf807da7c201d12ab
  Stored in directory: /tmp/pip-ephem-wheel-cache-e8wayng8/wheels/24/b4/20/2fa3a1e47c0411c39e198029315e3af2a2c1d59132913f136f
Successfully built trl
Installing collected packages: trl
Successfully installed trl-0.4.2.dev0


In [1]:
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfig
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, TaskType

# trl: Transformer Reinforcement Learning library
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

import torch
import evaluate

import numpy as np
import pandas as pd

# tqdm library makes the loops show a smart progress meter.
from tqdm import tqdm
tqdm.pandas()

In [2]:
model_name="google/flan-t5-base"
#model_name="google/flan-t5-large"
huggingface_dataset_name = "knkarthick/dialogsum"

dataset_original = load_dataset(huggingface_dataset_name)

dataset_original

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
})

In [3]:
def build_dataset(model_name,
                  dataset_name,
                  input_min_text_length, 
                  input_max_text_length):

    """
    Preprocess the dataset and split it into train and test parts.

    Parameters:
    - model_name (str): Tokenizer model name.
    - dataset_name (str): Name of the dataset to load.
    - input_min_text_length (int): Minimum length of the dialogues.
    - input_max_text_length (int): Maximum length of the dialogues.
        
    Returns:
    - dataset_splits (datasets.dataset_dict.DatasetDict): Preprocessed dataset containing train and test parts.
    """
    
    # load dataset (only "train" part will be enough for this lab).
    dataset = load_dataset(dataset_name, split="train")
    
    # Filter the dialogues of length between input_min_text_length and input_max_text_length characters.
    dataset = dataset.filter(lambda x: len(x["dialogue"]) > input_min_text_length and len(x["dialogue"]) <= input_max_text_length, batched=False)

    # Prepare tokenizer. Setting device_map="auto" allows to switch between GPU and CPU automatically.
    tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
    
    def tokenize(sample):
        
        # Wrap each dialogue with the instruction.
        prompt = f"""
Summarize the following conversation.

{sample["dialogue"]}

Summary:
"""
        sample["input_ids"] = tokenizer.encode(prompt)
        
        # This must be called "query", which is a requirement of our PPO library.
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    # Tokenize each dialogue.
    dataset = dataset.map(tokenize, batched=False)
    dataset.set_format(type="torch")
    
    # Split the dataset into train and test parts.
    dataset_splits = dataset.train_test_split(test_size=0.2, shuffle=False, seed=42)

    return dataset_splits

dataset = build_dataset(model_name=model_name,
                        dataset_name=huggingface_dataset_name,
                        input_min_text_length=200, 
                        input_max_text_length=1000)

print(dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 8017
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 2005
    })
})


In [4]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"\ntrainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

In [5]:
lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name, 
                                              torch_dtype=torch.bfloat16)

peft_model = PeftModel.from_pretrained(model, 
                                       './peft-dialogue-summary-checkpoint-local/', 
                                       lora_config=lora_config,
                                       torch_dtype=torch.bfloat16, 
                                       device_map="auto",                                       
                                       is_trainable=True)

print(f'PEFT model parameters to be updated:\n{print_number_of_trainable_model_parameters(peft_model)}\n')

PEFT model parameters to be updated:

trainable model parameters: 3538944
all model parameters: 251116800
percentage of trainable model parameters: 1.41%



In [6]:
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,                                                               
                                                               torch_dtype=torch.bfloat16,
                                                               is_trainable=True)

print(f'PPO model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(ppo_model)}\n')
print(ppo_model.v_head)

PPO model parameters to be updated (ValueHead + 769 params):

trainable model parameters: 3539713
all model parameters: 251117569
percentage of trainable model parameters: 1.41%

ValueHead(
  (dropout): Dropout(p=0.1, inplace=False)
  (summary): Linear(in_features=768, out_features=1, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)


In [7]:
ref_model = create_reference_model(ppo_model)

print(f'Reference model parameters to be updated:\n{print_number_of_trainable_model_parameters(ref_model)}\n')

Reference model parameters to be updated:

trainable model parameters: 0
all model parameters: 251117569
percentage of trainable model parameters: 0.00%



In [8]:
toxicity_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_name, device_map="auto")
toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name, device_map="auto")
print(toxicity_model.config.id2label)

{0: 'nothate', 1: 'hate'}


In [9]:
non_toxic_text = "#Person 1# tells Tommy that he didn't like the movie."

toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors="pt").input_ids

#logits = toxicity_model(input_ids=toxicity_input_ids).logits
logits = toxicity_model(input_ids=toxicity_input_ids.to("cuda")).logits

print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

# get the logits for "not hate" - this is the reward!
not_hate_index = 0
nothate_reward = (logits[:, not_hate_index]).tolist()
print(f'reward (high): {nothate_reward}')

logits [not hate, hate]: [3.1140496730804443, -2.489488124847412]
probabilities [not hate, hate]: [0.9963287711143494, 0.003671277780085802]
reward (high): [3.1140496730804443]


In [10]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
non_toxic_text = "#Person 1# tells Tommy that he didn't like the movie."

toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors="pt").input_ids
toxicity_input_ids = toxicity_input_ids.to(device)

logits = toxicity_model(input_ids=toxicity_input_ids).logits
print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

# get the logits for "not hate" - this is the reward!
not_hate_index = 0
nothate_reward = (logits[:, not_hate_index]).tolist()
print(f'reward (high): {nothate_reward}')

logits [not hate, hate]: [3.1140496730804443, -2.489488124847412]
probabilities [not hate, hate]: [0.9963287711143494, 0.003671277780085802]
reward (high): [3.1140496730804443]


In [11]:
toxic_text = "#Person 1# tells Tommy that the movie was terrible, dumb and stupid."

toxicity_input_ids = toxicity_tokenizer(toxic_text, return_tensors="pt").input_ids
toxicity_input_ids = toxicity_input_ids.to(device)

logits = toxicity_model(toxicity_input_ids).logits
print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

# Get the logits for "not hate" - this is the reward!
nothate_reward = (logits[:, not_hate_index]).tolist() 
print(f'reward (low): {nothate_reward}')

logits [not hate, hate]: [-0.6939336657524109, 0.3737802803516388]
probabilities [not hate, hate]: [0.2558380663394928, 0.7441619634628296]
reward (low): [-0.6939336657524109]


In [12]:
device = 0 if torch.cuda.is_available() else "cpu"

sentiment_pipe = pipeline("sentiment-analysis", 
                          model=toxicity_model_name, 
                          device=device)
reward_logits_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "none", # Set to "none" to retrieve raw logits.
    "batch_size": 16
}

reward_probabilities_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "softmax", # Set to "softmax" to apply softmax and retrieve probabilities.
    "batch_size": 16
}

print("Reward model output:")
print("For non-toxic text")
print(sentiment_pipe(non_toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))
print("For toxic text")
print(sentiment_pipe(toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(toxic_text, **reward_probabilities_kwargs))

Reward model output:
For non-toxic text
[{'label': 'nothate', 'score': 3.1140496730804443}, {'label': 'hate', 'score': -2.489488124847412}]
[{'label': 'nothate', 'score': 0.9963287711143494}, {'label': 'hate', 'score': 0.003671277780085802}]
For toxic text
[{'label': 'hate', 'score': 0.3737802803516388}, {'label': 'nothate', 'score': -0.6939336657524109}]
[{'label': 'hate', 'score': 0.7441619634628296}, {'label': 'nothate', 'score': 0.2558380663394928}]


In [13]:
print(sentiment_pipe(non_toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))

[{'label': 'nothate', 'score': 3.1140496730804443}, {'label': 'hate', 'score': -2.489488124847412}]
[{'label': 'nothate', 'score': 0.9963287711143494}, {'label': 'hate', 'score': 0.003671277780085802}]


In [14]:
print(sentiment_pipe(toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(toxic_text, **reward_probabilities_kwargs))

[{'label': 'hate', 'score': 0.3737802803516388}, {'label': 'nothate', 'score': -0.6939336657524109}]
[{'label': 'hate', 'score': 0.7441619634628296}, {'label': 'nothate', 'score': 0.2558380663394928}]


In [15]:
text = "I love you."
print(sentiment_pipe(text, **reward_logits_kwargs))
print(sentiment_pipe(text, **reward_probabilities_kwargs))

[{'label': 'nothate', 'score': 4.6204657554626465}, {'label': 'hate', 'score': -4.193373203277588}]
[{'label': 'nothate', 'score': 0.9998513460159302}, {'label': 'hate', 'score': 0.0001486393593950197}]


In [16]:
text = "How are you doing today?"
print(sentiment_pipe(text, **reward_logits_kwargs))
print(sentiment_pipe(text, **reward_probabilities_kwargs))

[{'label': 'nothate', 'score': 4.629106521606445}, {'label': 'hate', 'score': -4.079707145690918}]
[{'label': 'nothate', 'score': 0.9998348951339722}, {'label': 'hate', 'score': 0.00016509677516296506}]




In [17]:
text = "#Person 1# tells Tommy that he was terrible, dumb and stupid."
print(sentiment_pipe(text, **reward_logits_kwargs))
print(sentiment_pipe(text, **reward_probabilities_kwargs))

[{'label': 'hate', 'score': 0.9259346723556519}, {'label': 'nothate', 'score': -1.225987434387207}]
[{'label': 'hate', 'score': 0.8958483338356018}, {'label': 'nothate', 'score': 0.10415174067020416}]


In [18]:
toxicity_evaluator = evaluate.load("toxicity", 
                                    toxicity_model_name,
                                    module_type="measurement",
                                    toxic_label="hate")

In [19]:
toxicity_score = toxicity_evaluator.compute(predictions=[
    non_toxic_text
])

print("Toxicity score for non-toxic text:")
print(toxicity_score["toxicity"])

toxicity_score = toxicity_evaluator.compute(predictions=[
    toxic_text
])

print("\nToxicity score for toxic text:")
print(toxicity_score["toxicity"])

Toxicity score for non-toxic text:
[0.0036706042010337114]

Toxicity score for toxic text:
[0.7435290217399597]


In [20]:
toxicity_score = toxicity_evaluator.compute(predictions=[
    text
])

print("Toxicity score for text:")
print(toxicity_score["toxicity"])

Toxicity score for text:
[0.8959075212478638]


In [21]:
def evaluate_toxicity(model, 
                      toxicity_evaluator, 
                      tokenizer, 
                      dataset, 
                      num_samples):
    
    """
    Preprocess the dataset and split it into train and test parts.

    Parameters:
    - model (trl model): Model to be evaluated.
    - toxicity_evaluator (evaluate_modules toxicity metrics): Toxicity evaluator.
    - tokenizer (transformers tokenizer): Tokenizer to be used.
    - dataset (dataset): Input dataset for the evaluation.
    - num_samples (int): Maximum number of samples for the evaluation.
        
    Returns:
    tuple: A tuple containing two numpy.float64 values:
    - mean (numpy.float64): Mean of the samples toxicity.
    - std (numpy.float64): Standard deviation of the samples toxicity.
    """

    max_new_tokens=100

    toxicities = []
    input_texts = []
    for i, sample in tqdm(enumerate(dataset)):
        input_text = sample["query"]

        if i > num_samples:
            break
            
        input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        
        generation_config = GenerationConfig(max_new_tokens=max_new_tokens,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)

        response_token_ids = model.generate(input_ids=input_ids.to("cuda"),
                                            generation_config=generation_config)
        
        generated_text = tokenizer.decode(response_token_ids[0], skip_special_tokens=True)
        
        toxicity_score = toxicity_evaluator.compute(predictions=[(input_text + " " + generated_text)])

        toxicities.extend(toxicity_score["toxicity"])

    # Compute mean & std using np.
    mean = np.mean(toxicities)
    std = np.std(toxicities)
        
    return mean, std

In [22]:
print(dataset["test"][0]["query"])

Summarize the following conversation. #Person1#: I would like to order some internet today. #Person2#: What kind would you like? #Person1#: What kind of internet is there? #Person2#: You can get DEL or dial-up. #Person1#: Which of those two is best? #Person2#: I would recommend DEL. #Person1#: So that one better? #Person2#: It's better because it doesn't tie up the phone. #Person1#: What do you mean by that? #Person2#: DEL isn't connected through your phone line, but dial-up is. #Person1#: So then I can't use my phone if I'm on the internet? #Person2#: That's correct. With DEL you can do both. Summary: </s>


In [23]:
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")

input_text = dataset["test"][0]["query"]
input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        
generation_config = GenerationConfig(max_new_tokens=100,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)
response_token_ids = model.generate(input_ids=input_ids,
                                            generation_config=generation_config)
print(tokenizer.decode(response_token_ids[0], skip_special_tokens=True))

#Person1# wants to order dial-up Internet. #Person2# recommends dial-up and tells #P1# that it doesn't tie up the phone.


In [24]:
input_text = dataset["test"][0]["query"]
input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        
generation_config = GenerationConfig(max_new_tokens=100,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)
response_token_ids = ref_model.generate(input_ids=input_ids,
                                            generation_config=generation_config)
print(tokenizer.decode(response_token_ids[0], skip_special_tokens=True))

#Person1# wants to order some DEL or dial-up internet with #P1#'s help.


In [25]:
input_text = dataset["test"][0]["query"]
input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids
        
generation_config = GenerationConfig(max_new_tokens=100,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)
response_token_ids = peft_model.generate(input_ids=input_ids,
                                            generation_config=generation_config)
print(tokenizer.decode(response_token_ids[0], skip_special_tokens=True))

#Person1# wants to buy some Internet but doesn't have telephones to do both. #Person2# recommends DEL, which doesn't tie up the phone.


In [26]:
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
ref_model = ref_model.to("cuda")
mean_before_detoxification, std_before_detoxification = evaluate_toxicity(model=ref_model, 
                                                                          toxicity_evaluator=toxicity_evaluator, 
                                                                          tokenizer=tokenizer, 
                                                                          dataset=dataset["test"], 
                                                                          num_samples=10)

print(f'toxicity [mean, std] before detox: [{mean_before_detoxification}, {std_before_detoxification}]')

11it [00:08,  1.34it/s]

toxicity [mean, std] before detox: [0.032659146850082005, 0.0489372657443597]





In [27]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

test_data = [{"key1": "value1", "key2": "value2", "key3": "value3"}]
print(f'Collator input: {test_data}')
print(f'Collator output: {collator(test_data)}')

Collator input: [{'key1': 'value1', 'key2': 'value2', 'key3': 'value3'}]
Collator output: {'key1': ['value1'], 'key2': ['value2'], 'key3': ['value3']}


In [28]:
learning_rate=1.41e-5
max_ppo_epochs=1
mini_batch_size=4
batch_size=16

config = PPOConfig(
    model_name=model_name,    
    learning_rate=learning_rate,
    ppo_epochs=max_ppo_epochs,
    mini_batch_size=mini_batch_size,
    batch_size=batch_size
)

ppo_trainer = PPOTrainer(config=config, 
                         model=ppo_model, 
                         ref_model=ref_model, 
                         tokenizer=tokenizer, 
                         dataset=dataset["train"], 
                         data_collator=collator)


In [29]:
output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)

generation_kwargs = {
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}

reward_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "none", # You want the raw logits without softmax.
    "batch_size": 16
}

max_ppo_steps = 10

for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    # Break when you reach max_steps.
    if step >= max_ppo_steps:
        break   

    prompt_tensors = batch["input_ids"]
    #prompt_tensors = prompt_tensors.to(device)

    # Get response from FLAN-T5/PEFT LLM.
    summary_tensors = []

    for prompt_tensor in prompt_tensors:
        max_new_tokens = output_length_sampler()        
            
        generation_kwargs["max_new_tokens"] = max_new_tokens
        summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)
        
        summary_tensors.append(summary.squeeze()[-max_new_tokens:])
        
    # This needs to be called "response".
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in summary_tensors]

    # Compute reward outputs.
    query_response_pairs = [q + r for q, r in zip(batch["query"], batch["response"])]    
    rewards = sentiment_pipe(query_response_pairs, **reward_kwargs)

    # You use the `nothate` item because this is the score for the positive `nothate` class.
    reward_tensors = [torch.tensor(reward[not_hate_index]["score"]) for reward in rewards]    

    # Run PPO step.
    stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)
    
    print(f'objective/kl: {stats["objective/kl"]}')
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}')
    print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}')
    print('-'.join('' for x in range(100)))

0it [00:00, ?it/s]You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
1it [00:14, 14.17s/it]

objective/kl: 26.96736717224121
ppo/returns/mean: -0.327860951423645
ppo/policy/advantages_mean: -8.772149229230308e-09
---------------------------------------------------------------------------------------------------


2it [00:27, 13.43s/it]

objective/kl: 26.13813591003418
ppo/returns/mean: -0.303859144449234
ppo/policy/advantages_mean: 1.6285028436868743e-08
---------------------------------------------------------------------------------------------------


3it [00:38, 12.54s/it]

objective/kl: 17.057994842529297
ppo/returns/mean: 0.10496333241462708
ppo/policy/advantages_mean: -2.3792967596136805e-11
---------------------------------------------------------------------------------------------------


4it [00:51, 12.52s/it]

objective/kl: 22.364551544189453
ppo/returns/mean: -0.07384422421455383
ppo/policy/advantages_mean: -2.4391914266175263e-08
---------------------------------------------------------------------------------------------------


5it [01:04, 13.02s/it]

objective/kl: 25.84668731689453
ppo/returns/mean: -0.19504928588867188
ppo/policy/advantages_mean: 2.7601054775061584e-09
---------------------------------------------------------------------------------------------------


6it [01:19, 13.46s/it]

objective/kl: 29.3574275970459
ppo/returns/mean: -0.4261414110660553
ppo/policy/advantages_mean: -2.697810330687389e-08
---------------------------------------------------------------------------------------------------


7it [01:33, 13.83s/it]

objective/kl: 25.613330841064453
ppo/returns/mean: -0.2685316205024719
ppo/policy/advantages_mean: -1.4470316500592162e-08
---------------------------------------------------------------------------------------------------


8it [01:48, 13.98s/it]

objective/kl: 26.55628204345703
ppo/returns/mean: -0.27858686447143555
ppo/policy/advantages_mean: 5.630215849095066e-09
---------------------------------------------------------------------------------------------------


9it [02:00, 13.34s/it]

objective/kl: 21.39884376525879
ppo/returns/mean: -0.10612675547599792
ppo/policy/advantages_mean: 4.03273539006932e-08
---------------------------------------------------------------------------------------------------


10it [02:11, 13.16s/it]

objective/kl: 22.514467239379883
ppo/returns/mean: -0.021777454763650894
ppo/policy/advantages_mean: -8.535186779567994e-09
---------------------------------------------------------------------------------------------------





In [30]:
#ppo_model = ppo_model.to("cpu")
ppo_model = ppo_model.to("cuda")
mean_after_detoxification, std_after_detoxification = evaluate_toxicity(model=ppo_model, 
                                                                        toxicity_evaluator=toxicity_evaluator, 
                                                                        tokenizer=tokenizer, 
                                                                        dataset=dataset["test"],
                                                                        #dataset=dataset["test"].to(device),
                                                                        num_samples=10)
print(f'toxicity [mean, std] after detox: [{mean_after_detoxification}, {std_after_detoxification}]')

11it [00:08,  1.29it/s]

toxicity [mean, std] after detox: [0.02680596171624281, 0.029276909798787762]





In [31]:
mean_improvement = (mean_before_detoxification - mean_after_detoxification) / mean_before_detoxification
std_improvement = (std_before_detoxification - std_after_detoxification) / std_before_detoxification

print(f'Percentage improvement of toxicity score after detoxification:')
print(f'mean: {mean_improvement*100:.2f}%')
print(f'std: {std_improvement*100:.2f}%')

Percentage improvement of toxicity score after detoxification:
mean: 17.92%
std: 40.17%


In [32]:
ppo_model = ppo_model.to(device)

batch_size = 20
compare_results = {}

df_batch = dataset["test"][0:batch_size]

compare_results["query"] = df_batch["query"]
prompt_tensors = df_batch["input_ids"]

summary_tensors_ref = []
summary_tensors = []

# Get response from ppo and base model.
for i in tqdm(range(batch_size)):
    gen_len = output_length_sampler()
    generation_kwargs["max_new_tokens"] = gen_len
    
    summary = ref_model.generate(
        input_ids=torch.as_tensor(prompt_tensors[i]).unsqueeze(dim=0).to(device), 
        **generation_kwargs
    ).squeeze()[-gen_len:]
    summary_tensors_ref.append(summary)

    summary = ppo_model.generate(
        input_ids=torch.as_tensor(prompt_tensors[i]).unsqueeze(dim=0).to(device), 
        **generation_kwargs
    ).squeeze()[-gen_len:]
    summary_tensors.append(summary)

# Decode responses.
compare_results["response_before"] = [tokenizer.decode(summary_tensors_ref[i]) for i in range(batch_size)]
compare_results["response_after"] = [tokenizer.decode(summary_tensors[i]) for i in range(batch_size)]

# Sentiment analysis of query/response pairs before/after.
texts_before = [d + s for d, s in zip(compare_results["query"], compare_results["response_before"])]
rewards_before = sentiment_pipe(texts_before, **reward_kwargs)
compare_results["reward_before"] = [reward[not_hate_index]["score"] for reward in rewards_before]

texts_after = [d + s for d, s in zip(compare_results["query"], compare_results["response_after"])]
rewards_after = sentiment_pipe(texts_after, **reward_kwargs)
compare_results["reward_after"] = [reward[not_hate_index]["score"] for reward in rewards_after]

100%|██████████| 20/20 [00:27<00:00,  1.39s/it]


In [33]:
pd.set_option('display.max_colwidth', 500)
df_compare_results = pd.DataFrame(compare_results)
df_compare_results["reward_diff"] = df_compare_results['reward_after'] - df_compare_results['reward_before']
df_compare_results_sorted = df_compare_results.sort_values(by=['reward_diff'], ascending=False).reset_index(drop=True)
df_compare_results_sorted

Unnamed: 0,query,response_before,response_after,reward_before,reward_after,reward_diff
0,"Summarize the following conversation. #Person1#: Oh, my God! What's this? #Person2#: What? #Person1#: Look! This window is open. #Person2#: Did you open it before we left? #Person1#: Are you kidding? It's winter. Why would I open it? #Person2#: I don't know. Wait. Is this yours? #Person1#: No! Oh, my God! Someone has broken into the house. #Person2#: It looks that way. That's probably why the door wasn't locked when we came in. #Person1#: I locked it when I left though. #Person2#: Yes, but t...",<pad> #Person1# tells Allen that it's winter and people who stole the house aregies. They talk about top secretly their movements and the knowledge of what is happening.</s>,"<pad> Allen breaks into his house and breaks in at first. He opens the door and finds an old fridge. Then he hears an attack but then it's death, leaving Allen with nowhere to go.</s>",0.811695,2.056078,1.244384
1,"Summarize the following conversation. #Person1#: It smells like an ashtray in here! #Person2#: Hi honey! What's wrong? Why do you have that look on your face? #Person1#: What's wrong? I thought we agreed that you were gonna quit smoking. #Person2#: No! I said I was going to cut down which is very different. You can't just expect me to go cold turkey overnight! #Person1#: Look, there are other ways to quit. You can try the nicotine patch, or nicotine chewing gum. We spend a fortune on cigaret...","<pad> #Person2# smells like an ashtray. #Person1# thinks the inhaler could kill #Person2#, because of a bad taste in the air. #Person2# took a panic, and #Person1# says he shouldn't smoke. #Person2# will not quit hanging around.</s>",<pad> #Person1# tells #Person2# about finding another way to quit smoking and establishing a problem which isn't easy to understand. #Person2# feels bad and move on.</s>,1.195921,1.815888,0.619967
2,"Summarize the following conversation. #Person1#: I would like to order some internet today. #Person2#: What kind would you like? #Person1#: What kind of internet is there? #Person2#: You can get DEL or dial-up. #Person1#: Which of those two is best? #Person2#: I would recommend DEL. #Person1#: So that one better? #Person2#: It's better because it doesn't tie up the phone. #Person1#: What do you mean by that? #Person2#: DEL isn't connected through your phone line, but dial-up is. #Person1#: S...","<pad> When his phone is on the line, he can't use #Person1#'s phone while on the internet.</s>",<pad> #Person1# orders DEL in the Internet Shop and there are two kinds of Internet out.</s>,2.035398,2.58277,0.547373
3,"Summarize the following conversation. #Person1#: Hello? #Person2#: Hello? #Person1#: Can I speak to Li Hong, please? #Person2#: Speaking. #Person1#: Hi, Li Hong. This is Alice. #Person2#: Hi, Alice. How are you? #Person1#: Not bad. Li Hong, I am sorry that I can't go to see Mrs. Brown with you tomorrow morning. My mother is ill. I must take care of her. #Person2#: I'm sorry to hear that. You'd better stay at home. After all, we can visit Mrs. Brown later #Person1#: OK. Bye - bye. #Person2#: ...",<pad> Li Hong can't go to see Mrs. Brown tomorrow morning because her mother is sick.</s>,<pad> Li Hong can't come to Hangzhou tomorrow morning because her mother is ill.</s>,1.659602,2.131713,0.472111
4,"Summarize the following conversation. #Person1#: Amanda, how do you like this peaked cap? #Person2#: Didn't you say you want to buy a top hat? #Person1#: But I think this one fits me Well. Why don't you try on the sombrero in black? #Person2#: I don't like caps at all. Summary: </s>",<pad> Amanda tells #Person1# that she's impressed by the hat. #Person2# doesn't like caps at all.</s>,<pad> Amanda likes the peaked cap in #Person1# best.</s>,1.425087,1.689063,0.263976
5,"Summarize the following conversation. #Person1#: I'd like to have this cashed, please. #Person2#: Please put you name and address here. May I see your passport? #Person1#: Yes. #Person2#: How would you like it? #Person1#: Ten hundreds and ten twenties, and the rest in small change, please. #Person2#: OK. Here you are. Summary: </s>",<pad> #Person1# wants to cashed as many hundreds and ten twenties in small change with #P2#'s help.</s>,<pad> #Person1# wants to cashed ten hundreds and ten twenties in small change. #Person2# helps her go out and pay.</s>,1.049919,1.288102,0.238183
6,"Summarize the following conversation. #Person1#: Judy, what is everybody talking about? #Person2#: Haven't you heard? Richard was fired by our manager. #Person1#: You're kidding. It can't be true. #Person2#: Believe it or not. Everybody is talking about it in the company. #Person1#: Really? I'm surprised. #Person2#: Me too. Summary: </s>",<pad> Judy thinks Richard who was fired by his manager is too fun to be true.</s>,<pad> Richard was fired by the manager because #Person1# thinks everyone is talking about it in the company.</s>,1.156641,1.364785,0.208144
7,"Summarize the following conversation. #Person1#: Today more and more families have personal computers. People have wider range of choice to communicate with the outside world. #Person2#: Right. With the establishment of Internet and a lot of web companies, people are getting more and more dependent on the web. #Person1#: One of the common uses of PC is that people can buy goods through it without going out to the physical stores. #Person2#: Can you tell me how it is done? #Person1#: If a cus...",<pad> #Person1# tells about the advantages of using computer technology. #Person1# explains in detail the many advantages of computer equipment. #Person2# also explains the nature of computer equipment. Then #Person1# shows the advantages of computer's user interface.</s>,<pad> #Person1# tells with reference #P2#’s why computers are becoming more common. Latest product of a new invention is the ordering of items from web domain. They believe they are useful for ships: 1tons and strick tickets diesel among other machines.</s>,2.409236,2.610445,0.201209
8,"Summarize the following conversation. #Person1#: Could you help me, Sir? My flight got in 15 minutes ago. Everyone else has picked up the luggage but mine hasn't come through. #Person2#: I'm sorry, Madam, I'll go and find out if there is any more to come. Summary: </s>","<pad> I turned up the reservation because there was a few people in the flight, but #Person1#'s already already checked in. #Person2# says there's no need for my luggage.</s>",<pad> #Person1# wants to know why the flight got in 15 minutes ago.</s>,2.246698,2.309108,0.062409
9,"Summarize the following conversation. #Person1#: Hello. I want to reconfirm our flight to London. #Person2#: Yes, sir. Did you call the airline? #Person1#: Yes, I did. But I couldn't communicate with them in English. They speak only Spanish. So I need your help. #Person2#: Certainly, sir. What is the flight number and when are you leaving? #Person1#: We are taking IB 385 to London tomorrow at 1 p. m. #Person2#: Oh, I see, sir. We have the airline office inside the hotel. They have an English...",<pad> #Person1# questions the airlines for reconfirming the flight to London tomorrow at 1 p.m. #Person1# has mellified to the hotel's assistance.</s>,<pad> #Person1# calls to confirm a flight to London. #Person2# asks the number and time of departure. Then they dial 535.</s>,1.946794,2.002423,0.055629
