In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical

GAMMA = 0.99
LAMBDA = 0.95
CLIP_EPS = 0.2
EPOCHS = 10
BATCH_SIZE = 64
ACTOR_LR = 3e-4
CRITIC_LR = 1e-3
HIDDEN = 128

class Actor(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, HIDDEN), nn.ReLU(),
            nn.Linear(HIDDEN, HIDDEN), nn.ReLU(),
            nn.Linear(HIDDEN, act_dim), nn.Softmax(dim=-1)
        )

    def forward(self, x):
        return self.net(x)

class Critic(nn.Module):
    def __init__(self, obs_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, HIDDEN), nn.ReLU(),
            nn.Linear(HIDDEN, HIDDEN), nn.ReLU(),
            nn.Linear(HIDDEN, 1)
        )

    def forward(self, x):
        return self.net(x)

class PPO:
    def __init__(self, obs_dim, act_dim):
        self.actor = Actor(obs_dim, act_dim)
        self.critic = Critic(obs_dim)
        self.opt_actor = optim.Adam(self.actor.parameters(), lr=ACTOR_LR)
        self.opt_critic = optim.Adam(self.critic.parameters(), lr=CRITIC_LR)

    def get_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        probs = self.actor(state)
        dist = Categorical(probs)
        action = dist.sample()
        return action.item(), dist.log_prob(action)

    def compute_adv(self, rewards, values, dones):
        adv, ret, gae, next_val = [], [], 0, 0
        for r, v, d in zip(reversed(rewards), reversed(values), reversed(dones)):
            delta = r + GAMMA * next_val * (1 - d) - v
            gae = delta + GAMMA * LAMBDA * gae * (1 - d)
            adv.insert(0, gae)
            ret.insert(0, gae + v)
            next_val = v
        return torch.tensor(adv), torch.tensor(ret)

    def update(self, states, actions, old_logps, advs, rets):
        states = torch.tensor(np.array(states), dtype=torch.float32)
        actions = torch.tensor(actions)
        old_logps = torch.tensor(old_logps)

        for _ in range(EPOCHS):
            for i in range(0, len(states), BATCH_SIZE):
                idx = slice(i, i + BATCH_SIZE)
                logits = self.actor(states[idx])
                dist = Categorical(logits)
                logps = dist.log_prob(actions[idx])
                ratio = torch.exp(logps - old_logps[idx])
                s1 = ratio * advs[idx]
                s2 = torch.clamp(ratio, 1 - CLIP_EPS, 1 + CLIP_EPS) * advs[idx]
                loss_actor = -torch.min(s1, s2).mean()

                vals = self.critic(states[idx]).squeeze()
                loss_critic = nn.MSELoss()(vals, rets[idx])

                self.opt_actor.zero_grad()
                loss_actor.backward()
                self.opt_actor.step()

                self.opt_critic.zero_grad()
                loss_critic.backward()
                self.opt_critic.step()

In [3]:
import gym
import imageio
import torch

def train():
    env = gym.make("CartPole-v1")
    agent = PPO(env.observation_space.shape[0], env.action_space.n)

    for ep in range(1000):
        state, _ = env.reset()
        done = False
        logps, vals, rewards, states, actions, dones = [], [], [], [], [], []
        total = 0

        while not done:
            action, logp = agent.get_action(state)
            value = agent.critic(torch.tensor(state, dtype=torch.float32).unsqueeze(0)).item()
            next_state, reward, done, _, _ = env.step(action)

            states.append(state)
            actions.append(action)
            rewards.append(reward)
            logps.append(logp)
            vals.append(value)
            dones.append(done)

            state = next_state
            total += reward

        advs, rets = agent.compute_adv(rewards, vals, dones)
        agent.update(states, actions, logps, advs, rets)

        if ep % 10 == 0:
            print(f"Episode {ep}, Reward: {total}")
        if total >= 475:
            print(f"Solved at episode {ep}")
            break

    env.close()
    return agent

def record_video(agent, path="ppo_cartpole.mp4", max_steps=1000):
    env = gym.make("CartPole-v1", render_mode="rgb_array")
    frames = []
    state, _ = env.reset()
    for _ in range(max_steps):
        frame = env.render()
        frames.append(frame)
        action, _ = agent.get_action(state)
        state, _, done, _, _ = env.step(action)
        if done:
            break
    env.close()
    imageio.mimsave(path, frames, fps=30)
    print(f"Saved video to {path}")

if __name__ == "__main__":
    agent = train()
    record_video(agent)


  if not isinstance(terminated, (bool, np.bool8)):


Episode 0, Reward: 11.0
Episode 10, Reward: 38.0
Episode 20, Reward: 15.0
Episode 30, Reward: 126.0
Episode 40, Reward: 338.0
Episode 50, Reward: 225.0
Solved at episode 59




Saved video to ppo_cartpole.mp4
