# Fine-Tune FLAN-T5 with Reinforcement Learning (PPO) and PEFT to Generate Confidential Prompts

In [1]:
import os

%pip install torch torchdata  --index-url https://download.pytorch.org/whl/cu118 --quiet

%pip install transformers datasets evaluate rouge_score peft trl --quiet

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification, AutoModelForSeq2SeqLM, GenerationConfig, pipeline
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, TaskType

# trl: Transformer Reinforcement Learning library
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

import torch
import evaluate

import numpy as np
import pandas as pd

# tqdm library makes the loops show a smart progress meter.
from tqdm import tqdm
tqdm.pandas()

torch.cuda.is_available()

  from .autonotebook import tqdm as notebook_tqdm


True

In [3]:
DATASET = "DIBT/10k_prompts_ranked"

PEFT_CHECKTPOINT = f"./model_checkpoint"

ORIGINAL_MODEL_NAME = 'google/flan-t5-small'

REWARD_MODEL_NAME = "dslim/distilbert-NER"

## 2 - Load FLAN-T5 Model, Reward Model

### 2.1 - Load Data and FLAN-T5 Model Fine-Tuned with Confidential Prompts as autoencoder

In [4]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)

original_model = AutoModelForSeq2SeqLM.from_pretrained(ORIGINAL_MODEL_NAME, torch_dtype=torch.bfloat16).to('cuda')

peft_model = PeftModel.from_pretrained(original_model, 
                                       PEFT_CHECKTPOINT, 
                                       torch_dtype=torch.bfloat16,
                                       is_trainable=True).to('cuda')

print(f'Peft fined tuned model:\n{print_number_of_trainable_model_parameters(peft_model)}\n')

ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,                                                               
                                                               torch_dtype=torch.bfloat16,
                                                               is_trainable=True).to('cuda')

tokenizer = AutoTokenizer.from_pretrained(ORIGINAL_MODEL_NAME, device_map="auto")

print(f'\nPPO model parameters to be updated:\n{print_number_of_trainable_model_parameters(ppo_model)}\n')
print(ppo_model.v_head)

ref_model = create_reference_model(ppo_model).to('cuda')

print(f'\nReference model PPO KL:\n{print_number_of_trainable_model_parameters(ref_model)}\n')

Peft fined tuned model:
trainable model parameters: 1376256
all model parameters: 78337408
percentage of trainable model parameters: 1.76%


PPO model parameters to be updated:
trainable model parameters: 1376769
all model parameters: 78337921
percentage of trainable model parameters: 1.76%

ValueHead(
  (dropout): Dropout(p=0.1, inplace=False)
  (summary): Linear(in_features=512, out_features=1, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)

Reference model PPO KL:
trainable model parameters: 0
all model parameters: 78337921
percentage of trainable model parameters: 0.00%



In [5]:
dataset = load_dataset(DATASET, split='train')

dataset = dataset.filter(lambda x: len(x["prompt"]) > 50 and len(x["prompt"]) <= 300, batched=False)

dataset = dataset.train_test_split(test_size=0.1)

dataset

DatasetDict({
    train: Dataset({
        features: ['prompt', 'quality', 'metadata', 'avg_rating', 'num_responses', 'agreement_ratio', 'raw_responses', 'kind', 'cluster_description', 'topic'],
        num_rows: 4680
    })
    test: Dataset({
        features: ['prompt', 'quality', 'metadata', 'avg_rating', 'num_responses', 'agreement_ratio', 'raw_responses', 'kind', 'cluster_description', 'topic'],
        num_rows: 521
    })
})

In [6]:
def make_safe_prompt(prompt):
    return f"""Create a safe prompt from the following prompt:

{prompt}

Prompt:"""

def tokenize_function(sample):
    sample['input_ids'] = tokenizer.encode( make_safe_prompt(sample['prompt']) )

    sample['query'] = tokenizer.decode(sample['input_ids'])
    return sample

# The dataset actually contains 3 diff splits: train, validation, test.
# The tokenize_function code is handling all data across all splits in batches.
tokenized_datasets = dataset.map(tokenize_function, batched=False)

tokenized_datasets = tokenized_datasets.remove_columns(['prompt', 'quality', 'metadata', 'avg_rating', 'num_responses', 'agreement_ratio', \
                                                        'raw_responses', 'kind', 'cluster_description', 'topic'])

tokenized_datasets

Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4680/4680 [00:02<00:00, 1912.90 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 521/521 [00:00<00:00, 1943.59 examples/s]


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'query'],
        num_rows: 4680
    })
    test: Dataset({
        features: ['input_ids', 'query'],
        num_rows: 521
    })
})

### 2.2 - Reward Model

**Reinforcement Learning (RL)** is one type of machine learning where agents take actions in an environment aimed at maximizing their cumulative rewards. The agent's behavior is defined by the **policy**. And the goal of reinforcement learning is for the agent to learn an optimal, or nearly-optimal, policy that maximizes the **reward function**. 


In [7]:
reward_tokenizer = AutoTokenizer.from_pretrained(REWARD_MODEL_NAME)
reward_model = AutoModelForTokenClassification.from_pretrained(REWARD_MODEL_NAME).to('cuda')

