<a href="https://colab.research.google.com/github/joaoppadua/cup_book/blob/main/RM_toy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import numpy as np
from tqdm import tqdm


In [3]:
# 1. First, let's create a simple dataset class for paired comparisons
class PreferenceDataset(Dataset):
    def __init__(self, prompts, better_responses, worse_responses, max_length=512):
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
        self.prompts = prompts
        self.better_responses = better_responses
        self.worse_responses = worse_responses
        self.max_length = max_length

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        # Combine prompt with each completion
        better_text = f"{self.prompts[idx]} {self.better_responses[idx]}"
        worse_text = f"{self.prompts[idx]} {self.worse_responses[idx]}"

        # Tokenize both sequences
        better_encoded = self.tokenizer(better_text,
                                      truncation=True,
                                      max_length=self.max_length,
                                      padding='max_length',
                                      return_tensors='pt')

        worse_encoded = self.tokenizer(worse_text,
                                     truncation=True,
                                     max_length=self.max_length,
                                     padding='max_length',
                                     return_tensors='pt')

        return {
            'better_input_ids': better_encoded['input_ids'].squeeze(),
            'better_attention_mask': better_encoded['attention_mask'].squeeze(),
            'worse_input_ids': worse_encoded['input_ids'].squeeze(),
            'worse_attention_mask': worse_encoded['attention_mask'].squeeze(),
        }

In [4]:
# 2. Create the Reward Model
class RewardModel(nn.Module):
    def __init__(self, base_model='bert-base-uncased'):
        super().__init__()
        # Load pretrained model as base
        self.base_model = AutoModel.from_pretrained(base_model)
        # Add a value head that outputs a single scalar
        self.value_head = nn.Sequential(
            nn.Linear(self.base_model.config.hidden_size, 1),
            nn.Tanh()  # Bound values between -1 and 1
        )

    def forward(self, input_ids, attention_mask):
        # Get the pooled output from base model
        outputs = self.base_model(input_ids=input_ids,
                                attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]  # Use [CLS] token
        # Calculate value
        value = self.value_head(pooled_output)
        return value.squeeze(-1)

### Loss Function:

$$
\text{loss}(\theta) = -\frac{1}{\binom{K}{2}} \mathbb{E}_{(x,y_w,y_l)\sim D} [\log(\sigma(r_\theta(x, y_w) - r_\theta(x, y_l)))]
$$

Where:

$\theta$ represents the model parameters

$K$ is the number of responses being ranked (4 to 9 in the paper)

$x$ is the input prompt

$y_w$ is the preferred (winning) completion

$y_l$ is the less preferred (losing) completion

$D$ is the dataset of human comparisons

$r_\theta$ is the reward model's output

$\sigma$ is the sigmoid function

In this rendering, the normalization and the bias have been dropped, and a tanh function used to limit the value of the scalars.

In [11]:
# 3. Training function
def train_reward_model(model, train_loader, epochs=10, lr=1e-5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}'):
            # Move batch to device
            better_input_ids = batch['better_input_ids'].to(device)
            better_attention_mask = batch['better_attention_mask'].to(device)
            worse_input_ids = batch['worse_input_ids'].to(device)
            worse_attention_mask = batch['worse_attention_mask'].to(device)

            # Get reward scores for both completions
            better_score = model(better_input_ids, better_attention_mask)
            worse_score = model(worse_input_ids, worse_attention_mask)

            # Calculate loss (as described in the paper)
            loss = -torch.log(torch.sigmoid(better_score - worse_score)).mean()

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch+1} average loss: {avg_loss:.4f}')

In [12]:
# 4. Example usage

# Create some toy data
prompts = [
    "List a constitutional right in the First Amendment:",
    "Explain the concept of gravity:",
    "Describe the weather:",
    "What is the capital of France?",
    "Translate 'Hello, how are you?' into Spanish.",
    "Write a short poem about nature.",
    "Summarize the plot of Hamlet.",
    "What are the main differences between a civil case and a criminal case?", # Legal example
    "Explain the concept of 'habeas corpus' in law.", # Legal example
    "Draft a simple contract for the sale of goods."  # Legal example
]

