<a href="https://colab.research.google.com/github/anshradh/trl_custom/blob/test/03_writing_prompt_reward_model_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Writing Prompt Reward Model Training
The idea is to train a model to predict which response to a reddit writing prompt was ranked the highest, which we'll use as a reward model for training a LM to output human-preferred responses to reddit writing prompts.

##Prerequisites

In [None]:
# Install needed libraries and log into huggingface
!pip install datasets
!pip install transformers
!pip install accelerate
!pip install huggingface_hub
!apt install git-lfs
from huggingface_hub import notebook_login
notebook_login()

In [None]:
import torch
from tqdm.auto import tqdm
import numpy as np
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
import torch.nn.functional as F
from torch.optim import Adam
import torch
import collections
import random
tqdm.pandas()

from datasets import load_dataset, ClassLabel, load_metric, concatenate_datasets

from transformers import AutoModel, AutoTokenizer
from transformers import top_k_top_p_filtering
from torch import nn
from torch.nn import Identity
import torch.nn.functional as F
import torch

from transformers import AutoModelForSequenceClassification, DataCollatorWithPadding, AdamW, get_scheduler

from accelerate import Accelerator

## Data Preprocessing

In [None]:
# load dataset from huggingface
prompt_response_dataset = load_dataset("rewardsignal/reddit_writing_prompts", data_files="prompt_responses_full.csv", split='train[:80%]')

In [None]:
## We tokenize and preprocess the text portion of the dataset here
# tokenizer_name = input()
tokenizer_name = 'distilgpt2'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
prompt_prefix = "Writing Prompt: "
response_prefix = "Response: "

def preprocess_text_function(examples):
  examples["prompt"] = [prompt.replace('[WP] ', prompt_prefix) for prompt in examples["prompt"]]
  examples["response"] = [response_prefix + response for response in examples["response"]]
  return tokenizer(examples['prompt'], examples['response'], truncation=True)

tokenized_reward_dataset = prompt_response_dataset.map(preprocess_text_function, batched=True, num_proc=4)

In [None]:
## Here we binarize the labels (best-ranked response receives a label of 1, rest get a label of 0) and remove extraneous dataset columns
def preprocess_labels_function(examples):
  examples['labels'] = [1 if (rank == 0) else 0 for rank in examples["response_rank"]]
  return examples
tokenized_reward_dataset = tokenized_reward_dataset.map(preprocess_labels_function, batched=True, num_proc=4)
tokenized_reward_dataset.cast_column("labels", ClassLabel(num_classes=2, names=['not-best', 'best'], names_file=None, id=None))
tokenized_reward_dataset = tokenized_reward_dataset.remove_columns(['Unnamed: 0', 'prompt_id', 'prompt', 'prompt_score', 'prompt_created_utc', 'response_id', 'response', 'response_score', 'response_created_utc', 'num_responses', 'response_children', 'score_bin', 'response_rank']
)
tokenized_reward_dataset.set_format("torch")

In [None]:
## Balance our dataset (only select a small portion of the "not-best" labeled examples to match the number of best writing response examples)
positive_reward_dataset = tokenized_reward_dataset.filter(lambda example: example['labels'] == 1)
negative_reward_dataset = tokenized_reward_dataset.filter(lambda example: example['labels'] == 0).shuffle(seed=42).select(range(len(positive_reward_dataset)))
tokenized_reward_dataset = concatenate_datasets([positive_reward_dataset, negative_reward_dataset])

## Getting ready for training

In [None]:
## Split into training and evaluation datasets
reward_train_dataset = tokenized_reward_dataset.shuffle(seed=42).select(range(4*len(tokenized_reward_dataset)//5))
reward_eval_dataset = tokenized_reward_dataset.shuffle(seed=42).select(range(4*len(tokenized_reward_dataset)//5, len(tokenized_reward_dataset)))

In [None]:
## Set up dataloaders for training and evaluating, as well as other essentials for running the training loop
# reward_model_name = input()
reward_model_name = 'distilgpt2'
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
## Load pre-trained sequence classification model
reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_name, num_labels = 2)
reward_model.config.pad_token_id = reward_model.config.eos_token_id

train_dataloader = DataLoader(
    reward_train_dataset, shuffle=True, batch_size=4, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    reward_eval_dataset, batch_size=4, collate_fn=data_collator
)

optimizer = AdamW(reward_model.parameters(), lr=3e-5)
accelerator = Accelerator()
train_dataloader, eval_dataloader, reward_model, optimizer = accelerator.prepare(train_dataloader, eval_dataloader, reward_model, optimizer)
num_epochs = 5
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

progress_bar = tqdm(range(num_training_steps))

## Training

In [None]:
## Run training loop for the reward model
reward_model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        outputs = reward_model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

## Evaluation

In [None]:
## Evaluate accuracy of the reward model on the evaluation dataset
metric = load_metric("accuracy")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
reward_model.to(device)
reward_model.eval()
count = 0
for batch in eval_dataloader:
    count += 1
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = reward_model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()

In [None]:
## Push the model to the hugginface hub
reward_model.push_to_hub(model_name + "_reward_model", use_temp_dir=True)

## Results and Discussion