ner_model = pipeline("ner", model=reward_model, tokenizer=reward_tokenizer, device="cuda")

def reward(data):
    ners = ner_model(data)
    score = sum(m['score'] for m in ners) / len(ners) if len(ners) else 0.0

    return {'label': 'ok' if score < 0.5 else 'protect', 'score': score}

test = "This is a private John's name. Generate a fake loan with it."
print(f'Results model: {ner_model(test)}')

print(f'Rewards: {reward(test)}')

Results model: [{'entity': 'B-PER', 'score': 0.7190356, 'index': 5, 'word': 'John', 'start': 18, 'end': 22}]
Rewards: {'label': 'protect', 'score': 0.7190356254577637}


In [8]:
def evaluator(model, ds, column):
    global score
    global count
    score = 0
    count = 0
    def evaluate(sample):
        global score
        global count
        r = model(sample[column])
        if r['label'] == 'protect':
            count += 1
            score += r['score']

        return sample

    ds.map(evaluate, batched=False)

    return count, score / count if count > 0 else 0

count, score = evaluator(reward, tokenized_datasets['train'], 'query')
print(f"Found {count} prompts to protect out of {len(tokenized_datasets['train'])}. Average score {score}")

Map:   0%|                                                                                                                                                                      | 0/4680 [00:00<?, ? examples/s]You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4680/4680 [00:40<00:00, 116.02 examples/s]

Found 1804 prompts to protect out of 4680. Average score 0.8473334156436817





In [18]:
def evaluate_ner_generation(model, size):

    generation_config = GenerationConfig(max_new_tokens=100, top_k=0.0, top_p=1.0, do_sample=True)

    scores = []
    for i, sample in tqdm(enumerate(tokenized_datasets['test'].select(range(size)))):
        prompt = sample["query"]
            
        inp = tokenizer(sample["query"], return_tensors="pt", padding=True).to('cuda').input_ids
        
        gen_ids = model.generate(input_ids=inp, generation_config=generation_config)
        
        generated_text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
        
        r = reward(sample["query"] + " " + generated_text)

        scores.append( r['score'] )

    mean = np.mean(scores)
    std = np.std(scores)
        
    return mean, std

peft_mean, peft_std = evaluate_ner_generation(ref_model, size=len(tokenized_datasets['test']))

print(f"Average need protect score: {peft_mean} . Std: {peft_std}")

521it [07:35,  1.14it/s]

Average need protect score: 0.3070756154536768 . Std: 0.40759072981704814





<a name='3'></a>
## 3 - Perform Fine-Tuning to protect prompts
Optimize a RL policy against the reward model using Proximal Policy Optimization (PPO).

<a name='3.1'></a>
### 3.1 - Initialize `PPOTrainer`
 
For the `PPOTrainer` initialization, you will need a collator. Here it will be a function transforming the dictionaries in a particular way. You can define and test it:

Set up the configuration parameters. Load the `ppo_model` and the tokenizer. You will also load a frozen version of the model `ref_model`. The first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This works as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original LLM.

In [10]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

test_data = [{"key1": "value1", "key2": "value2", "key3": "value3"}]
print(f'Collator input: {test_data}')
print(f'Collator output: {collator(test_data)}')

config = PPOConfig(
    model_name="ppo_model",    
    learning_rate=1.41e-5,
    ppo_epochs=1,
    mini_batch_size=1,
    batch_size=2
)

ppo_trainer = PPOTrainer(config=config, 
                         model=ppo_model, 
                         ref_model=ref_model, 
                         tokenizer=tokenizer, 
                         dataset=tokenized_datasets["train"], 
                         data_collator=collator)

Collator input: [{'key1': 'value1', 'key2': 'value2', 'key3': 'value3'}]
Collator output: {'key1': ['value1'], 'key2': ['value2'], 'key3': ['value3']}


<a name='3.2'></a>
### 3.2 - Fine-Tune the Model

The fine-tuning loop consists of the following main steps:
1. Get the query responses from the policy LLM (PEFT model).
2. Get reward values using distilbert-ner
3. Optimize policy with PPO using the (query, response, reward) triplet.

The operation is running if you see the following metrics appearing:
* `objective/kl`: minimize kl divergence,
* `ppo/returns/mean`: maximize mean returns,
* `ppo/policy/advantages_mean`: maximize advantages.

In [11]:
output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)

generation_kwargs = {
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}

for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):

    reply_tensors = []
    promp_tensors = []

    for prompt_ids in batch["input_ids"]:
        max_new_tokens = output_length_sampler()

        generation_kwargs["max_new_tokens"] = max_new_tokens

        inp = torch.as_tensor(prompt_ids).to('cuda')
        promp_tensors.append(inp)
        
        prompts = ppo_trainer.generate(inp, **generation_kwargs)
        
        reply_tensors.append(prompts.squeeze()[-max_new_tokens:])
        
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in reply_tensors]

    # Compute reward outputs.
    rewards = [reward(q + r) for q, r in zip(batch["query"], batch["response"])]
    reward_tensors = [ - torch.tensor(r["score"]) for r in rewards]    

    # Run PPO step.
    stats = ppo_trainer.step(promp_tensors, reply_tensors, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)

    if step % 20 == 0:
        print(f'objective/kl: {stats["objective/kl"]}')
        print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}')
        print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}')
        print('-'.join('' for x in range(100)))

