In [1]:
import jupyter_black

jupyter_black.load()

%load_ext autoreload
%autoreload 2

In [2]:
import time

import gymnasium as gym
import torch
import torch.nn as nn
import numpy as np
from torch.distributions import Categorical

from ppo import PPO

# Initialise the environment
env = gym.make("CartPole-v1")

In [3]:
input_dim = env.observation_space.shape[0]
out_dim = env.action_space.n
hidden_dim = 16

In [4]:
actor = nn.Sequential(
    nn.Linear(in_features=input_dim, out_features=hidden_dim),
    nn.Linear(in_features=hidden_dim, out_features=out_dim),
    nn.Softmax(dim=1),
)
critic = nn.Sequential(
    nn.Linear(in_features=input_dim, out_features=hidden_dim),
    nn.Linear(in_features=hidden_dim, out_features=1),
)

In [5]:
ppo = PPO(
    env=env,
    actor=actor,
    critic=critic,
    max_training_samples=2048,
    gamma=0.99,
    epsilon=0.2,
    batch_size=64,
    num_epochs=10,
    num_training_cycles=30,
)

In [6]:
ppo.train()

  0%|          | 0/30 [00:00<?, ?it/s]

Average episode length: 22.051546391752577


  3%|▎         | 1/30 [00:01<00:52,  1.80s/it]

Average episode length: 17.296


  7%|▋         | 2/30 [00:03<00:49,  1.75s/it]

Average episode length: 17.867768595041323


 10%|█         | 3/30 [00:05<00:46,  1.73s/it]

Average episode length: 23.494505494505493


 13%|█▎        | 4/30 [00:06<00:44,  1.70s/it]

Average episode length: 21.019607843137255


 17%|█▋        | 5/30 [00:08<00:42,  1.69s/it]

Average episode length: 17.639344262295083


 20%|██        | 6/30 [00:10<00:41,  1.71s/it]

Average episode length: 21.626262626262626


 23%|██▎       | 7/30 [00:11<00:39,  1.71s/it]

Average episode length: 20.4


 27%|██▋       | 8/30 [00:13<00:36,  1.68s/it]

Average episode length: 23.439560439560438


 30%|███       | 9/30 [00:15<00:34,  1.66s/it]

Average episode length: 19.38738738738739


 33%|███▎      | 10/30 [00:16<00:33,  1.67s/it]

Average episode length: 20.05607476635514


 37%|███▋      | 11/30 [00:18<00:31,  1.68s/it]

Average episode length: 19.232142857142858


 40%|████      | 12/30 [00:20<00:30,  1.67s/it]

Average episode length: 25.03529411764706


 43%|████▎     | 13/30 [00:21<00:28,  1.68s/it]

Average episode length: 19.925925925925927


 47%|████▋     | 14/30 [00:23<00:27,  1.70s/it]

Average episode length: 20.490384615384617


 50%|█████     | 15/30 [00:25<00:25,  1.69s/it]

Average episode length: 21.03030303030303


 53%|█████▎    | 16/30 [00:27<00:23,  1.68s/it]

Average episode length: 26.5125


 57%|█████▋    | 17/30 [00:28<00:21,  1.66s/it]

Average episode length: 18.620689655172413


 60%|██████    | 18/30 [00:30<00:19,  1.66s/it]

Average episode length: 23.307692307692307


 63%|██████▎   | 19/30 [00:32<00:18,  1.68s/it]

Average episode length: 19.0990990990991


 67%|██████▋   | 20/30 [00:33<00:16,  1.67s/it]

Average episode length: 29.816901408450704


 70%|███████   | 21/30 [00:35<00:14,  1.65s/it]

Average episode length: 17.672131147540984


 73%|███████▎  | 22/30 [00:36<00:13,  1.65s/it]

Average episode length: 20.4


 77%|███████▋  | 23/30 [00:38<00:11,  1.65s/it]

Average episode length: 19.37837837837838


 80%|████████  | 24/30 [00:40<00:09,  1.66s/it]

Average episode length: 30.492753623188406


 83%|████████▎ | 25/30 [00:41<00:08,  1.66s/it]

Average episode length: 19.6697247706422


 87%|████████▋ | 26/30 [00:43<00:06,  1.66s/it]

Average episode length: 19.88888888888889


 90%|█████████ | 27/30 [00:45<00:04,  1.67s/it]

Average episode length: 21.43


 93%|█████████▎| 28/30 [00:46<00:03,  1.67s/it]

Average episode length: 31.029411764705884


 97%|█████████▋| 29/30 [00:48<00:01,  1.67s/it]

Average episode length: 18.612068965517242


100%|██████████| 30/30 [00:50<00:00,  1.68s/it]


In [11]:
env = gym.make("CartPole-v1", render_mode="human")

In [12]:
# Set model to eval mode
ppo.critic.eval()

# Reset environment
obs, _ = env.reset()
done = False

while not done:
    # Convert observation to tensor
    state = torch.from_numpy(obs).float().unsqueeze(0)  # shape: (1, 4)

    # Get value estimate from critic
    with torch.no_grad():
        value = ppo.critic(state).item()

    # Pick action randomly or use a policy (if you have ppo.actor)
    action = env.action_space.sample()

    # Step in environment
    obs, reward, terminated, truncated, _ = env.step(action)
    done = terminated or truncated

    time.sleep(0.05)  # Slow down for visualization

env.close()