In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random

# ---- Setup Environment ----
env = gym.make("CartPole-v1", render_mode=None)

# ---- Policy Network ----
class PolicyNet(nn.Module):
    def __init__(self, obs_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim)
        )

    def forward(self, x):
        logits = self.net(x)
        return torch.softmax(logits, dim=-1)

# ---- Reward Model ----
class RewardModel(nn.Module):
    def __init__(self, traj_len, obs_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(traj_len * obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, traj_flat):
        return self.net(traj_flat)

# ---- Collect Trajectory ----
def collect_trajectory(env, policy, traj_len=50):
    obs = env.reset()[0]
    traj = []
    for _ in range(traj_len):
        obs_tensor = torch.tensor([obs], dtype=torch.float32)
        probs = policy(obs_tensor)
        action = torch.multinomial(probs, 1).item()
        next_obs, _, terminated, truncated, _ = env.step(action)
        traj.append(obs)
        obs = next_obs
        if terminated or truncated:
            break
    # pad trajectory
    while len(traj) < traj_len:
        traj.append(np.zeros_like(traj[0]))
    return torch.tensor(traj, dtype=torch.float32)

# ---- Simulated Human Feedback ----
def synthetic_preference(t1, t2):
    def stability_score(traj):
        return -torch.mean(torch.abs(traj[:, 2]))  # pole angle
    score1 = stability_score(t1)
    score2 = stability_score(t2)
    return 0 if score1 > score2 else 1

# ---- Train Reward Model ----
def train_reward_model(reward_model, pairs, labels, epochs=5):
    opt = optim.Adam(reward_model.parameters(), lr=1e-3)
    loss_fn = nn.BCEWithLogitsLoss()
    for _ in range(epochs):
        for (t1, t2), label in zip(pairs, labels):
            r1 = reward_model(t1.view(1, -1))
            r2 = reward_model(t2.view(1, -1))
            logits = r1 - r2
            target = torch.tensor([[1.0 if label == 0 else 0.0]])
            loss = loss_fn(logits, target)
            opt.zero_grad()
            loss.backward()
            opt.step()

# ---- Policy Update ----
def update_policy(policy, reward_model, env, steps=20):
    opt = optim.Adam(policy.parameters(), lr=1e-2)
    for _ in range(steps):
        traj = collect_trajectory(env, policy)
        r = reward_model(traj.view(1, -1))
        loss = -r
        opt.zero_grad()
        loss.backward()
        opt.step()

# ---- Evaluation ----
def evaluate_policy(policy, env, episodes=5):
    total_rewards = []
    for _ in range(episodes):
        obs = env.reset()[0]
        total = 0
        done = False
        while not done:
            obs_tensor = torch.tensor([obs], dtype=torch.float32)
            probs = policy(obs_tensor)
            action = torch.argmax(probs).item()
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total += reward
        total_rewards.append(total)
    return np.mean(total_rewards)

# ---- MAIN LOOP ----
obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
traj_len = 50

policy = PolicyNet(obs_dim, action_dim)
reward_model = RewardModel(traj_len, obs_dim)

# Pre-RLHF Evaluation
pre_rlhf_rewards = [evaluate_policy(policy, env) for _ in range(10)]

# 1. Collect data
pairs, labels = [], []
for _ in range(100):
    t1 = collect_trajectory(env, policy, traj_len)
    t2 = collect_trajectory(env, policy, traj_len)
    label = synthetic_preference(t1, t2)
    pairs.append((t1, t2))
    labels.append(label)

# 2. Train Reward Model
train_reward_model(reward_model, pairs, labels)

# 3. RLHF Policy Update
rlhf_rewards = []
for _ in range(30):
    update_policy(policy, reward_model, env)
    avg_reward = evaluate_policy(policy, env)
    rlhf_rewards.append(avg_reward)

# Post-RLHF Evaluation
post_rlhf_rewards = [evaluate_policy(policy, env) for _ in range(10)]

# ---- Plotting ----
plt.figure(figsize=(10, 5))
plt.plot(rlhf_rewards, label="RLHF Training Reward")
plt.axhline(np.mean(pre_rlhf_rewards), color='r', linestyle='--', label="Pre-RLHF Avg")
plt.axhline(np.mean(post_rlhf_rewards), color='g', linestyle='--', label="Post-RLHF Avg")
plt.xlabel("RLHF Iteration")
plt.ylabel("Average Reward")
plt.title("CartPole with Simulated RLHF")
plt.legend()
plt.show()