1it [00:01,  1.83s/it]

objective/kl: 7.874390125274658
ppo/returns/mean: -0.386505126953125
ppo/policy/advantages_mean: -0.020643368363380432
---------------------------------------------------------------------------------------------------


21it [00:59,  3.04s/it]

objective/kl: 16.69438934326172
ppo/returns/mean: -0.10692700743675232
ppo/policy/advantages_mean: 0.0515250563621521
---------------------------------------------------------------------------------------------------


41it [01:53,  3.42s/it]

objective/kl: 3.9284586906433105
ppo/returns/mean: 0.14606568217277527
ppo/policy/advantages_mean: -0.25949299335479736
---------------------------------------------------------------------------------------------------


61it [02:40,  2.37s/it]

objective/kl: 5.835698127746582
ppo/returns/mean: 0.16853585839271545
ppo/policy/advantages_mean: -0.011575572192668915
---------------------------------------------------------------------------------------------------


81it [03:32,  2.47s/it]

objective/kl: 1.113680124282837
ppo/returns/mean: 0.17919711768627167
ppo/policy/advantages_mean: 0.12866327166557312
---------------------------------------------------------------------------------------------------


101it [04:24,  2.09s/it]

objective/kl: 10.460097312927246
ppo/returns/mean: -0.5558268427848816
ppo/policy/advantages_mean: 0.0817338079214096
---------------------------------------------------------------------------------------------------


121it [05:16,  3.12s/it]

objective/kl: 8.776378631591797
ppo/returns/mean: -0.3434141278266907
ppo/policy/advantages_mean: 0.23436737060546875
---------------------------------------------------------------------------------------------------


141it [06:11,  2.77s/it]

objective/kl: 7.459585189819336
ppo/returns/mean: -0.4541173279285431
ppo/policy/advantages_mean: 5.960464477539063e-08
---------------------------------------------------------------------------------------------------


161it [07:01,  2.38s/it]

objective/kl: 12.359760284423828
ppo/returns/mean: -0.6478800773620605
ppo/policy/advantages_mean: -0.32652804255485535
---------------------------------------------------------------------------------------------------


181it [07:49,  2.52s/it]

objective/kl: 2.5769102573394775
ppo/returns/mean: 0.2929069399833679
ppo/policy/advantages_mean: -0.08970504999160767
---------------------------------------------------------------------------------------------------


201it [08:39,  2.92s/it]

objective/kl: 11.039070129394531
ppo/returns/mean: -0.062153398990631104
ppo/policy/advantages_mean: -0.3090556561946869
---------------------------------------------------------------------------------------------------


221it [09:25,  2.32s/it]

objective/kl: -2.9055495262145996
ppo/returns/mean: 0.22249266505241394
ppo/policy/advantages_mean: -0.13908767700195312
---------------------------------------------------------------------------------------------------


241it [10:21,  3.32s/it]

objective/kl: 12.588629722595215
ppo/returns/mean: -0.6037713289260864
ppo/policy/advantages_mean: -0.4630015194416046
---------------------------------------------------------------------------------------------------


261it [11:13,  3.51s/it]

objective/kl: 26.878517150878906
ppo/returns/mean: -1.2588640451431274
ppo/policy/advantages_mean: 0.15761269629001617
---------------------------------------------------------------------------------------------------


281it [12:01,  2.21s/it]

objective/kl: 8.932064056396484
ppo/returns/mean: -1.0093388557434082
ppo/policy/advantages_mean: 0.010890878736972809
---------------------------------------------------------------------------------------------------


301it [12:47,  2.91s/it]

objective/kl: 7.194850444793701
ppo/returns/mean: -0.2959769368171692
ppo/policy/advantages_mean: -0.014455273747444153
---------------------------------------------------------------------------------------------------


321it [13:36,  1.96s/it]

objective/kl: 1.5449353456497192
ppo/returns/mean: -0.3995529115200043
ppo/policy/advantages_mean: 0.197078138589859
---------------------------------------------------------------------------------------------------


341it [14:23,  1.81s/it]

objective/kl: 3.525902032852173
ppo/returns/mean: -0.34936124086380005
ppo/policy/advantages_mean: -0.1597183644771576
---------------------------------------------------------------------------------------------------


361it [15:13,  2.76s/it]

objective/kl: 3.8998613357543945
ppo/returns/mean: -0.03330712765455246
ppo/policy/advantages_mean: 0.06938782334327698
---------------------------------------------------------------------------------------------------


381it [16:05,  2.43s/it]

objective/kl: 5.408855438232422
ppo/returns/mean: -0.3032090663909912
ppo/policy/advantages_mean: -0.26647791266441345
---------------------------------------------------------------------------------------------------


401it [16:59,  2.78s/it]

objective/kl: 5.804335594177246
ppo/returns/mean: -0.23455144464969635
ppo/policy/advantages_mean: 0.02741359733045101
---------------------------------------------------------------------------------------------------


421it [17:53,  2.38s/it]

