In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random



In [None]:

# Environment setup (simplified)
class TextEnvironment:
    def __init__(self):
        self.actions = ["Say hello", "Ask about weather", "Tell a joke", "Share a fact"]
        self.state = "start"
        
    def reset(self):
        self.state = "start"
        return self.state
        
    def step(self, action_idx):
        action = self.actions[action_idx]
        
        # Normally you'd have a real environment response
        if action_idx == 0:
            response = "Hello there!"
        elif action_idx == 1:
            response = "It's sunny today."
        elif action_idx == 2:
            response = "Why don't scientists trust atoms? Because they make up everything!"
        else:
            response = "Did you know honey never spoils?"
            
        self.state = response
        return response, 0, False, {}  # state, reward, done, info


In [None]:

# Reward Model (predicts human preference)
class RewardModel(nn.Module):
    def __init__(self, input_size):
        super(RewardModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()  # Output between 0 and 1 for preference
        )
        
    def forward(self, x):
        return self.fc(x)

# Policy Network
class PolicyNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super(PolicyNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, output_size)
        )
        
    def forward(self, x):
        return self.fc(x)

# Human feedback simulation (in practice, you'd collect real human feedback)
def get_human_feedback(response1, response2):
    # Simulate human preference (in reality, you'd collect this from actual humans)
    preferences = {
        "Hello there!": 0.8,
        "It's sunny today.": 0.6,
        "Why don't scientists trust atoms? Because they make up everything!": 0.9,
        "Did you know honey never spoils?": 0.7
    }
    
    pref1 = preferences.get(response1, 0.5)
    pref2 = preferences.get(response2, 0.5)
    
    if pref1 > pref2:
        return 1  # Prefer first response
    elif pref2 > pref1:
        return 2  # Prefer second response
    else:
        return 0  # No preference

# Text embedding (simplified)
def embed_text(text):
    # In practice, use a proper text embedding like BERT
    return torch.randn(128)  # Random embedding for demonstration

def train_rlhf():
    # Initialize components
    env = TextEnvironment()
    policy = PolicyNetwork(128, len(env.actions))
    reward_model = RewardModel(128)
    optimizer_policy = optim.Adam(policy.parameters(), lr=0.001)
    optimizer_reward = optim.Adam(reward_model.parameters(), lr=0.001)
    
    # Experience buffer for RLHF
    buffer = deque(maxlen=1000)
    
    # Training loop
    for episode in range(100):
        state = env.reset()
        state_embed = embed_text(state)
        
        # Generate two responses for comparison
        with torch.no_grad():
            logits1 = policy(state_embed)
            prob1 = torch.softmax(logits1, dim=-1)
            action1 = torch.multinomial(prob1, 1).item()
            
            logits2 = policy(state_embed)
            prob2 = torch.softmax(logits2, dim=-1)
            action2 = torch.multinomial(prob2, 1).item()
        
        response1, _, _, _ = env.step(action1)
        response2, _, _, _ = env.step(action2)
        
        # Get human feedback
        preference = get_human_feedback(response1, response2)
        
        # Store in buffer
        if preference != 0:
            buffer.append((state_embed, action1, action2, preference))
        
        # Train reward model if we have enough data
        if len(buffer) >= 32:
            batch = random.sample(buffer, 32)
            
            # Prepare data
            states = torch.stack([x[0] for x in batch])
            actions1 = torch.tensor([x[1] for x in batch])
            actions2 = torch.tensor([x[2] for x in batch])
            preferences = torch.tensor([x[3] for x in batch])
            
            # Train reward model
            optimizer_reward.zero_grad()
            
            # Get rewards for each action
            rewards1 = reward_model(states)
            rewards2 = reward_model(states)
            
            # Calculate loss based on preferences
            loss = 0
            for i in range(len(batch)):
                if preferences[i] == 1:
                    loss += -torch.log(rewards1[i] / (rewards1[i] + rewards2[i]))
                elif preferences[i] == 2:
                    loss += -torch.log(rewards2[i] / (rewards1[i] + rewards2[i]))
            
            loss = loss / len(batch)
            loss.backward()
            optimizer_reward.step()
            
            # Train policy using rewards
            optimizer_policy.zero_grad()
            
            # Get current policy's actions
            logits = policy(states)
            probs = torch.softmax(logits, dim=-1)
            
            # Get rewards for these actions
            with torch.no_grad():
                rewards = reward_model(states)
            
            # Reinforce actions that lead to higher rewards
            selected_probs = probs.gather(1, actions1.unsqueeze(1))
            policy_loss = -torch.log(selected_probs) * rewards1.detach()
            policy_loss = policy_loss.mean()
            
            policy_loss.backward()
            optimizer_policy.step()
            
            print(f"Episode {episode}, Reward Loss: {loss.item():.4f}, Policy Loss: {policy_loss.item():.4f}")

if __name__ == "__main__":
    train_rlhf()