In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForCausalLMWithValueHead, AutoTokenizer
from datasets import load_dataset

# Load model and tokenizer
model_name = "your-pretrained-model"  # Replace with an appropriate model
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
# Define dataset class
class RewardDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=512):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        input_text = item['question'] + " " + item['response'] + " " + item['feedback']
        inputs = self.tokenizer(input_text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")
        label = torch.tensor(item['score'], dtype=torch.float32)
        
        return {"input_ids": inputs["input_ids"].squeeze(0), 
                "attention_mask": inputs["attention_mask"].squeeze(0), 
                "label": label}

# Load dataset
dataset = load_dataset("your_dataset")  # Replace with your dataset
train_dataset = RewardDataset(dataset["train"], tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)

# Training loop
epochs = 3
model.train()
for epoch in range(epochs):
    epoch_loss = 0.0
    for batch in train_dataloader:
        optimizer.zero_grad()
        
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = outputs.logits.squeeze()
        
        loss = criterion(predictions, labels)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(train_dataloader):.4f}")