objective/kl: 6.3468217849731445
ppo/returns/mean: -0.10778060555458069
ppo/policy/advantages_mean: -0.19353090226650238
---------------------------------------------------------------------------------------------------


441it [18:33,  2.15s/it]

objective/kl: 5.648125648498535
ppo/returns/mean: -0.682228684425354
ppo/policy/advantages_mean: -0.4032168686389923
---------------------------------------------------------------------------------------------------


461it [19:33,  3.12s/it]

objective/kl: 40.1257209777832
ppo/returns/mean: -2.1247599124908447
ppo/policy/advantages_mean: 0.0
---------------------------------------------------------------------------------------------------


481it [20:30,  3.06s/it]

objective/kl: 2.71347713470459
ppo/returns/mean: 0.02128533273935318
ppo/policy/advantages_mean: 0.24191167950630188
---------------------------------------------------------------------------------------------------


501it [21:19,  2.10s/it]

objective/kl: 2.55240535736084
ppo/returns/mean: -0.06154848635196686
ppo/policy/advantages_mean: 0.10642561316490173
---------------------------------------------------------------------------------------------------


521it [22:14,  2.95s/it]

objective/kl: 10.469079971313477
ppo/returns/mean: -0.6006935238838196
ppo/policy/advantages_mean: -0.05022788047790527
---------------------------------------------------------------------------------------------------


541it [23:00,  2.18s/it]

objective/kl: 8.270356178283691
ppo/returns/mean: -0.7361651062965393
ppo/policy/advantages_mean: 0.19523733854293823
---------------------------------------------------------------------------------------------------


561it [23:50,  2.67s/it]

objective/kl: 2.9512383937835693
ppo/returns/mean: -0.0677398294210434
ppo/policy/advantages_mean: -0.2594880759716034
---------------------------------------------------------------------------------------------------


581it [24:38,  1.89s/it]

objective/kl: 17.188737869262695
ppo/returns/mean: -1.514052391052246
ppo/policy/advantages_mean: -0.15222612023353577
---------------------------------------------------------------------------------------------------


601it [25:29,  3.22s/it]

objective/kl: -1.0590934753417969
ppo/returns/mean: -0.1711270809173584
ppo/policy/advantages_mean: -0.5872310400009155
---------------------------------------------------------------------------------------------------


621it [26:31,  3.55s/it]

objective/kl: 9.131011962890625
ppo/returns/mean: -0.41598033905029297
ppo/policy/advantages_mean: -0.05222555994987488
---------------------------------------------------------------------------------------------------


641it [27:21,  2.68s/it]

objective/kl: 3.7299444675445557
ppo/returns/mean: -0.6166903972625732
ppo/policy/advantages_mean: 0.055070266127586365
---------------------------------------------------------------------------------------------------


661it [28:11,  2.91s/it]

objective/kl: 11.30679988861084
ppo/returns/mean: -0.4447242021560669
ppo/policy/advantages_mean: -0.17597651481628418
---------------------------------------------------------------------------------------------------


681it [29:03,  1.97s/it]

objective/kl: -2.1056551933288574
ppo/returns/mean: -0.018422961235046387
ppo/policy/advantages_mean: -0.06275340914726257
---------------------------------------------------------------------------------------------------


701it [29:50,  2.50s/it]

objective/kl: 5.916505336761475
ppo/returns/mean: -0.39060255885124207
ppo/policy/advantages_mean: -0.052865296602249146
---------------------------------------------------------------------------------------------------


721it [30:41,  2.40s/it]

objective/kl: 11.515992164611816
ppo/returns/mean: -1.2261472940444946
ppo/policy/advantages_mean: -0.026699185371398926
---------------------------------------------------------------------------------------------------


741it [31:33,  2.67s/it]

objective/kl: 6.887491703033447
ppo/returns/mean: -0.5275769829750061
ppo/policy/advantages_mean: 0.05892634019255638
---------------------------------------------------------------------------------------------------


761it [32:18,  2.03s/it]

objective/kl: 22.060874938964844
ppo/returns/mean: -1.6011658906936646
ppo/policy/advantages_mean: 0.15318182110786438
---------------------------------------------------------------------------------------------------


781it [33:17,  2.48s/it]

objective/kl: 62.15116500854492
ppo/returns/mean: -3.532172679901123
ppo/policy/advantages_mean: 0.195514976978302
---------------------------------------------------------------------------------------------------


801it [34:11,  2.84s/it]

objective/kl: 6.571223735809326
ppo/returns/mean: -0.6686107516288757
ppo/policy/advantages_mean: 0.1551368236541748
---------------------------------------------------------------------------------------------------


821it [35:05,  2.73s/it]

objective/kl: 8.994508743286133
ppo/returns/mean: -0.608680248260498
ppo/policy/advantages_mean: 0.19320009648799896
---------------------------------------------------------------------------------------------------


841it [35:53,  2.31s/it]

objective/kl: 3.286473035812378
ppo/returns/mean: -0.8351767659187317
ppo/policy/advantages_mean: -0.027204006910324097
---------------------------------------------------------------------------------------------------


861it [36:45,  2.57s/it]