better_responses = [
    "Freedom of speech",
    "Gravity is the force that attracts objects with mass towards each other.",
    "Sunny with clear blue skies and a gentle breeze.",
    "Paris",
    "Hola, ¿cómo estás?",
    "Green leaves whisper secrets to the wind, / A gentle stream flows where the flowers begin.",
    "Hamlet, Prince of Denmark, seeks revenge against his uncle, Claudius, who murdered his father and married his mother.",
    "In a civil case, two parties dispute a private matter, while in a criminal case, the government prosecutes an individual for a crime.", # Legal example
    "'Habeas corpus' is a legal principle that protects individuals from unlawful imprisonment by requiring the government to justify their detention.", # Legal example
    "This contract is for the sale of [goods] from [seller] to [buyer] for the price of [price]."  # Legal example
]

worse_responses = [
    "Constitutional rights are very important",
    "Things fall down because of gravity.",
    "The weather is nice.",
    "London",
    "Hello, cómo eres?",
    "Nature is beautiful",
    "Hamlet is a play by Shakespeare.",
    "Civil cases are easy, and criminal cases are hard.", # Legal example
    "Habeas corpus is a Latin term.", # Legal example
    "I agree to sell you stuff." # Legal example
]

# Create dataset and dataloader
dataset = PreferenceDataset(prompts, better_responses, worse_responses)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Initialize and train model
model = RewardModel()
train_reward_model(model, dataloader)

# Example inference
def get_reward_score(prompt, completion):
    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # get the device
    with torch.no_grad():
        text = f"{prompt} {completion}"
        encoded = dataset.tokenizer(text,
                                  truncation=True,
                                  max_length=512,
                                  padding='max_length',
                                  return_tensors='pt')
        # move tensors to device
        score = model(encoded['input_ids'].to(device), encoded['attention_mask'].to(device))
        return score.item()



Epoch 1: 100%|██████████| 5/5 [00:02<00:00,  2.33it/s]


Epoch 1 average loss: 0.7454


Epoch 2: 100%|██████████| 5/5 [00:02<00:00,  2.39it/s]


Epoch 2 average loss: 0.6035


Epoch 3: 100%|██████████| 5/5 [00:02<00:00,  2.38it/s]


Epoch 3 average loss: 0.5669


Epoch 4: 100%|██████████| 5/5 [00:02<00:00,  2.38it/s]


Epoch 4 average loss: 0.4998


Epoch 5: 100%|██████████| 5/5 [00:02<00:00,  2.35it/s]


Epoch 5 average loss: 0.4284


Epoch 6: 100%|██████████| 5/5 [00:02<00:00,  2.35it/s]


Epoch 6 average loss: 0.3764


Epoch 7: 100%|██████████| 5/5 [00:02<00:00,  2.34it/s]


Epoch 7 average loss: 0.2744


Epoch 8: 100%|██████████| 5/5 [00:02<00:00,  2.36it/s]


Epoch 8 average loss: 0.2504


Epoch 9: 100%|██████████| 5/5 [00:02<00:00,  2.32it/s]


Epoch 9 average loss: 0.2224


Epoch 10: 100%|██████████| 5/5 [00:02<00:00,  2.34it/s]

Epoch 10 average loss: 0.1919





In [15]:
# Test the model
prompt = "Give me a constitutional right:"
completion1 = "Freedom of expression"
completion2 = "Some things are better than others"

score1 = get_reward_score(prompt, completion1)
score2 = get_reward_score(prompt, completion2)
print(f"Score for completion 1: {score1:.4f}")
print(f"Score for completion 2: {score2:.4f}")


Score for completion 1: 0.7725
Score for completion 2: -0.8029
