In [1]:
import os
import torch
from datasets import load_dataset, load_from_disk
from transformers import AutoModelForSequenceClassification, AutoTokenizer

In [2]:
if os.path.exists("./datasets/hh-rlhf"):
    dataset = load_from_disk("./datasets/hh-rlhf")
else:
    dataset = load_dataset("Anthropic/hh-rlhf", split="test")
    dataset.save_to_disk("./datasets/hh-rlhf")    

In [3]:
def prepare_texts(conversation):
    turns = conversation.split("\n\n")[1:]
    context, response = " ".join(turns[:-2]), turns[-1]
    return context, response

In [4]:
class RewardScorer:
    def __init__(self, reward_model="OpenAssistant/reward-model-deberta-v3-large-v2"):
        self.reward_model = reward_model
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = AutoModelForSequenceClassification.from_pretrained(self.reward_model).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(self.reward_model)

    def __call__(self, record):
        c_context, c_response = prepare_texts(record["chosen"])
        inputs = self.tokenizer(
            c_context, 
            c_response, 
            max_length=512,
            truncation=True,
            return_tensors='pt').to(self.device)
        c_reward = self.model(**inputs).logits[0].cpu().detach().item()

        
        r_context, r_response = prepare_texts(record["rejected"])
        inputs = self.tokenizer(
            r_context, 
            r_response, 
            max_length=512,
            truncation=True,
            return_tensors='pt').to(self.device)
        r_reward = self.model(**inputs).logits[0].cpu().detach().item()

        return {
            "context": c_context,
            "chosen_response": c_response,
            "reject_response": r_response,
            "chosen_reward": c_reward,
            "reject_reward": r_reward,
        }

In [5]:
scorer = RewardScorer()

In [6]:
dataset = dataset.map(scorer, remove_columns=dataset.column_names)

  0%|          | 0/8552 [00:00<?, ?ex/s]

In [7]:
dataset.save_to_disk("./datasets/hh-rlhf-ai")