objective/kl: 20.844745635986328
ppo/returns/mean: -1.4172494411468506
ppo/policy/advantages_mean: 0.002569134347140789
---------------------------------------------------------------------------------------------------


881it [37:46,  3.48s/it]

objective/kl: 9.547528266906738
ppo/returns/mean: -0.6274033188819885
ppo/policy/advantages_mean: -0.07033631205558777
---------------------------------------------------------------------------------------------------


901it [38:40,  3.28s/it]

objective/kl: 12.204246520996094
ppo/returns/mean: -0.9778375625610352
ppo/policy/advantages_mean: 0.06257912516593933
---------------------------------------------------------------------------------------------------


921it [39:30,  2.17s/it]

objective/kl: 5.192238807678223
ppo/returns/mean: -0.42756128311157227
ppo/policy/advantages_mean: -0.012915387749671936
---------------------------------------------------------------------------------------------------


941it [40:12,  1.93s/it]

objective/kl: 15.488338470458984
ppo/returns/mean: -1.1100099086761475
ppo/policy/advantages_mean: 0.32112300395965576
---------------------------------------------------------------------------------------------------


961it [41:00,  2.62s/it]

objective/kl: 12.056196212768555
ppo/returns/mean: -1.1712760925292969
ppo/policy/advantages_mean: -0.07806328684091568
---------------------------------------------------------------------------------------------------


981it [41:51,  2.24s/it]

objective/kl: 15.33996868133545
ppo/returns/mean: -1.1205298900604248
ppo/policy/advantages_mean: 0.04920737445354462
---------------------------------------------------------------------------------------------------


1001it [42:43,  2.97s/it]

objective/kl: 4.993995666503906
ppo/returns/mean: -0.7739098072052002
ppo/policy/advantages_mean: -0.07645297050476074
---------------------------------------------------------------------------------------------------


1021it [43:35,  2.13s/it]

objective/kl: 15.163532257080078
ppo/returns/mean: -0.7594738006591797
ppo/policy/advantages_mean: -0.14679023623466492
---------------------------------------------------------------------------------------------------


1041it [44:43,  4.26s/it]

objective/kl: 18.563060760498047
ppo/returns/mean: -0.9673831462860107
ppo/policy/advantages_mean: 0.0750507116317749
---------------------------------------------------------------------------------------------------


1061it [45:35,  2.76s/it]

objective/kl: 4.02427864074707
ppo/returns/mean: -0.627548336982727
ppo/policy/advantages_mean: -0.23578087985515594
---------------------------------------------------------------------------------------------------


1081it [46:21,  2.07s/it]

objective/kl: 1.8246928453445435
ppo/returns/mean: -0.3135833740234375
ppo/policy/advantages_mean: 0.0589674673974514
---------------------------------------------------------------------------------------------------


1101it [47:20,  2.72s/it]

objective/kl: 5.890929222106934
ppo/returns/mean: -0.5812664031982422
ppo/policy/advantages_mean: 0.42388254404067993
---------------------------------------------------------------------------------------------------


1121it [48:13,  3.51s/it]

objective/kl: 13.816459655761719
ppo/returns/mean: -1.1047382354736328
ppo/policy/advantages_mean: 0.5100609064102173
---------------------------------------------------------------------------------------------------


1141it [49:13,  2.28s/it]

objective/kl: 3.369213581085205
ppo/returns/mean: -0.34434932470321655
ppo/policy/advantages_mean: 0.0018508033826947212
---------------------------------------------------------------------------------------------------


1161it [50:05,  2.02s/it]

objective/kl: 3.3875532150268555
ppo/returns/mean: -0.3877819776535034
ppo/policy/advantages_mean: 0.10479603707790375
---------------------------------------------------------------------------------------------------


1181it [51:05,  2.52s/it]

objective/kl: 5.830746173858643
ppo/returns/mean: -0.8040759563446045
ppo/policy/advantages_mean: 0.06407357007265091
---------------------------------------------------------------------------------------------------


1201it [52:02,  2.73s/it]

objective/kl: 33.95820617675781
ppo/returns/mean: -1.4091434478759766
ppo/policy/advantages_mean: 0.4443027973175049
---------------------------------------------------------------------------------------------------


1221it [53:00,  3.13s/it]

objective/kl: 64.03408813476562
ppo/returns/mean: -2.424380302429199
ppo/policy/advantages_mean: 0.5296339988708496
---------------------------------------------------------------------------------------------------


1241it [54:12,  2.80s/it]

objective/kl: 4.457647323608398
ppo/returns/mean: -0.37329521775245667
ppo/policy/advantages_mean: 0.09763671457767487
---------------------------------------------------------------------------------------------------


1261it [55:03,  2.15s/it]

objective/kl: 6.063701629638672
ppo/returns/mean: -0.8730283379554749
ppo/policy/advantages_mean: 0.0014304015785455704
---------------------------------------------------------------------------------------------------


1281it [55:53,  2.34s/it]

objective/kl: 10.076835632324219
ppo/returns/mean: -0.7490701079368591
ppo/policy/advantages_mean: 0.09263700246810913
---------------------------------------------------------------------------------------------------


1301it [56:56,  2.65s/it]

