# Reinforcement Learning from Human Feedback (RLHF)

## Enhancing T5-Base Summarization with Proximal Policy Optimization (PPO) and PEFT Fine-Tuning


Reinforcement Learning from Human Feedback, commonly known as **RLHF**, is a specialized machine learning approach that amalgamates traditional reinforcement learning techniques and human expertise. This union offers a unique pathway to training artificial intelligence agents.

---

### Key Insights:

1. **Nature of RLHF**: RLHF can be understood as an iterative procedure. The system undergoes continuous improvement, adapting its learning function based on newly acquired human feedback.
  
2. **Safety and Trust**: Incorporating human feedback ensures the system not only comprehends the tasks it should execute but also recognizes actions it should avoid. This dual capability fosters safer and more trustworthy systems.
  
3. **Performance Enhancements**: A study in 2022 evidenced that RLHF outperforms conventional supervised learning (SL). This superiority can be attributed to RLHF's ability to assess cumulative rewards for coherent conversations, a nuanced understanding that SL misses.

---

RLHF has proven instrumental in guiding language models, molding them to align better with intricate human values. As we venture into this notebook, we'll deep-dive into the methodologies and applications of RLHF.



Useful references: 

https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py

https://www.kaggle.com/code/paultimothymooney/fine-tune-flan-t5-with-ppo-deeplearning-ai

https://github.com/huggingface/trl/blob/main/tests/test_ppo_trainer.py


Reward model: ideally a SequenceClassification type of model: We will use Bert

Policy model: ideally a Seq2SeqLM: We will use T5

![Alt text](image-1.png)

Image source: https://huggingface.co/docs/trl/index

## Process Overview

In this notebook, we embark on the journey of aligning a model using Reinforcement Learning from Human Feedback (RLHF). We'll employ various specialized models and leverage a structured training loop for this purpose.

---

### Models Utilized:

1. **Rewards Model**: 
   - A finely-tuned model designated for dispensing rewards based on the actions of the policy model.

2. **Base Model (Policy Model)**:
   - The core model we aim to align using RLHF.
   - During the RL process, this model becomes the "policy model", driving decisions and actions.

3. **Reference Model**:
   - A frozen replica of the base model.
   - Its primary role is to act as a benchmark, monitoring the evolution of the policy model throughout the RL process.

---

### Training loop Overview:

We begin by initializing the Proximal Policy Optimization (PPO) training class. The training process encompasses the following steps:

- **Generation of Summaries**: 
  - Derived from the policy model.
  
- **Reward Assignment**:
  - The generated summaries are channeled through the rewards model.
  - Based on these summaries, rewards are determined, reflecting the alignment of the policy model with human preferences.
  
- **Model Adjustment via PPO**:
  - Utilizing the acquired rewards, PPO refines the weights of the policy model, nudging it closer to human preferences.
  
This iterative training loop continues for a predefined number of steps.

---

## Evaluation:

Post-training, we evaluate the efficacy and alignment of the policy model post-RL to determine its proficiency in mirroring human preferences.

---



### Install dependencies

In [None]:
!pip install torch
!pip install transformers
!pip install datasets
!pip install trl
!pip install peft
!pip install numpy
!pip install pandas
!pip install tqdm


### A quick hack to link this notebook to WanDB

In this case it is redundant because the transformers libraries will do it, but as an educational gesture, this is how you could install WanDB in a notebook that doesn't contain ibraries already prep'ed with WDB.

