In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchrl.envs import EnvBase
from torchrl.trainers import PPOTrainer
from torchrl.objectives import ClipPPOLoss
from torchrl.data import TensorDictReplayBuffer

# -----------------------------
# Define LLM policy + value head
# -----------------------------
class LLMPolicy(nn.Module):
    def __init__(self, vocab_size, hidden_dim):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_dim)
        self.lm = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=8, batch_first=True),
            num_layers=2
        )
        self.lm_head = nn.Linear(hidden_dim, vocab_size)
        self.value_head = nn.Linear(hidden_dim, 1)

    def forward(self, tokens):
        x = self.embed(tokens)
        x = self.lm(x)
        logits = self.lm_head(x)        # [B, T, V]
        values = self.value_head(x)     # [B, T, 1]
        return logits, values.squeeze(-1)

# -----------------------------
# Define a TorchRL Env wrapper
# -----------------------------
class LLMEnv(EnvBase):
    def __init__(self, reward_model, vocab_size, max_len=32):
        super().__init__()
        self.reward_model = reward_model
        self.vocab_size = vocab_size
        self.max_len = max_len

    def reset(self, tensordict=None):
        # Start with a prompt (toy: random tokens)
        prompt = torch.randint(0, self.vocab_size, (1, 8))
        return {"tokens": prompt}

    def step(self, tensordict):
        tokens = tensordict["tokens"]
        # Sample next token (action)
        action = torch.randint(0, self.vocab_size, (1, 1))
        new_tokens = torch.cat([tokens, action], dim=1)

        # Reward from reward model
        reward = self.reward_model(new_tokens)

        done = new_tokens.size(1) >= self.max_len
        return {
            "tokens": new_tokens,
            "reward": reward,
            "done": torch.tensor([done])
        }

# -----------------------------
# PPO training loop
# -----------------------------
def train():
    vocab_size = 32000
    hidden_dim = 512

    policy = LLMPolicy(vocab_size, hidden_dim)
    reward_model = lambda tokens: torch.rand(tokens.size(0))  # stub

    env = LLMEnv(reward_model, vocab_size)

    # PPO loss
    loss_module = ClipPPOLoss(
        actor=policy.lm_head,
        critic=policy.value_head,
        clip_epsilon=0.2,
        entropy_coef=0.01,
        critic_coef=0.5,
    )

    # Replay buffer
    rb = TensorDictReplayBuffer(storage_size=1000)

    trainer = PPOTrainer(
        loss_module=loss_module,
        env=env,
        replay_buffer=rb,
        optim=torch.optim.Adam(policy.parameters(), lr=3e-4),
        frames_per_batch=64,
        total_frames=10000,
    )

    trainer.train()

if __name__ == "__main__":
    train()