objective/kl: 14.540630340576172
ppo/returns/mean: -1.4111915826797485
ppo/policy/advantages_mean: -0.5703574419021606
---------------------------------------------------------------------------------------------------


1321it [57:53,  3.13s/it]

objective/kl: 8.54103946685791
ppo/returns/mean: -0.9671183824539185
ppo/policy/advantages_mean: -0.10177192091941833
---------------------------------------------------------------------------------------------------


1341it [58:58,  3.03s/it]

objective/kl: 5.837609767913818
ppo/returns/mean: -0.7319106459617615
ppo/policy/advantages_mean: 0.1689254343509674
---------------------------------------------------------------------------------------------------


1361it [59:45,  2.07s/it]

objective/kl: 3.581906795501709
ppo/returns/mean: -0.4071638584136963
ppo/policy/advantages_mean: -0.007364027202129364
---------------------------------------------------------------------------------------------------


1381it [1:00:36,  2.65s/it]

objective/kl: 6.744108200073242
ppo/returns/mean: -1.015115737915039
ppo/policy/advantages_mean: -0.022092312574386597
---------------------------------------------------------------------------------------------------


1401it [1:01:33,  2.22s/it]

objective/kl: 8.100587844848633
ppo/returns/mean: -0.8756448030471802
ppo/policy/advantages_mean: 0.10615743696689606
---------------------------------------------------------------------------------------------------


1421it [1:02:22,  2.17s/it]

objective/kl: 9.142379760742188
ppo/returns/mean: -0.8055477142333984
ppo/policy/advantages_mean: 0.05483047664165497
---------------------------------------------------------------------------------------------------


1441it [1:03:16,  2.95s/it]

objective/kl: 3.965038776397705
ppo/returns/mean: -0.3262704908847809
ppo/policy/advantages_mean: -0.04006745666265488
---------------------------------------------------------------------------------------------------


1461it [1:04:05,  1.94s/it]

objective/kl: 5.720600605010986
ppo/returns/mean: -0.6937193870544434
ppo/policy/advantages_mean: -0.20279456675052643
---------------------------------------------------------------------------------------------------


1481it [1:04:56,  2.40s/it]

objective/kl: 2.8742384910583496
ppo/returns/mean: -0.6741130352020264
ppo/policy/advantages_mean: -0.43787479400634766
---------------------------------------------------------------------------------------------------


1501it [1:05:50,  2.93s/it]

objective/kl: 4.585004806518555
ppo/returns/mean: -0.7705327272415161
ppo/policy/advantages_mean: -0.24883109331130981
---------------------------------------------------------------------------------------------------


1521it [1:06:39,  2.24s/it]

objective/kl: 13.03388786315918
ppo/returns/mean: -1.3999173641204834
ppo/policy/advantages_mean: 0.13071100413799286
---------------------------------------------------------------------------------------------------


1541it [1:07:36,  2.41s/it]

objective/kl: 6.899394989013672
ppo/returns/mean: -0.9269323348999023
ppo/policy/advantages_mean: 0.264494925737381
---------------------------------------------------------------------------------------------------


1561it [1:08:26,  2.89s/it]

objective/kl: 10.744851112365723
ppo/returns/mean: -1.2610018253326416
ppo/policy/advantages_mean: 0.14237062633037567
---------------------------------------------------------------------------------------------------


1581it [1:09:21,  2.64s/it]

objective/kl: 16.125852584838867
ppo/returns/mean: -1.2561428546905518
ppo/policy/advantages_mean: 0.14634917676448822
---------------------------------------------------------------------------------------------------


1601it [1:10:09,  2.20s/it]

objective/kl: 1.3226767778396606
ppo/returns/mean: -0.5187605619430542
ppo/policy/advantages_mean: -0.22924628853797913
---------------------------------------------------------------------------------------------------


1621it [1:10:59,  2.03s/it]

objective/kl: 2.9385647773742676
ppo/returns/mean: -0.3443501889705658
ppo/policy/advantages_mean: -0.011588122695684433
---------------------------------------------------------------------------------------------------


1641it [1:12:01,  2.72s/it]

objective/kl: 2.095628023147583
ppo/returns/mean: -0.6853137016296387
ppo/policy/advantages_mean: -0.31556859612464905
---------------------------------------------------------------------------------------------------


1661it [1:12:49,  2.49s/it]

objective/kl: 9.776711463928223
ppo/returns/mean: -0.9062880277633667
ppo/policy/advantages_mean: 0.15939980745315552
---------------------------------------------------------------------------------------------------


1681it [1:13:44,  2.15s/it]

objective/kl: 5.971773147583008
ppo/returns/mean: -1.274552345275879
ppo/policy/advantages_mean: 0.3672305643558502
---------------------------------------------------------------------------------------------------


1701it [1:14:31,  2.64s/it]

objective/kl: 10.43311882019043
ppo/returns/mean: -1.009933590888977
ppo/policy/advantages_mean: 0.017566099762916565
---------------------------------------------------------------------------------------------------


1721it [1:15:19,  2.62s/it]

objective/kl: 13.387677192687988
ppo/returns/mean: -1.2929476499557495
ppo/policy/advantages_mean: 0.046932995319366455
---------------------------------------------------------------------------------------------------