In [2]:
import wandb
import random

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="rlhf_ppo_v1",
    
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjcolanotoro[0m ([33mjcolano[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
import torch 

from transformers import AutoModelForSequenceClassification, AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration

from torch.utils.data import DataLoader, Dataset as TorchDataset
from torch.optim import AdamW

from datasets import load_dataset, Dataset as HFDataset

from peft import PeftModel, PeftConfig,  TaskType

from peft import (
    get_peft_config,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
    PeftType,
    LoraConfig,
)

# AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
# https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead

# trl: Transformer Reinforcement Learning library
import trl 
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead # https://huggingface.co/docs/trl/quickstart
from trl import create_reference_model
from trl.core import LengthSampler

import evaluate

import numpy as np
import pandas as pd

# tqdm library makes the loops show a smart progress meter.
from tqdm import tqdm
tqdm.pandas()


  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

# Reward Model

![Alt text](image-2.png)

Image source: https://huggingface.co/blog/rlhf

## Reward Model in Reinforcement Learning (RL)

In RL, a **reward model** is a mechanism providing feedback to the agent about its performance in its environment. Instead of predefined reward functions, reward models infer the reward signal from human feedback, especially useful in complex scenarios where crafting a reward function is challenging.

### Why is it Important?

- **Feedback Mechanism**: It's how agents determine if actions are beneficial or detrimental.
- **Facilitates Learning**: Agents use these signals to update their policies to maximize rewards.
- **Handles Complexity**: For real-world problems where explicit reward functions are difficult, a learned reward model is valuable.
- **Safety and Alignment**: They ensure RL agents' objectives align with human intentions, reducing potential harmful behaviors.

In our code, we're initializing a reward model (based on a transformer like BERT) for RL with Human Feedback (RLHF). This model generates reward signals from the agent's interactions, steering its learning process.


In [32]:
# Specify the directory where you saved the model and tokenizer
reward_model_directory = "JuanKO/rlhf_reward_model"

rm_model = AutoModelForSequenceClassification.from_pretrained(reward_model_directory)
rm_tokenizer = AutoTokenizer.from_pretrained(reward_model_directory)
rm_model.to(device)


Some weights of the model checkpoint at ./model_bert_hf_experiment2/ were not used when initializing BertForSequenceClassification: ['bert.encoder.layer.3.attention.self.value.lora_B.default.weight', 'bert.encoder.layer.3.attention.self.query.lora_A.default.weight', 'bert.encoder.layer.10.attention.self.query.lora_A.default.weight', 'bert.encoder.layer.10.attention.self.value.lora_A.default.weight', 'bert.encoder.layer.7.attention.self.value.lora_B.default.weight', 'bert.encoder.layer.1.attention.self.value.lora_B.default.weight', 'bert.encoder.layer.9.attention.self.query.lora_B.default.weight', 'bert.encoder.layer.5.attention.self.value.lora_A.default.weight', 'bert.encoder.layer.2.attention.self.query.lora_B.default.weight', 'bert.encoder.layer.4.attention.self.value.lora_A.default.weight', 'bert.encoder.layer.7.attention.self.query.lora_A.default.weight', 'bert.encoder.layer.2.attention.self.value.lora_B.default.weight', 'bert.encoder.layer.4.attention.self.query.lora_B.default.wei

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

## Function: `score_summaries`

### Description:
The `score_summaries` function is designed to score two summaries, `chosen_summary` and `rejected_summary`, within the context of a Reinforcement Learning with Human Feedback (RLHF) loop. It tokenizes the inputs, obtains the logits from a given model, computes the softmax probabilities, and finally extracts the scores (probabilities) and logits associated with each summary.

### Parameters:

- **model** (`torch.nn.Module`): 
    - The PyTorch model that produces logits given an input.
  
- **tokenizer** (`transformers.PreTrainedTokenizer`): 
    - A tokenizer object used to tokenize input summaries.
  
- **chosen_summary** (`str`): 
    - The chosen summary string that needs to be scored.
  
- **rejected_summary** (`str`): 
    - The rejected summary string that needs to be scored.

### Returns:

- **chosen_score** (`float`): 
    - The probability score associated with the `chosen_summary` being positive or "good".

- **rejected_score** (`float`): 
    - The probability score associated with the `rejected_summary` being positive or "good".

- **chosen_logit** (`float`): 
    - The logit value associated with the `chosen_summary`.

- **rejected_logit** (`float`): 
    - The logit value associated with the `rejected_summary`.

### Function Flow:

1. **Tokenization**: 
    - The input summaries, `chosen_summary` and `rejected_summary`, are tokenized using the provided tokenizer. These tokenized inputs are padded or truncated to a maximum length of 512 tokens.

2. **Move to Device**: 
    - The tokenized tensors are transferred to the device (likely a GPU or CPU) where the model resides.

3. **Obtain Logits**: 
    - The tokenized tensors are passed through the model to obtain logits. This is done in a no-gradient context to ensure computational efficiency and prevent any updates to the model.

4. **Compute Probabilities**: 
    - The obtained logits are passed through a softmax function to get the associated probabilities. This helps in understanding how likely each summary is deemed "good" by the model.

5. **Extract Scores and Logits**: 
    - The function then extracts the probability and logit associated with the positive class (assumed to be the second class in the logits) for both summaries.

### Notes:
- The function assumes that the positive class (indicating the summary is "good") is the second class in the logits.
- The softmax function ensures that the logits are converted into probabilities that sum up to 1.


In [33]:
import torch.nn.functional as F


def score_summaries(model, tokenizer, chosen_summary, rejected_summary):
    # Tokenize the inputs
    chosen_tokens = tokenizer(chosen_summary, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
    rejected_tokens = tokenizer(rejected_summary, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
    
    chosen_tokens.to(device)
    rejected_tokens.to(device)
    
    # Get logits from the model
    with torch.no_grad():
        chosen_logits = model(**chosen_tokens).logits
        rejected_logits = model(**rejected_tokens).logits
    
    # Apply softmax to get probabilities
    chosen_probs = F.softmax(chosen_logits, dim=-1)
    rejected_probs = F.softmax(rejected_logits, dim=-1)

    # Assuming the positive class (indicating 'chosen' is good) is the second one
    chosen_score = chosen_probs[0][1].item()
    rejected_score = rejected_probs[0][1].item()
    
    # Extract logits for each summary
    chosen_logit = chosen_logits[0][1].item()
    rejected_logit = rejected_logits[0][1].item()

    return chosen_score, rejected_score, chosen_logit, rejected_logit

#### Run some examples to test the function


In this test, we evaluate the `score_summaries` function using two sample summaries: one labeled as `chosen_summary` and the other as `rejected_summary`. These summaries are tokenized, scored, and the associated logits are obtained using our reward model (`rm_model`) and its tokenizer (`rm_tokenizer`).

### Sample Summaries:

- **Chosen Summary**: 
    - "Water meter in another condo is not in our condo. What can we do legally to restore water to my condo complex?"
    
- **Rejected Summary**: 
    - "Go fix the problem."

### Test Execution:

The `score_summaries` function is called with the provided model, tokenizer, and the sample summaries. The returned scores and logits for each summary are then printed.

### Expected Output:

- **Chosen Score**: 
    - This gives the probability score of the `chosen_summary` being perceived as "good" or positive by the model.
  
- **Rejected Score**: 
    - This gives the probability score of the `rejected_summary` being perceived as "good" or positive by the model.
  
- **Chosen Logit**:
    - This returns the raw logit value associated with the `chosen_summary`.
  
- **Rejected Logit**:
    - This returns the raw logit value associated with the `rejected_summary`.

### Notes:
- Higher scores indicate a higher probability of the summary being perceived as positive or "good".
- The logit values provide insight into the raw outputs of the model before being passed through the softmax function.


In [34]:

chosen_summary = "Water meter in another condo is not in our condo. What can we do legally to restore water to my condo complex?"
rejected_summary = "Go fix the problem."

chosen_score, rejected_score, chosen_logit, rejected_logit = score_summaries(rm_model, rm_tokenizer, chosen_summary, rejected_summary)

print(f"Chosen Score: {chosen_score:.4f}")
print(f"Rejected Score: {rejected_score:.4f}")

print(f"Chosen Logit: {chosen_logit:.4f}")
print(f"Rejected Logit: {rejected_logit:.4f}")

## Loading the T5 Model for RLHF Fine-Tuning

### Overview:

T5, short for "Text-to-Text Transfer Transformer", is a state-of-the-art model designed to handle various text-to-text tasks. In this section, we'll be loading a T5 model that is intended to be fine-tuned using the Reinforcement Learning with Human Feedback (RLHF) approach.

### Steps:

1. **Model Selection**:
    - We've selected the T5 model for our fine-tuning process. Specifically, we'll be working with the "t5-base" variant which offers a balance between computational efficiency and performance.

2. **Loading Model and Tokenizer**:
    - `policy_model_path`: Specifies the directory path where our pre-trained (or fine-tuned) T5 model is saved.
    - `policy_model_name`: Indicates the model name, which in this case is "t5-base".
    - Using the `T5ForConditionalGeneration.from_pretrained` method, we load the model weights from our specified path.
    - Similarly, the corresponding tokenizer, which is essential for converting text into a format that the T5 model can understand, is loaded using the `T5Tokenizer.from_pretrained` method.

3. **Device Allocation**:
    - The model is assigned to a computation device (either CPU or GPU) using the `.to(device)` method. This ensures efficient computation, especially when working with large datasets.

### Test the Model:

After loading, it's a good practice to perform some inference tests to ensure that the model is loaded correctly and is functioning as expected.



In [7]:
policy_model_path = "JuanKO/rlhf_base_model"
policy_model_name = "t5-base" 

policy_model = T5ForConditionalGeneration.from_pretrained(policy_model_path)
policy_model.to(device)
policy_tokenizer = T5Tokenizer.from_pretrained(policy_model_path)

Downloading spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.37k [00:00<?, ?B/s]

### Testing the T5 Model for Summarization


After loading our T5 model, we'll test its summarization capabilities on a sample text from the r/relationships subreddit. This test will help us understand the model's performance and its readiness for RLHF fine-tuning.

### Steps:

1. **Setting the Task Prefix**:
    - We use the prefix "summarize: " to indicate to the T5 model the type of task we want it to perform.

2. **Sample Text**:
    - We have selected a post from the r/relationships subreddit to be summarized. This text provides context about a user's relationship concerns related to her bisexuality.

3. **Generating the Summary**:
    - We feed the concatenated task prefix and text into our T5 model.
    - The model then processes this input and returns a concise summary. The `generate` function is used to obtain this output, and we've set a max length of 100 tokens for our summary.

4. **Decoding the Summary**:
    - The output from the T5 model is in the form of token IDs. Using the T5 tokenizer's `decode` method, we convert these tokens back into human-readable text.

5. **Scoring the Summary using the Reward Model**:
    - With the generated summary in hand, we then use our previously defined `score_summaries` function to evaluate the quality of the summary.
    - This function returns a score and logit value for both the chosen summary and a rejected (blank) summary. Higher scores and logits suggest better alignment with what the reward model considers a good summary.

### Results:

By examining the printed scores and logits, we can gauge the perceived quality of the generated summary according to our reward model.


In [38]:
task_prefix = "summarize: " 

text = "SUBREDDIT: r/relationships TITLE: How do I/do I at all [20 F] tell my boyfriend [23 M] that I'm bisexual? POST: I've had two serious relationships prior to this one, both with women. They had no problem with me being bisexual and it was something known before the relationship -- my first girlfriend was also bisexual. I am now in a relationship with a guy. We've been exclusive for about a month. Having never faced this issue, I come to you, Reddit. Is this something that he needs to know? Is it really relevant to a hetero relationship, regardless of if one of the participants in the relationship is bisexual? If you guys think it is necessary, when do you think is the right time? I think my biggest fear is losing him because of it. I know that I should be with someone who is fine with who I am, but I really like the guy and I'd hate for my sexual orientation to be the thing that kills this."
#text = "SUBREDDIT: r/legaladvice TITLE: What can I do legally to restore water to my condominium!? POST: Hi, I live in SE Michigan in a condominium complex. Our water was shut off due to non-payment. (we recieved no notice) and we had to pay all that was due ($1500) We payed this yesterday at 2, they said the water would be turned on immediately. It wasn't. It's now the next day. The lady in our assosciation keeps insisting that the water meter is in another condo. Which we can't access because the person living there is never there (it's being rented) Now we're stuck with no water, no shower, no teeth brushing, no toilets, and no food for certain meals.... Please help us... What can we do? We called the police and they say that we can file a civil report for the lady not doing her job..."
prompt = f"{task_prefix}{text}"
input_ids = policy_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
outputs = policy_model.generate(input_ids, max_length=100).to(device)

strOutput = policy_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(strOutput)

chosen_score, rejected_score, chosen_logit, rejected_logit = score_summaries(rm_model, rm_tokenizer, strOutput, "")

print(f"Chosen Score: {chosen_score:.4f}")
print(f"Rejected Score: {rejected_score:.4f}")

print(f"Chosen Logit: {chosen_logit:.4f}")
print(f"Rejected Logit: {rejected_logit:.4f}")


TL;DR: I'm bisexual and I'm in a hetero relationship. Is it necessary to tell my boyfriend that I'm bisexual? When do you think is the right time?
Chosen Score: 0.5943
Rejected Score: 0.5193
Chosen Logit: 0.0889
Rejected Logit: 0.2162


## Preparing the T5 Model for Peft + LoRA

### Overview:

Peft and LoRA (Low-Rank Adaptation) are techniques that enable efficient fine-tuning of pre-trained models by introducing low-rank structures into the models. Here, we'll configure the T5 model for this process.

### Steps:

1. **Setting up the LoRA Configuration**:
    - `LoraConfig` provides the configuration settings for Low-Rank Adaptation.
        - `r`: Rank of the low-rank structure. In this instance, it's set to 8.
        - `lora_alpha`: Scaling factor for the newly introduced low-rank parameters.
        - `target_modules`: Specifies which parts of the model to apply LoRA. Here, we're targeting the "q" (query) and "v" (value) modules.
        - `lora_dropout`: Dropout rate for the low-rank parameters. Set to 0.10, or 10%.
        - `bias`: Specifies the type of bias for the low-rank projection. We've chosen "none" in this case.
        - `task_type`: Indicates the type of task. As we're using T5, the task type is set to `SEQ_2_SEQ_LM`.

2. **Applying LoRA Configuration to T5**:
    - Using the `get_peft_model` function, we apply the LoRA configuration to our pre-loaded T5 model.
    - The returned model (`policy_peft_model`) is equipped with the Peft + LoRA modifications and is ready for fine-tuning.

### Summary of this section:

Our T5 model is now prepared with Peft + LoRA adjustments. This configuration optimizes the model for more efficient fine-tuning on specific tasks while leveraging the powerful pre-trained knowledge.


In [40]:
lora_config = LoraConfig(
    r=8, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.10,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # T5
)

policy_peft_model = get_peft_model(policy_model, lora_config)
policy_peft_model.to(device)

PeftModelForSeq2SeqLM(
  (base_model): LoraModel(
    (model): T5ForConditionalGeneration(
      (shared): Embedding(32128, 768)
      (encoder): T5Stack(
        (embed_tokens): Embedding(32128, 768)
        (block): ModuleList(
          (0): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): Linear(
                    in_features=768, out_features=768, bias=False
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=768, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=8, out_features=768, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
                    (lora_embedding_B): Pa

### Analyzing Trainable Parameters in the Peft + LoRA Configured T5 Model

After applying the Peft + LoRA configuration to our T5 model, it's essential to inspect the model's parameters to understand its structure better.

### Key Insights:

1. **Trainable Parameters**:
    - This refers to the parameters that will be updated during the training process.
    - In our configured model, there are **884,736** trainable parameters.

2. **Total Parameters**:
    - This indicates the complete count of parameters present in the model, including those that are non-trainable.
    - The model consists of **223,788,288** total parameters.

3. **Percentage of Trainable Parameters**:
    - It's useful to know the fraction of the model's parameters that are trainable, as this can influence training time and model flexibility.
    - Only about **0.3953%** (or roughly 0.4%) of the entire model's parameters are trainable.

### Summary of this section:

The Peft + LoRA configuration results in a model where only a small fraction of parameters are trainable. This approach offers a balance, as it allows for specific fine-tuning while leveraging a vast pre-trained structure. The advantage is that it can lead to faster training times and might prevent overfitting, especially when training data is limited.


In [41]:
policy_peft_model.print_trainable_parameters()

trainable params: 884736 || all params: 223788288 || trainable%: 0.3953450861557152


![Alt text](image-3.png)

Image source: https://magazine.sebastianraschka.com/p/llm-training-rlhf-and-its-alternatives

## Instantiating the PPO Model with Value Head

Proximal Policy Optimization (PPO) is a reinforcement learning algorithm. In this step, we set up the model for PPO training using our earlier `policy_peft_model`.

### Key Components:

1. **AutoModelForSeq2SeqLMWithValueHead**:
    - An extension of the transformers model that includes a scalar output for each token, aiding in reinforcement learning.
    - This model can capture the value function, an estimate of future rewards.

2. **Inputs**:
    - We pass in our `policy_peft_model`, which has been configured with Peft + LoRA, as the foundation for our PPO model.
    - We set `torch_dtype` to `torch.bfloat16` for numerical precision and memory efficiency.
    - The `is_trainable` flag is set to `True`, allowing us to further fine-tune the model using our RL loop.

3. **Device Assignment**:
    - We transfer our instantiated model to the appropriate device (`device`) for computation, ensuring efficient training.

### Summary of this section:

With our PPO model instantiated, we're poised to fine-tune our summarization model using reinforcement learning with human feedback. This approach is aimed at improving the model's performance in generating summaries based on human preferences and judgments.

[More on PPO and TRL](https://huggingface.co/docs/trl/quickstart)


In [42]:
# https://huggingface.co/docs/trl/quickstart
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(policy_peft_model,                                                               
                                                               torch_dtype=torch.bfloat16,
                                                               is_trainable=True)

ppo_model.to(device)

AutoModelForSeq2SeqLMWithValueHead(
  (pretrained_model): PeftModelForSeq2SeqLM(
    (base_model): LoraModel(
      (model): T5ForConditionalGeneration(
        (shared): Embedding(32128, 768)
        (encoder): T5Stack(
          (embed_tokens): Embedding(32128, 768)
          (block): ModuleList(
            (0): T5Block(
              (layer): ModuleList(
                (0): T5LayerSelfAttention(
                  (SelfAttention): T5Attention(
                    (q): Linear(
                      in_features=768, out_features=768, bias=False
                      (lora_dropout): ModuleDict(
                        (default): Dropout(p=0.1, inplace=False)
                      )
                      (lora_A): ModuleDict(
                        (default): Linear(in_features=768, out_features=8, bias=False)
                      )
                      (lora_B): ModuleDict(
                        (default): Linear(in_features=8, out_features=768, bias=False)
                      

### Defining the Reference Model

In reinforcement learning, especially when fine-tuning models using methods like Proximal Policy Optimization (PPO), it's helpful to have a reference model. This model represents the initial state or behavior of the learner model (in this case, the Language Model) before any alignment or optimization. It aids in calculating the importance sampling ratio, a critical component for stable and effective updates in PPO.

### Key Components:

1. **create_reference_model**:
    - A function provided by Huggingface's TRL (Transformer Reinforcement Learning) library.
    - Creates a duplicate of the passed model which acts as a reference during the RL fine-tuning process.

2. **Inputs**:
    - The `policy_model` we previously defined serves as the input. This model acts as the basis for our reference model.

3. **Device Assignment**:
    - Once instantiated, we move our reference model to the specified device (`device`) for computations.

### Summary of this section:

By defining a reference model, we set a stable baseline against which we can measure and guide the progress and changes of our main model during the reinforcement learning process.

[More on TRL and Reference Models](https://huggingface.co/docs/trl/models#trl.create_reference_model)


In [43]:
ref_model = create_reference_model(policy_model)
ref_model.to(device)

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(
                in_features=768, out_features=768, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=768, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=768, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(
                i

### Preparing the Dataset for Reinforcement Learning

Reinforcement learning (RL) requires a dataset to simulate experiences and provide feedback. In our RL setup for fine-tuning a language model, we utilize a comparison dataset.

### Steps:

1. **Load Dataset**:
    - Using Huggingface's `datasets` library, we fetch the 'CarperAI/openai_summarize_comparisons' dataset's test split.

2. **Filtering**:
    - We want to ensure the prompt lengths are manageable. 
    - Filtering by word count: We retain samples where the prompt has ≤ 450 words.
    - (Alternative Filtering by character count is commented out for reference.)

3. **Shuffling and Sampling**:
    - To ensure a diverse set of samples, we shuffle the dataset.
    - We then select a subset (2,000 samples in this instance) for the RL process.

4. **Feature Extraction**:
    - From our shuffled dataset, we focus on the `prompt` and `chosen` fields. 
    - Rename the 'chosen' field to 'response' to align with the PPO library's requirements.

5. **Dataset Conversion**:
    - Convert the dictionary containing our features into a Huggingface Dataset format.

6. **Train-Eval Split**:
    - Split the dataset into training and evaluation subsets. 
    - Here, 80% of samples are designated for training, and the remaining 20% are for evaluation.

### Outcome:

By the end of this process, we will have a training dataset and an evaluation dataset ready for the RL process. These datasets will be essential in guiding the model's fine-tuning and assessing its performance during the RL loop.


In [44]:
# Load the dataset
orig_dataset = load_dataset('CarperAI/openai_summarize_comparisons', split='test')

# Filter samples where the prompt length is less than or equal to 750
filtered_dataset = orig_dataset.filter(lambda example: len(example['prompt'].split()) <= 450) # By word
#filtered_dataset = orig_dataset.filter(lambda example: len(example['prompt']) <= 1250) # By character

# Shuffle and select the first 10K samples
#shuffled_dataset = orig_dataset.shuffle(seed=42).select(range(1000))
shuffled_dataset = filtered_dataset.shuffle(seed=42).select(range(2000)) 


# Extract the desired features.  Renaming chose to response to follow the ppo library requirements.
new_dataset_dict = {
    "prompt": shuffled_dataset["prompt"],
    "response": shuffled_dataset["chosen"]
}

# Convert the dictionary to a new Dataset
dataset = HFDataset.from_dict(new_dataset_dict)

# Split the new_dataset into train_dataset and eval_dataset
split_ratio = 0.8  # 80% for training, 20% for evaluation
num_train_samples = int(split_ratio * len(dataset))
train_dataset = dataset.select(range(num_train_samples))
eval_dataset = dataset.select(range(num_train_samples, len(dataset)))

Found cached dataset parquet (C:/Users/juan_/.cache/huggingface/datasets/CarperAI___parquet/CarperAI--openai_summarize_comparisons-79d2c222a15dc8fb/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at C:\Users\juan_\.cache\huggingface\datasets\CarperAI___parquet\CarperAI--openai_summarize_comparisons-79d2c222a15dc8fb\0.0.0\2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec\cache-d5c2170aaeb9b06c.arrow
Loading cached shuffled indices for dataset at C:\Users\juan_\.cache\huggingface\datasets\CarperAI___parquet\CarperAI--openai_summarize_comparisons-79d2c222a15dc8fb\0.0.0\2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec\cache-f81caef5de9ecb07.arrow


In [45]:
print(train_dataset[0].keys())
print(eval_dataset[0].keys())

dict_keys(['prompt', 'response'])
dict_keys(['prompt', 'response'])


### Tokenization of Datasets

For reinforcement learning, it is crucial that the data is in a format understood by the model. This requires tokenizing our textual data into numerical tokens. Here, we'll use the tokenizer associated with our model (T5 in this case) to process our datasets.

### Steps:

1. **Tokenizer Initialization**:
    - Instantiate the tokenizer corresponding to our model (T5). If you use a different model, ensure you fetch the right tokenizer.

2. **Tokenization Function**:
    - Define a function (`tokenize_function`) that:
        - Processes the 'prompt' in each example of the dataset.
        - Truncates or pads the tokenized prompt to a maximum length of 512 tokens.
        - Returns the tokenized 'input_ids' for each 'prompt' and retains the associated 'response'.

3. **Apply Tokenization**:
    - Apply the `tokenize_function` to both the training and evaluation datasets using the `map` function.

### Outcome:

The datasets (`train_dataset` and `eval_dataset`) are now tokenized and in a suitable format for model ingestion during the reinforcement learning loop.


In [46]:
from transformers import T5Tokenizer

# Instantiate your tokenizer (replace T5Tokenizer with your model's tokenizer if different)
tokenizer = T5Tokenizer.from_pretrained("t5-small") # or whatever model you're using

def tokenize_function(example):
    # Tokenize the prompt and store it as input_ids. Also return the response.
    return {
        "input_ids": tokenizer(example["prompt"], return_tensors="pt", truncation=True, max_length=512)["input_ids"].squeeze(),
        "response": example["response"],
    }

# Tokenize the training and evaluation datasets
train_dataset = train_dataset.map(tokenize_function, batched=False)
eval_dataset = eval_dataset.map(tokenize_function, batched=False)


Map:   0%|          | 0/1600 [00:00<?, ? examples/s]

Map:   0%|          | 0/400 [00:00<?, ? examples/s]

In [47]:
train_dataset 

Dataset({
    features: ['prompt', 'response', 'input_ids'],
    num_rows: 1600
})

In [48]:
# Lets check one sample of the train_dataset
print(train_dataset[0])  # print the first example from the training dataset

{'prompt': "SUBREDDIT: r/relationship_advice\nTITLE: [20/m] My girlfriend [20/f] has become very distant and weird\nPOST: I have been in a relationship with my girlfriend for a little bit over 1 year. We recently had a breakup because I was distant and she thought I was cheating on her (which I wasn't). Before the breakup, she wanted to spend as much time with me as she could, but recently she has been very distant. We used to go to eachothers places overnight almost daily, but nowadays she does not want to come over to my place or want me to go over to hers (We both live on our own). She also used to talk to me all the time on facebook, but now she pretty much only replies to what I talk, and does not try to keep the conversation going. She has became pretty slow at replying, but when I'm with her, she replies instantly to her other friends who text her. \n\nI'm really lost at this situation, because I feel like she does not want to be with me anymore. I know that she's taking SSRI me

### Hyperparameter Initialization

Before training the model using reinforcement learning, we need to define several hyperparameters that will guide and constrain the training process.

### Data Collation:

- **`collator` Function**: 
    - A helper function that takes a list of data samples and merges them into a single batch, making it suitable for processing by the model.
    - For instance, given an input of individual key-value data samples, the function groups the values by their keys.

    Example:
    ```python
    test_data = [{"key1": "value1", "key2": "value2", "key3": "value3"}, {"key1": "value4", "key2": "value5", "key3": "value6"}]
    collated_data = collator(test_data)
    ```

- **Sample Data**:
    - To visually validate the output of the `collator`, a sample is taken from the training dataset and processed.

### Key Hyperparameters:

- **`learning_rate`**: 
    - Controls the step size at each iteration while moving towards a minimum in the loss function. Set to `1.41e-5`.

- **`max_ppo_epochs`**: 
    - Specifies the maximum number of epochs for the Proximal Policy Optimization (PPO) training. Set to `3`.

- **`mini_batch_size`** & **`batch_size`**: 
    - Determines the number of samples in each mini-batch (`4`) and the overall batch size (`16`).

- **`DEFAULT_REJECTED_SUMMARY_TEXT`**: 
    - A placeholder text for a bad summary. This could potentially act as a regularizer during training, though its effect needs to be verified. 

- **Generation Constraints** (`generation_kwargs`):
    - `temperature`: Controls the randomness of predictions by scaling the logits before applying softmax. Set to `1.0`.
    - `min_length`: Minimum length of the generated text. Set to `5`.
    - `top_k` & `top_p`: Parameters controlling the nucleus sampling method. Here, `top_k` is set to `0.0` and `top_p` to `1.0`, indicating no truncation based on these parameters.
    - `do_sample`: Boolean value determining whether to sample the outputs. Set to `True`.

- **Output Length Sampling**:
    - `output_min_length` & `output_max_length`: Define the minimum (`100`) and maximum (`400`) lengths of generated outputs.
    - `output_length_sampler`: Samples an output length between the specified min and max values.

- **`max_ppo_steps`**: 
    - Determines the total number of PPO steps during training. Set to `100`.


In [None]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

test_data = [{"key1": "value1", "key2": "value2", "key3": "value3"}, {"key1": "value4", "key2": "value5", "key3": "value6"}]
print(f'Collator input: {test_data}')
print(f'Collator output: {collator(test_data)}')

# Lets sample what the collator generates:
sample_data = [train_dataset[i] for i in range(3)]  # take first three examples
collated_data = collator(sample_data)
print(collated_data.keys())

learning_rate=1.41e-5
max_ppo_epochs=3
mini_batch_size=4
batch_size=16

# This is a HACK... lets see how this works out. May casue bias or may help. The good side is that this, being constant, can effect some type of regularization, preventing the model from gravitating too much towards any specific pattern in the training data.  Just a thought.
DEFAULT_REJECTED_SUMMARY_TEXT = "This is a bad summary"

# Some initial values
output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)

# These hyperparams guide the generation of the completion in the policy model. We could add other params like temperature.
generation_kwargs = {
    "temperature": 1.0,
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}

max_ppo_steps = 100


### Configuration for PPO Training

We leverage the `PPOConfig` from the Hugging Face `trl` library to set up the configuration required for the Proximal Policy Optimization (PPO) training.

The `PPOConfig` requires and/or allows for a number of arguments that define the behavior of the PPO training loop:

- **`model_name`**: 
    - Name of the model. Here, it is set as `policy_model_name`.

- **`learning_rate`**: 
    - The rate at which the model adjusts based on the error during training. We've set it to the previously initialized value of `learning_rate`.

- **`ppo_epochs`**: 
    - Specifies the number of epochs for PPO training. Set to the previously defined `max_ppo_epochs`.

- **`mini_batch_size`**: 
    - The size of the smaller batches that the main batch is divided into, during training. Set to the previously initialized value of `mini_batch_size`.

- **`batch_size`**: 
    - The number of data samples processed during each training step. We've set it to the previously initialized value of `batch_size`.

For a more detailed understanding and potential additional configurations, one can refer to the [Hugging Face documentation on `trl.trainer`](https://huggingface.co/docs/trl/trainer).


In [52]:
# Check out https://huggingface.co/docs/trl/trainer

config = PPOConfig(
    model_name=policy_model_name,    
    learning_rate=learning_rate,
    ppo_epochs=max_ppo_epochs,
    mini_batch_size=mini_batch_size,
    batch_size=batch_size
)

### Setting Up the PPO Trainer

To fine-tune the model using Proximal Policy Optimization (PPO), we use the `PPOTrainer` class from Hugging Face's `trl` library.

The `PPOTrainer` class is initialized with several key arguments:

- **`config`**: 
    - The configuration object created using `PPOConfig`. This contains the hyperparameters required for PPO training.

- **`model`**: 
    - The model that will be fine-tuned. In this case, it is the `ppo_model` which was previously instantiated.

- **`ref_model`**: 
    - The reference model, representing the model before alignment. We use `ref_model` for this purpose.

- **`tokenizer`**: 
    - The tokenizer responsible for converting text into tokens suitable for the model's input. Here, it's the `policy_tokenizer` we set up before.

- **`dataset`**: 
    - The training dataset. We use the tokenized `train_dataset`.

- **`data_collator`**: 
    - A function to transform a list of samples to a batch. We use the `collator` function we defined earlier.

This trainer will be used to conduct the PPO training loop, enabling us to fine-tune the model using reinforcement learning. 

For a deeper dive into the functionalities provided by the `PPOTrainer` class, one can refer to the [Hugging Face documentation on `trl.trainer`](https://huggingface.co/docs/trl/trainer).



In [53]:
# Check out https://huggingface.co/docs/trl/trainer

ppo_trainer = PPOTrainer(config=config, 
                         model=ppo_model, 
                         ref_model=ref_model, 
                         tokenizer=policy_tokenizer, 
                         dataset=train_dataset, 
                         data_collator=collator)

## Fine-Tuning with Reinforcement Learning

Reinforcement learning offers a unique approach to fine-tuning models. The underlying principle is to allow the model to learn by receiving feedback (rewards) on its actions. In this context, an action would be generating a summary for a given text prompt.

### Training Loop Overview

The training loop we've crafted here follows this sequence of steps:

1. **Model Prediction**: Using the policy language model (`ppo_trainer` in this case), we generate predicted summaries.
2. **Score Generation**: We then pass these summaries to a reward model to assign a score (reward) based on the quality of the generated summary.
3. **Model Update**: With the generated summaries and their respective scores, we use Proximal Policy Optimization (PPO) to update our policy language model.

### Detailed Breakdown

#### **1. Model Prediction**:

- We iterate through our training data in batches (`prompt_tensors`).
- For each prompt, we predict a summary (`summary_tensors`). This prediction is based on the generation hyperparameters we've specified (`generation_kwargs`), which guide the sampling strategy.

#### **2. Score Generation**:

- For each summary, we calculate a score by comparing it with a default rejected summary. 
- This step uses a separate reward model (`rm_model`), which assesses the quality of summaries.

#### **3. Model Update**:

- Using PPO, we update our policy model based on:
  - The initial input (`prompt_tensors`).
  - The generated summary (`summary_tensors`).
  - The assigned reward (`reward_tensors`).
  
### Key Metrics:

- `objective/kl`: Measures how different the policy's action distribution after the update is from the action distribution before the update. PPO tries to make these changes very small to avoid drastic changes.
  
- `ppo/returns/mean`: This is the average return achieved by the agent. Higher is better.

- `ppo/policy/advantages_mean`: Measures how much better an action is than the average action at a given state. An advantage of zero means the action is just average, a positive advantage means it's better than average, and a negative one means it's worse than average.

### Important Notes:

- **HACK** Alert: The code seems to contain certain hacks (like for handling variable sequence lengths) which are generally used to overcome specific issues during development. It's always good to revisit and see if there's a cleaner approach.

- **Reward Model**: It's crucial that the reward model (`rm_model`) is robust. The quality of the model training largely depends on the feedback it provides.

### References:

- [PPOTrainer in Hugging Face's TRL library](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer)
- [Using Transformer Reinforcement Learning to detoxify generative language models](https://medium.com/@ben.burtenshaw/using-transformer-reinforcement-learning-to-detoxify-generative-language-models-5198446d6786)
- HuggingFace's example scripts in their GitHub repository.

The success of reinforcement learning is deeply intertwined with the feedback mechanism and the quality of the reward signal.


In [54]:
for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    if step >= max_ppo_steps: # Break when we reach max_steps.
        break   

    prompt_tensors = batch["input_ids"]

    if isinstance(prompt_tensors, list) and all(isinstance(item, list) for item in prompt_tensors): # HACK!!! Check if original_prompt_tensors is a list of lists     
        lengths = [len(seq) for seq in prompt_tensors] # Verify if sequences have fixed or variable length
        unique_lengths = set(lengths)
        
        if len(unique_lengths) > 1: # If sequences have variable lengths, pad them
            max_length = max(unique_lengths)
            original_prompt_tensors = [seq + [0] * (max_length - len(seq)) for seq in prompt_tensors]  # padding with zeros
            
        prompt_tensors = [torch.tensor(seq).to(device) for seq in prompt_tensors] # Convert original_prompt_tensors to individual tensors
    
    summary_tensors = []

    for prompt_tensor in prompt_tensors:
        prompt_tensor = torch.tensor(prompt_tensor).to(device)
        max_new_tokens = output_length_sampler()             
        generation_kwargs["max_new_tokens"] = max_new_tokens
        summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)
        summary_tensors.append(summary.squeeze()[-max_new_tokens:])

    batch["response"] = [policy_tokenizer.decode(r.squeeze()) for r in summary_tensors]

    chosen_summaries = batch["response"]
    rejected_summaries = [DEFAULT_REJECTED_SUMMARY_TEXT] * len(batch["response"]) 

    reward_tensors = []

    for chosen_summary, rejected_summary in zip(chosen_summaries, rejected_summaries):
        chosen_score, _, _, _ = score_summaries(rm_model, rm_tokenizer, chosen_summary, rejected_summary)
        reward_tensors.append(torch.tensor(chosen_score))
    
    stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)
    
    print(f'objective/kl: {stats["objective/kl"]}') # Measures how different the policy's action distribution after the update is from the action distribution before the update. PPO tries to make these changes very small to avoid sudden changes.
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}') # This is the average return achieved by the agent. Higher is better.
    print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}') # Measures how much better an action is than the average action at a given state.
    print('-'.join('' for x in range(100)))

  prompt_tensor = torch.tensor(prompt_tensor).to(device)
1it [00:16, 16.83s/it]

objective/kl: 0.0
ppo/returns/mean: 0.3847939372062683
ppo/policy/advantages_mean: 0.00486169196665287
---------------------------------------------------------------------------------------------------


2it [00:34, 17.17s/it]

objective/kl: -0.005069880746304989
ppo/returns/mean: 0.38824570178985596
ppo/policy/advantages_mean: 0.001255544601008296
---------------------------------------------------------------------------------------------------


3it [00:53, 17.98s/it]

objective/kl: -0.027740254998207092
ppo/returns/mean: 0.4024064540863037
ppo/policy/advantages_mean: 0.005844333209097385
---------------------------------------------------------------------------------------------------


4it [01:08, 16.89s/it]

objective/kl: 0.008470938540995121
ppo/returns/mean: 0.4360160231590271
ppo/policy/advantages_mean: 0.008006769232451916
---------------------------------------------------------------------------------------------------


5it [01:26, 17.16s/it]

objective/kl: -0.0028168456628918648
ppo/returns/mean: 0.4166437089443207
ppo/policy/advantages_mean: 0.0039040702395141125
---------------------------------------------------------------------------------------------------


6it [01:43, 17.11s/it]

objective/kl: -0.015527371317148209
ppo/returns/mean: 0.4231906831264496
ppo/policy/advantages_mean: 0.0028870929963886738
---------------------------------------------------------------------------------------------------


7it [02:00, 17.38s/it]

objective/kl: -0.0020126374438405037
ppo/returns/mean: 0.4269492030143738
ppo/policy/advantages_mean: 0.0030096229165792465
---------------------------------------------------------------------------------------------------


8it [02:20, 17.93s/it]

objective/kl: 0.06836467236280441
ppo/returns/mean: 0.41612517833709717
ppo/policy/advantages_mean: 0.0038263590540736914
---------------------------------------------------------------------------------------------------


9it [02:37, 17.70s/it]

objective/kl: 0.08298622071743011
ppo/returns/mean: 0.44846484065055847
ppo/policy/advantages_mean: 0.0022622086107730865
---------------------------------------------------------------------------------------------------


10it [02:53, 17.15s/it]

objective/kl: 0.005825418047606945
ppo/returns/mean: 0.46129149198532104
ppo/policy/advantages_mean: 0.003759412094950676
---------------------------------------------------------------------------------------------------


11it [03:11, 17.38s/it]

objective/kl: 0.005642858799546957
ppo/returns/mean: 0.4399201571941376
ppo/policy/advantages_mean: 0.0013956755865365267
---------------------------------------------------------------------------------------------------


12it [03:27, 17.02s/it]

objective/kl: 0.005834147334098816
ppo/returns/mean: 0.4514174461364746
ppo/policy/advantages_mean: 0.003677933244034648
---------------------------------------------------------------------------------------------------


13it [03:44, 17.21s/it]

objective/kl: 0.011828919872641563
ppo/returns/mean: 0.4530404210090637
ppo/policy/advantages_mean: 0.0011673825792968273
---------------------------------------------------------------------------------------------------


14it [04:01, 16.92s/it]

objective/kl: 0.033650025725364685
ppo/returns/mean: 0.4805707633495331
ppo/policy/advantages_mean: 0.004944793879985809
---------------------------------------------------------------------------------------------------


15it [04:19, 17.21s/it]

objective/kl: 0.04556742310523987
ppo/returns/mean: 0.4878147542476654
ppo/policy/advantages_mean: 0.00022911513224244118
---------------------------------------------------------------------------------------------------


16it [04:36, 17.38s/it]

objective/kl: -0.052357759326696396
ppo/returns/mean: 0.47049078345298767
ppo/policy/advantages_mean: 0.0005096103996038437
---------------------------------------------------------------------------------------------------


17it [04:54, 17.58s/it]

objective/kl: -0.009967220947146416
ppo/returns/mean: 0.4759097099304199
ppo/policy/advantages_mean: -0.00103578413836658
---------------------------------------------------------------------------------------------------


18it [05:12, 17.57s/it]

objective/kl: -0.02261665277183056
ppo/returns/mean: 0.4794938266277313
ppo/policy/advantages_mean: 0.004009743221104145
---------------------------------------------------------------------------------------------------


19it [05:31, 18.14s/it]

objective/kl: 0.0016215275973081589
ppo/returns/mean: 0.4834381937980652
ppo/policy/advantages_mean: 0.002844936214387417
---------------------------------------------------------------------------------------------------


20it [05:51, 18.66s/it]

objective/kl: -0.014883600175380707
ppo/returns/mean: 0.4874545931816101
ppo/policy/advantages_mean: 0.001672036712989211
---------------------------------------------------------------------------------------------------


21it [06:11, 18.87s/it]

objective/kl: 0.038705311715602875
ppo/returns/mean: 0.48216789960861206
ppo/policy/advantages_mean: 0.005550017114728689
---------------------------------------------------------------------------------------------------


22it [06:27, 18.22s/it]

objective/kl: -0.0019519738852977753
ppo/returns/mean: 0.4977434575557709
ppo/policy/advantages_mean: 0.002608383074402809
---------------------------------------------------------------------------------------------------


23it [06:46, 18.22s/it]

objective/kl: 0.015746327117085457
ppo/returns/mean: 0.4924009442329407
ppo/policy/advantages_mean: -0.0029126708395779133
---------------------------------------------------------------------------------------------------


24it [07:03, 17.84s/it]

objective/kl: 0.13578671216964722
ppo/returns/mean: 0.49223825335502625
ppo/policy/advantages_mean: -0.0014680366730317473
---------------------------------------------------------------------------------------------------


25it [07:18, 17.11s/it]

objective/kl: 0.18424507975578308
ppo/returns/mean: 0.5160964727401733
ppo/policy/advantages_mean: 0.0035873521119356155
---------------------------------------------------------------------------------------------------


26it [07:36, 17.30s/it]

objective/kl: 0.07853008806705475
ppo/returns/mean: 0.5125787854194641
ppo/policy/advantages_mean: 0.0018142350018024445
---------------------------------------------------------------------------------------------------


27it [07:54, 17.66s/it]

objective/kl: 0.1536068469285965
ppo/returns/mean: 0.5129472017288208
ppo/policy/advantages_mean: -0.0031952625140547752
---------------------------------------------------------------------------------------------------


28it [08:13, 17.97s/it]

objective/kl: -0.02651580050587654
ppo/returns/mean: 0.5035720467567444
ppo/policy/advantages_mean: 0.0005308630643412471
---------------------------------------------------------------------------------------------------


29it [08:33, 18.51s/it]

objective/kl: 0.04060036689043045
ppo/returns/mean: 0.5135884284973145
ppo/policy/advantages_mean: 0.002005363814532757
---------------------------------------------------------------------------------------------------


30it [08:52, 18.82s/it]

objective/kl: 0.06144071742892265
ppo/returns/mean: 0.5074781179428101
ppo/policy/advantages_mean: 0.006083873566240072
---------------------------------------------------------------------------------------------------


31it [09:09, 18.14s/it]

objective/kl: 0.1026630848646164
ppo/returns/mean: 0.497572124004364
ppo/policy/advantages_mean: -0.004282218404114246
---------------------------------------------------------------------------------------------------


32it [09:24, 17.31s/it]

objective/kl: 0.11476826667785645
ppo/returns/mean: 0.5155744552612305
ppo/policy/advantages_mean: -0.0009203864028677344
---------------------------------------------------------------------------------------------------


33it [09:41, 17.14s/it]

objective/kl: 0.018813099712133408
ppo/returns/mean: 0.5048032999038696
ppo/policy/advantages_mean: 0.0009756293147802353
---------------------------------------------------------------------------------------------------


34it [09:59, 17.34s/it]

objective/kl: -0.004020830616354942
ppo/returns/mean: 0.50152587890625
ppo/policy/advantages_mean: -0.0017703983467072248
---------------------------------------------------------------------------------------------------


35it [10:16, 17.33s/it]

objective/kl: 0.028400693088769913
ppo/returns/mean: 0.5098116397857666
ppo/policy/advantages_mean: -0.00064779695821926
---------------------------------------------------------------------------------------------------


36it [10:35, 17.76s/it]

objective/kl: 0.004725713282823563
ppo/returns/mean: 0.5155090093612671
ppo/policy/advantages_mean: 0.005265518091619015
---------------------------------------------------------------------------------------------------


37it [10:53, 17.94s/it]

objective/kl: 0.03030284121632576
ppo/returns/mean: 0.5224802494049072
ppo/policy/advantages_mean: 0.000699183321557939
---------------------------------------------------------------------------------------------------


38it [11:12, 18.25s/it]

objective/kl: 0.0028322823345661163
ppo/returns/mean: 0.5142924785614014
ppo/policy/advantages_mean: 0.00593021884560585
---------------------------------------------------------------------------------------------------


39it [11:30, 18.30s/it]

objective/kl: 0.06499931961297989
ppo/returns/mean: 0.5349175930023193
ppo/policy/advantages_mean: -0.0069395094178617
---------------------------------------------------------------------------------------------------


40it [11:50, 18.67s/it]

objective/kl: 0.015764307230710983
ppo/returns/mean: 0.5187349915504456
ppo/policy/advantages_mean: -0.0023731673136353493
---------------------------------------------------------------------------------------------------


41it [12:07, 18.06s/it]

objective/kl: 0.018244529142975807
ppo/returns/mean: 0.5299811363220215
ppo/policy/advantages_mean: -0.00019256211817264557
---------------------------------------------------------------------------------------------------


42it [12:23, 17.57s/it]

objective/kl: -0.058882132172584534
ppo/returns/mean: 0.5251530408859253
ppo/policy/advantages_mean: -0.0028166293632239103
---------------------------------------------------------------------------------------------------


43it [12:42, 18.02s/it]

objective/kl: -0.01506706140935421
ppo/returns/mean: 0.5216745138168335
ppo/policy/advantages_mean: -0.007766125723719597
---------------------------------------------------------------------------------------------------


44it [13:00, 17.97s/it]

objective/kl: 0.0036642001941800117
ppo/returns/mean: 0.5320895910263062
ppo/policy/advantages_mean: -0.000767915858887136
---------------------------------------------------------------------------------------------------


45it [13:18, 17.86s/it]

objective/kl: -0.03506341576576233
ppo/returns/mean: 0.5339188575744629
ppo/policy/advantages_mean: 0.010133009403944016
---------------------------------------------------------------------------------------------------


46it [13:37, 18.21s/it]

objective/kl: -0.03634041175246239
ppo/returns/mean: 0.5373767018318176
ppo/policy/advantages_mean: -0.0011674811830744147
---------------------------------------------------------------------------------------------------


47it [13:55, 18.19s/it]

objective/kl: 0.0006244629621505737
ppo/returns/mean: 0.5391950011253357
ppo/policy/advantages_mean: -0.008310972712934017
---------------------------------------------------------------------------------------------------


48it [14:11, 17.56s/it]

objective/kl: 0.09847931563854218
ppo/returns/mean: 0.5256244540214539
ppo/policy/advantages_mean: -0.004047184716910124
---------------------------------------------------------------------------------------------------


49it [14:27, 17.18s/it]

objective/kl: 0.007130159065127373
ppo/returns/mean: 0.5404835939407349
ppo/policy/advantages_mean: -0.0006340897525660694
---------------------------------------------------------------------------------------------------


50it [14:44, 17.19s/it]

objective/kl: 0.09234826266765594
ppo/returns/mean: 0.5280537605285645
ppo/policy/advantages_mean: 0.00522202393040061
---------------------------------------------------------------------------------------------------


51it [15:01, 17.10s/it]

objective/kl: -0.011652393266558647
ppo/returns/mean: 0.5378813743591309
ppo/policy/advantages_mean: 0.006647405680269003
---------------------------------------------------------------------------------------------------


52it [15:18, 17.06s/it]

objective/kl: -0.010729705914855003
ppo/returns/mean: 0.5417889356613159
ppo/policy/advantages_mean: -0.0059839775785803795
---------------------------------------------------------------------------------------------------


53it [15:36, 17.42s/it]

objective/kl: 0.0522296279668808
ppo/returns/mean: 0.5455126166343689
ppo/policy/advantages_mean: 0.003960065543651581
---------------------------------------------------------------------------------------------------


54it [15:53, 17.02s/it]

objective/kl: 0.07429303228855133
ppo/returns/mean: 0.5376936793327332
ppo/policy/advantages_mean: 0.003455133643001318
---------------------------------------------------------------------------------------------------


55it [16:09, 16.90s/it]

objective/kl: -0.015849264338612556
ppo/returns/mean: 0.5331861972808838
ppo/policy/advantages_mean: -0.0017464521806687117
---------------------------------------------------------------------------------------------------


56it [16:26, 17.00s/it]

objective/kl: 0.10449449717998505
ppo/returns/mean: 0.5360128879547119
ppo/policy/advantages_mean: 0.0010660793632268906
---------------------------------------------------------------------------------------------------


57it [16:43, 16.93s/it]

objective/kl: 0.018445100635290146
ppo/returns/mean: 0.5541489124298096
ppo/policy/advantages_mean: 0.00361231598071754
---------------------------------------------------------------------------------------------------


58it [16:59, 16.50s/it]

objective/kl: 0.13619235157966614
ppo/returns/mean: 0.5405611991882324
ppo/policy/advantages_mean: 0.002603059634566307
---------------------------------------------------------------------------------------------------


59it [17:16, 16.88s/it]

objective/kl: 0.054558731615543365
ppo/returns/mean: 0.5453513860702515
ppo/policy/advantages_mean: -0.0017076923977583647
---------------------------------------------------------------------------------------------------


60it [17:34, 17.17s/it]

objective/kl: 0.10614021122455597
ppo/returns/mean: 0.5431622266769409
ppo/policy/advantages_mean: 0.002949802204966545
---------------------------------------------------------------------------------------------------


61it [17:52, 17.46s/it]

objective/kl: 0.008595196530222893
ppo/returns/mean: 0.5485168695449829
ppo/policy/advantages_mean: 0.0012846844037994742
---------------------------------------------------------------------------------------------------


62it [18:10, 17.48s/it]

objective/kl: 0.06707711517810822
ppo/returns/mean: 0.5445464849472046
ppo/policy/advantages_mean: 0.0017537561943754554
---------------------------------------------------------------------------------------------------


63it [18:28, 17.63s/it]

objective/kl: 0.059157513082027435
ppo/returns/mean: 0.5462327003479004
ppo/policy/advantages_mean: 0.007264412939548492
---------------------------------------------------------------------------------------------------


64it [18:48, 18.35s/it]

objective/kl: 0.019886992871761322
ppo/returns/mean: 0.5484727621078491
ppo/policy/advantages_mean: 0.005174936726689339
---------------------------------------------------------------------------------------------------


65it [19:06, 18.11s/it]

objective/kl: 0.15200024843215942
ppo/returns/mean: 0.548679530620575
ppo/policy/advantages_mean: 0.0005625371704809368
---------------------------------------------------------------------------------------------------


66it [19:22, 17.60s/it]

objective/kl: 0.07403381168842316
ppo/returns/mean: 0.5488724112510681
ppo/policy/advantages_mean: -0.003060923656448722
---------------------------------------------------------------------------------------------------


67it [19:42, 18.27s/it]

objective/kl: 0.05367913842201233
ppo/returns/mean: 0.5376991033554077
ppo/policy/advantages_mean: 0.004013826604932547
---------------------------------------------------------------------------------------------------


68it [19:59, 17.96s/it]

objective/kl: 0.03038051351904869
ppo/returns/mean: 0.5573604106903076
ppo/policy/advantages_mean: 0.0009522270411252975
---------------------------------------------------------------------------------------------------


69it [20:14, 17.15s/it]

objective/kl: -0.0014824792742729187
ppo/returns/mean: 0.5564751625061035
ppo/policy/advantages_mean: -0.0011386226397007704
---------------------------------------------------------------------------------------------------


70it [20:30, 16.67s/it]

objective/kl: 0.04659070819616318
ppo/returns/mean: 0.5504528880119324
ppo/policy/advantages_mean: -0.006692121736705303
---------------------------------------------------------------------------------------------------


71it [20:48, 17.14s/it]

objective/kl: -0.09201271086931229
ppo/returns/mean: 0.5490429401397705
ppo/policy/advantages_mean: -0.006128481589257717
---------------------------------------------------------------------------------------------------


72it [21:06, 17.54s/it]

objective/kl: -0.06070361286401749
ppo/returns/mean: 0.5590073466300964
ppo/policy/advantages_mean: -0.0016573065659031272
---------------------------------------------------------------------------------------------------


73it [21:24, 17.43s/it]

objective/kl: 0.03430064767599106
ppo/returns/mean: 0.5551648736000061
ppo/policy/advantages_mean: 0.002447400940582156
---------------------------------------------------------------------------------------------------


74it [21:42, 17.73s/it]

objective/kl: -0.06824635714292526
ppo/returns/mean: 0.5516186952590942
ppo/policy/advantages_mean: -0.004368680063635111
---------------------------------------------------------------------------------------------------


75it [21:59, 17.45s/it]

objective/kl: 0.019973933696746826
ppo/returns/mean: 0.5668906569480896
ppo/policy/advantages_mean: 0.0025545943062752485
---------------------------------------------------------------------------------------------------


76it [22:18, 18.05s/it]

objective/kl: -0.015159569680690765
ppo/returns/mean: 0.5665764212608337
ppo/policy/advantages_mean: 0.00242402171716094
---------------------------------------------------------------------------------------------------


77it [22:36, 18.08s/it]

objective/kl: 0.05545371025800705
ppo/returns/mean: 0.5569909811019897
ppo/policy/advantages_mean: 0.0012659328058362007
---------------------------------------------------------------------------------------------------


78it [22:54, 18.01s/it]

objective/kl: 0.01855643279850483
ppo/returns/mean: 0.5615234375
ppo/policy/advantages_mean: -0.006460611708462238
---------------------------------------------------------------------------------------------------


79it [23:12, 17.98s/it]

objective/kl: 0.11219315975904465
ppo/returns/mean: 0.5545932650566101
ppo/policy/advantages_mean: 0.00518676545470953
---------------------------------------------------------------------------------------------------


80it [23:30, 17.95s/it]

objective/kl: -0.04506971687078476
ppo/returns/mean: 0.562433660030365
ppo/policy/advantages_mean: -0.0017066728323698044
---------------------------------------------------------------------------------------------------


81it [23:45, 17.18s/it]

objective/kl: 0.09495706856250763
ppo/returns/mean: 0.5601772665977478
ppo/policy/advantages_mean: 0.0048878975212574005
---------------------------------------------------------------------------------------------------


82it [24:02, 17.00s/it]

objective/kl: 0.027638882398605347
ppo/returns/mean: 0.5567706823348999
ppo/policy/advantages_mean: -0.005091940518468618
---------------------------------------------------------------------------------------------------


83it [24:18, 16.82s/it]

objective/kl: 0.017184820026159286
ppo/returns/mean: 0.5517389178276062
ppo/policy/advantages_mean: 0.002571543212980032
---------------------------------------------------------------------------------------------------


84it [24:34, 16.38s/it]

objective/kl: 0.08351362496614456
ppo/returns/mean: 0.5540848970413208
ppo/policy/advantages_mean: -0.00670247245579958
---------------------------------------------------------------------------------------------------


85it [24:51, 16.67s/it]

objective/kl: 0.09037359058856964
ppo/returns/mean: 0.5606344938278198
ppo/policy/advantages_mean: -0.0025991201400756836
---------------------------------------------------------------------------------------------------


86it [25:09, 16.92s/it]

objective/kl: -0.007849142886698246
ppo/returns/mean: 0.5723594427108765
ppo/policy/advantages_mean: -0.0017869044095277786
---------------------------------------------------------------------------------------------------


87it [25:25, 16.84s/it]

objective/kl: 0.05883844196796417
ppo/returns/mean: 0.5611841678619385
ppo/policy/advantages_mean: -0.003919322043657303
---------------------------------------------------------------------------------------------------


88it [25:43, 17.15s/it]

objective/kl: 0.1022462546825409
ppo/returns/mean: 0.5655121803283691
ppo/policy/advantages_mean: -0.007099844049662352
---------------------------------------------------------------------------------------------------


89it [25:59, 16.64s/it]

objective/kl: 0.08246717602014542
ppo/returns/mean: 0.5580060482025146
ppo/policy/advantages_mean: -0.0024183127097785473
---------------------------------------------------------------------------------------------------


90it [26:16, 16.82s/it]

objective/kl: 0.05359852313995361
ppo/returns/mean: 0.5517401695251465
ppo/policy/advantages_mean: -0.0005698163877241313
---------------------------------------------------------------------------------------------------


91it [26:35, 17.51s/it]

objective/kl: 0.041515789926052094
ppo/returns/mean: 0.5618153810501099
ppo/policy/advantages_mean: -0.010931842029094696
---------------------------------------------------------------------------------------------------


92it [26:54, 17.86s/it]

objective/kl: -0.043928734958171844
ppo/returns/mean: 0.5648149847984314
ppo/policy/advantages_mean: -0.001079285517334938
---------------------------------------------------------------------------------------------------


93it [27:10, 17.46s/it]

objective/kl: 0.04435547813773155
ppo/returns/mean: 0.5669693350791931
ppo/policy/advantages_mean: -0.0014983447035774589
---------------------------------------------------------------------------------------------------


94it [27:26, 16.81s/it]

objective/kl: -0.01745159551501274
ppo/returns/mean: 0.5599812865257263
ppo/policy/advantages_mean: -0.001669015153311193
---------------------------------------------------------------------------------------------------


95it [27:43, 17.11s/it]

objective/kl: -0.019577058032155037
ppo/returns/mean: 0.5756101012229919
ppo/policy/advantages_mean: -0.0012546818470582366
---------------------------------------------------------------------------------------------------


96it [28:02, 17.71s/it]

objective/kl: 0.004033096134662628
ppo/returns/mean: 0.5709392428398132
ppo/policy/advantages_mean: -0.00309437932446599
---------------------------------------------------------------------------------------------------


97it [28:19, 17.46s/it]

objective/kl: 0.027891982346773148
ppo/returns/mean: 0.5759447813034058
ppo/policy/advantages_mean: 0.006105329841375351
---------------------------------------------------------------------------------------------------


98it [28:37, 17.57s/it]

objective/kl: -0.030279502272605896
ppo/returns/mean: 0.5550879836082458
ppo/policy/advantages_mean: 0.000660043559037149
---------------------------------------------------------------------------------------------------


99it [28:56, 18.04s/it]

objective/kl: 0.04733710736036301
ppo/returns/mean: 0.5625548958778381
ppo/policy/advantages_mean: -0.013566880486905575
---------------------------------------------------------------------------------------------------


100it [29:13, 17.54s/it]

objective/kl: 0.12012763321399689
ppo/returns/mean: 0.5627652406692505
ppo/policy/advantages_mean: 0.008588030003011227
---------------------------------------------------------------------------------------------------





## Saving the Model and Tokenizer

After the fine-tuning process, it's crucial to save the model's weights and the tokenizer's configuration for future use, whether it's for inference, further training, or sharing with the community.

### 1. Saving the Model

To preserve the state of your model post-training, use the `save_pretrained` method:


In [57]:
ppo_model_path = "./YOUR/PATH/HERE"

# Save the model
ppo_model.save_pretrained(ppo_model_path)

# Save the tokenizer
policy_tokenizer.save_pretrained(ppo_model_path)

('./model_ppo_jco_v1\\tokenizer_config.json',
 './model_ppo_jco_v1\\special_tokens_map.json',
 './model_ppo_jco_v1\\spiece.model',
 './model_ppo_jco_v1\\added_tokens.json')

## Inference using the Fine-tuned Model

After saving the fine-tuned model, the next step is to utilize it for generating summaries. The model will produce outputs based on the knowledge it acquired during the RL fine-tuning process.

### Loading the Model

To load the model, we will use the `AutoModelForSeq2SeqLMWithValueHead` class from the `trl` library. This class is tailored for sequence-to-sequence tasks and also has the value head which was required for the Proximal Policy Optimization (PPO) algorithm:


In [8]:
ppo_saved_model_path = "JuanKO/rlhf_ppo_model"

from trl import AutoModelForSeq2SeqLMWithValueHead # https://huggingface.co/docs/trl/quickstart
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(ppo_saved_model_path)

from transformers import AutoTokenizer
policy_tokenizer = AutoTokenizer.from_pretrained(ppo_saved_model_path)

Downloading (…)/adapter_config.json:   0%|          | 0.00/338 [00:00<?, ?B/s]

Downloading adapter_model.bin:   0%|          | 0.00/3.59M [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/4.16k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.37k [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

### Function for Generating Summaries

In order to simplify the inference process and generate summaries for new prompts, a dedicated function `generate_summary` has been defined. This function uses the trained model, its tokenizer, and other parameters to produce concise and relevant summaries for input text.


In [14]:
def generate_summary(prompt: str, model, tokenizer, generation_kwargs, output_length_sampler) -> str:
    """
    Generate a summary for a given prompt using a trained policy model.
    
    Args:
    - prompt (str): The input text for which a summary needs to be generated.
    - model: The trained policy model.
    - tokenizer: The tokenizer used for the policy model.
    - generation_kwargs (dict): Arguments used for response generation.
    - output_length_sampler (func): Function to sample the length of the output.

    Returns:
    - str: Generated summary.
    """

    # Tokenize the prompt
    prompt_tensor = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    # Ensure it's only one tensor and check its shape
    assert prompt_tensor.dim() == 2, f"Unexpected tensor shape: {prompt_tensor.shape}"
    
    # Set the generation arguments
    max_new_tokens = output_length_sampler()
    generation_kwargs["max_new_tokens"] = max_new_tokens
    
    # Generate a summary
    summary_tensor = model.generate(input_ids=prompt_tensor, **generation_kwargs)
    
    # Decode and return the summary
    summary = tokenizer.decode(summary_tensor[0], skip_special_tokens=True)
    return summary


In [13]:
# text = "SUBREDDIT: r/relationships TITLE: How do I/do I at all [20 F] tell my boyfriend [23 M] that I'm bisexual? POST: I've had two serious relationships prior to this one, both with women. They had no problem with me being bisexual and it was something known before the relationship -- my first girlfriend was also bisexual. I am now in a relationship with a guy. We've been exclusive for about a month. Having never faced this issue, I come to you, Reddit. Is this something that he needs to know? Is it really relevant to a hetero relationship, regardless of if one of the participants in the relationship is bisexual? If you guys think it is necessary, when do you think is the right time? I think my biggest fear is losing him because of it. I know that I should be with someone who is fine with who I am, but I really like the guy and I'd hate for my sexual orientation to be the thing that kills this."
# text = "SUBREDDIT: r/legaladvice TITLE: What can I do legally to restore water to my condominium!? POST: Hi, I live in SE Michigan in a condominium complex. Our water was shut off due to non-payment. (we recieved no notice) and we had to pay all that was due ($1500) We payed this yesterday at 2, they said the water would be turned on immediately. It wasn't. It's now the next day. The lady in our assosciation keeps insisting that the water meter is in another condo. Which we can't access because the person living there is never there (it's being rented) Now we're stuck with no water, no shower, no teeth brushing, no toilets, and no food for certain meals.... Please help us... What can we do? We called the police and they say that we can file a civil report for the lady not doing her job..."
# text = "SUBREDDIT: r/relationships TITLE: To go or not to go? Old friend (f, 23) getting married, I (f 23) don't want to because I have to go from here in the Netherlands to USA. POST: So, I have had this friend for a long time and we have always been there for each other. But about 6 months ago I moved here to the Netherlands to be with my partner (m23). This is our first place together here and we had to buy our own furniture. Needless to say we don't really have any money for trips. My friend is getting married in March in the USA and I feel really guilty out of obligation but I really don't want to go. I don't have the money for it and I don't want to leave here and miss my partner. Reasons for not wanting to go: 1. Money 2. Missing my partner. 3. Being incredibly bored once I'm there! I won't have a car or a way to get around, so I'll just be sitting in my parents house all day. I know it's bad that I don't want to go, but I am just really dreading it. Reddit, what do I do?"
# text = "SUBREDDIT: r/Advice TITLE: Bike tour around the world? POST: Hi there redditors! First of all I'd like to apologize for my English, but as you will see (I hope not), I'm not a native speaker. I'm 23-year-old who recently graduated from university and just stared my first job. Now, you see, my job is interesting and all, but it's an office job and I feel I'm not suited for this. I'm the adventures type, I want something happening around me and going to work from 9 to 6 is just killing me. The one thing that I thought of is a bike trip mostly in Europe, Asia and North Africa. The problem is that I'm from a country with an average salary around 350 euros or 450 USD. My salary is a bit higher - around 450 euros, but still not enough according to what I read is needed for such a trip, witch is about 30000 USD. My question is if somebody has done something like this without any money and if they have some tips for me. I'm thinking about sleeping outdoors or helping some locals for food and a place to crash. Is this something that could work out? I'm planning to go with my girlfriend and I think not too many people would take us in. Any help would be greatly appreciated!"
text = "SUBREDDIT: r/Parenting TITLE: Question about saying 'no' to 18 month old POST: When I tell my son 'no' to something that is either dangerous (like sitting on the arm of the couch or trying to climb onto the television) or something that is an unwanted behavior (biting, hitting etc.) he looks at me and giggles before continuing to do whatever the hell he wants to do. When my husband tells him 'no' he stops what he's doing and sometimes gets upset to the point of crying (I think because his feelings are hurt). I guess the question is, how do I get him to listen to me and not just to his father? I have tried to make my voice sound louder and more masculine, but that just makes him laugh even harder."


In [75]:
prompt = f"{task_prefix}{text}"
generated_summary = generate_summary(prompt, ppo_model, policy_tokenizer, generation_kwargs, output_length_sampler)
print(generated_summary)

TL;DR: man says 'no' to dangerous behavior, stops biting his fathers, but sometimes gets upset. working on making my voice feel more masculine because of poor child behavior.


## Conclusion and Recap

In this notebook, we embarked on the ambitious journey of Reinforcement Learning from Human Feedback (RLHF) with the aim to enhance text summarization. The major components of this approach are the policy model (in this case, a T5 model) and a reward model (based on BERT). Let's recap the steps we've taken and the knowledge we've gained:

1. **Loading the Policy Model (T5)**:
   - We began by initializing the T5 model which would act as our policy model for generating text summaries.
  
2. **Loading the Reward Model (BERT)**:
   - To evaluate the quality of the summaries generated by the T5 model and to give feedback, we employed a BERT-based model which was trained on a mixture of model-written summaries and human feedback.

3. **Training Loop with Proximal Policy Optimization (PPO)**:
   - For the fine-tuning of our T5 policy model, we utilized the PPO algorithm, a state-of-the-art deep reinforcement learning method.
   - We established a loop wherein the T5 model proposed text summaries which were then evaluated by the BERT-based reward model. Using these rewards, the T5 model was fine-tuned to better align with human preferences.
   - Throughout this loop, we monitored various metrics such as the KL divergence, mean returns, and advantages to ensure that the training was progressing desirably.

4. **Inference**:
   - After the RLHF process, we put our enhanced T5 model to the test! By employing a dedicated function, we generated summaries for new input text, reaping the rewards of our fine-tuning efforts.

By leveraging the strengths of both T5 and BERT, and by harnessing the power of reinforcement learning through PPO, we aimed to create a model that produces summaries of superior quality that are more in line with human preferences.

Future efforts can focus on refining the training process, experimenting with different RL algorithms, or scaling up the training data to further improve the performance.

Thank you for joining on this journey, and happy summarizing!


Notebook developed by [Pano Evangeliou](https://www.linkedin.com/in/p-evangeliou/) and [Juan Olano](https://www.linkedin.com/in/juan-olano-b9a330112/) - Sept.2023