1741it [1:16:21,  4.15s/it]

objective/kl: 18.067546844482422
ppo/returns/mean: -1.2567973136901855
ppo/policy/advantages_mean: 0.034054264426231384
---------------------------------------------------------------------------------------------------


1761it [1:17:24,  2.64s/it]

objective/kl: 10.604499816894531
ppo/returns/mean: -1.6983667612075806
ppo/policy/advantages_mean: -0.06434249877929688
---------------------------------------------------------------------------------------------------


1781it [1:18:19,  2.90s/it]

objective/kl: 18.381168365478516
ppo/returns/mean: -2.268251895904541
ppo/policy/advantages_mean: 0.1501816213130951
---------------------------------------------------------------------------------------------------


1801it [1:19:16,  3.43s/it]

objective/kl: 41.97243118286133
ppo/returns/mean: -2.808316469192505
ppo/policy/advantages_mean: 0.08548218011856079
---------------------------------------------------------------------------------------------------


1821it [1:20:06,  1.91s/it]

objective/kl: 2.437648296356201
ppo/returns/mean: -0.5473542213439941
ppo/policy/advantages_mean: -0.14743661880493164
---------------------------------------------------------------------------------------------------


1841it [1:20:54,  3.34s/it]

objective/kl: 40.4603271484375
ppo/returns/mean: -1.9442297220230103
ppo/policy/advantages_mean: 0.23988328874111176
---------------------------------------------------------------------------------------------------


1861it [1:21:47,  2.90s/it]

objective/kl: 8.014890670776367
ppo/returns/mean: -0.9569429755210876
ppo/policy/advantages_mean: 0.174189031124115
---------------------------------------------------------------------------------------------------


1881it [1:22:39,  2.74s/it]

objective/kl: 11.478837966918945
ppo/returns/mean: -1.1787699460983276
ppo/policy/advantages_mean: 0.12903811037540436
---------------------------------------------------------------------------------------------------


1901it [1:23:34,  2.45s/it]

objective/kl: 5.075638771057129
ppo/returns/mean: -0.7703702449798584
ppo/policy/advantages_mean: 0.12747707962989807
---------------------------------------------------------------------------------------------------


1921it [1:24:21,  2.40s/it]

objective/kl: 3.280702829360962
ppo/returns/mean: -0.576866626739502
ppo/policy/advantages_mean: -0.0182047002017498
---------------------------------------------------------------------------------------------------


1941it [1:25:16,  2.41s/it]

objective/kl: 0.007048487663269043
ppo/returns/mean: -0.32077598571777344
ppo/policy/advantages_mean: 0.021973520517349243
---------------------------------------------------------------------------------------------------


1961it [1:26:00,  1.90s/it]

objective/kl: 9.36141586303711
ppo/returns/mean: -1.1972787380218506
ppo/policy/advantages_mean: 0.061178289353847504
---------------------------------------------------------------------------------------------------


1981it [1:26:55,  2.30s/it]

objective/kl: 3.5085229873657227
ppo/returns/mean: -0.2587331533432007
ppo/policy/advantages_mean: 0.023602046072483063
---------------------------------------------------------------------------------------------------


2001it [1:27:59,  3.62s/it]

objective/kl: 21.36741065979004
ppo/returns/mean: -1.7440836429595947
ppo/policy/advantages_mean: -0.3662722408771515
---------------------------------------------------------------------------------------------------


2021it [1:28:53,  2.63s/it]

objective/kl: 3.9985849857330322
ppo/returns/mean: -0.7820138931274414
ppo/policy/advantages_mean: -0.03708180785179138
---------------------------------------------------------------------------------------------------


2041it [1:29:55,  3.05s/it]

objective/kl: 8.83263874053955
ppo/returns/mean: -0.9825858473777771
ppo/policy/advantages_mean: 0.49922168254852295
---------------------------------------------------------------------------------------------------


2061it [1:30:46,  2.45s/it]

objective/kl: 5.681432723999023
ppo/returns/mean: -1.2939579486846924
ppo/policy/advantages_mean: -0.7909479141235352
---------------------------------------------------------------------------------------------------


2081it [1:31:44,  2.56s/it]

objective/kl: 14.182760238647461
ppo/returns/mean: -1.1207013130187988
ppo/policy/advantages_mean: -0.10208620131015778
---------------------------------------------------------------------------------------------------


2101it [1:32:35,  2.84s/it]

objective/kl: 7.467329978942871
ppo/returns/mean: -0.9650772213935852
ppo/policy/advantages_mean: -0.16898222267627716
---------------------------------------------------------------------------------------------------


2121it [1:33:29,  2.66s/it]

objective/kl: 6.935825824737549
ppo/returns/mean: -1.224590539932251
ppo/policy/advantages_mean: -0.047514840960502625
---------------------------------------------------------------------------------------------------


2141it [1:34:25,  2.94s/it]

objective/kl: 9.995491981506348
ppo/returns/mean: -0.8783331513404846
ppo/policy/advantages_mean: 0.11017302423715591
---------------------------------------------------------------------------------------------------


2161it [1:35:17,  2.26s/it]

objective/kl: 6.943464756011963
ppo/returns/mean: -0.9672530889511108
ppo/policy/advantages_mean: -0.08253878355026245
---------------------------------------------------------------------------------------------------


2181it [1:36:23,  2.86s/it]

objective/kl: 14.079084396362305
ppo/returns/mean: -1.8730268478393555
ppo/policy/advantages_mean: 0.008827544748783112
---------------------------------------------------------------------------------------------------


2201it [1:37:17,  3.39s/it]

objective/kl: 24.461318969726562
ppo/returns/mean: -1.6157257556915283
ppo/policy/advantages_mean: 0.01202082633972168
---------------------------------------------------------------------------------------------------


2221it [1:38:13,  2.22s/it]

objective/kl: -2.0607054233551025
ppo/returns/mean: -0.563808023929596
ppo/policy/advantages_mean: 0.019701752811670303
---------------------------------------------------------------------------------------------------


2241it [1:39:01,  2.18s/it]

objective/kl: 5.010007381439209
ppo/returns/mean: -1.0263731479644775
ppo/policy/advantages_mean: 0.041366782039403915
---------------------------------------------------------------------------------------------------


2261it [1:39:55,  2.45s/it]

objective/kl: 34.498287200927734
ppo/returns/mean: -2.217802047729492
ppo/policy/advantages_mean: -0.2095203548669815
---------------------------------------------------------------------------------------------------


2281it [1:40:55,  3.83s/it]

objective/kl: 6.453924655914307
ppo/returns/mean: -0.951710045337677
ppo/policy/advantages_mean: -0.17228515446186066
---------------------------------------------------------------------------------------------------


2301it [1:41:43,  3.04s/it]

objective/kl: 11.228034973144531
ppo/returns/mean: -1.4084060192108154
ppo/policy/advantages_mean: -0.37908563017845154
---------------------------------------------------------------------------------------------------


2321it [1:42:27,  1.98s/it]

objective/kl: 4.705312252044678
ppo/returns/mean: -0.4427275061607361
ppo/policy/advantages_mean: 0.3497820496559143
---------------------------------------------------------------------------------------------------


2340it [1:43:21,  2.65s/it]


<a name='3.3'></a>
### 3.3 - Evaluate the Model Quantitatively

Load the PPO/PEFT model back in from disk and use the test dataset split to evaluate the toxicity score of the RL-fine-tuned model.

In [20]:
ppo_mean, ppo_std = evaluate_ner_generation(ppo_model, size=len(tokenized_datasets['test']))

print(f"Average need protect score after PPO: {ppo_mean} . Std: {ppo_std}")

521it [08:26,  1.03it/s]

Average need protect score after PPO: 0.33328257935858674 . Std: 0.4139923147111121





In [13]:
mean_improvement = (peft_mean - ppo_mean) / peft_mean
std_improvement = (peft_std - ppo_std) / peft_std

print(f'Percentage improvement of protect score after PPO:')
print(f'mean: {mean_improvement*100:.2f}%')
print(f'std: {std_improvement*100:.2f}%')

Percentage improvement of protect score after PPO:
mean: -44.48%
std: -13.55%


In [16]:
PPO_MODEL_PATH=f"./ppo_model_checkpoint"

ppo_model.save_pretrained(PPO_MODEL_PATH)
tokenizer.save_pretrained(PPO_MODEL_PATH)

('./ppo_model_checkpoint\\tokenizer_config.json',
 './ppo_model_checkpoint\\special_tokens_map.json',
 './ppo_model_checkpoint\\tokenizer.json')

In [24]:
index = 50

orginal_prompt = dataset['test'][index]['prompt']

def make_safe_prompt(prompt):
    return f"""Create a confidential prompt from the following prompt:

{prompt}

Prompt:"""

prompt = make_safe_prompt(orginal_prompt)

inputs = tokenizer(prompt, return_tensors='pt').to('cuda')

peft_output = tokenizer.decode(
    ref_model.generate( input_ids=inputs["input_ids"], max_new_tokens=200)[0], 
    skip_special_tokens=True
)

ppo_output = tokenizer.decode(
    ppo_model.generate( input_ids=inputs["input_ids"], max_new_tokens=200)[0], 
    skip_special_tokens=True
)

dash_line = '-'.join('' for x in range(100))
print(dash_line)
print(f'INPUT PROMPT:\n{orginal_prompt}')
print(dash_line)
print(f'MODEL GENERATION - PEFT MODEL:\n{peft_output}')
print(dash_line)
print(f'MODEL GENERATION - PPO MODEL:\n{ppo_output}')

---------------------------------------------------------------------------------------------------
INPUT PROMPT:
What are the most important events and festivals to attend in Rio de Janeiro, and how do they showcase the city's vibrant culture and traditions?
---------------------------------------------------------------------------------------------------
MODEL GENERATION - PEFT MODEL:
What are the most important events and festivals to attend in Rio de Janeiro, and how do they showcase the city's vibrant culture and traditions?
---------------------------------------------------------------------------------------------------
MODEL GENERATION - PPO MODEL:
What are the most important events and festivals to attend in Rio de Janeiro, and how do they showcase the city's vibrant culture and traditions?
