In [55]:
import gymnasium as gym
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.distributions import Categorical

In [56]:
DEVICE = 'cpu'

In [57]:
# Policy and value model
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()

        self.shared_layers = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU())
        
        self.policy_layers= nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim))
        
        self.value_layers = nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1))
        
    def value(self, obs):
        z = self.shared_layers(obs)
        value = self.value_layers(z)
        return value
    
    def policy(self, obs):
        z = self.shared_layers(obs)
        policy_logits = self.policy_layers(z)
        return policy_logits

    def forward(self, obs):
        obs = self.shared_layers(obs)
        policy_logits = self.policy_layers(obs)
        value = self.value_layers(obs)
        return policy_logits, value

In [68]:
### Create trainer
class PPOTrainer():
    def __init__(self,
                 actor_critic,
                 ppo_clip_value=0.2,
                 target_kl_div=0.01,
                 max_policy_train_iters=80,
                 value_train_iters=80,
                 policy_lr=3e-4,
                 value_lr=1e-2):
        self.ac = actor_critic
        self.ppo_clip_val = ppo_clip_value
        self.target_kl_div = target_kl_div
        self.max_policy_train_iters = max_policy_train_iters
        self.value_train_iters = value_train_iters

        policy_params = list(self.ac.shared_layers.parameters()) + list(self.ac.policy_layers.parameters())
        self.policy_optim = optim.Adam(policy_params, lr=policy_lr)

        value_params = list(self.ac.shared_layers.parameters()) + list(self.ac.value_layers.parameters())
        self.value_optim = optim.Adam(value_params, lr=value_lr)

    def train_policy(self, obs, acts, old_log_probs, gaes):
        for _ in range(self.max_policy_train_iters):
            self.policy_optim.zero_grad()

            new_logits = self.ac.policy(obs)
            new_logits = Categorical(logits=new_logits)
            new_log_probs = new_logits.log_prob(acts)

            policy_ratio = torch.exp(new_log_probs - old_log_probs)
            clipped_ratio = policy_ratio.clamp(1 - self.ppo_clip_val, 1 + self.ppo_clip_val)

            clipped_loss = clipped_ratio * gaes
            full_loss = policy_ratio * gaes
            policy_loss = -torch.min(clipped_loss, full_loss).mean()

            policy_loss.backward()
            self.policy_optim.step()

            kl_div = (old_log_probs - new_log_probs).mean()
            if kl_div > 1.5 * self.target_kl_div:
                break

    def train_value(self, obs, returns):
        for _ in range(self.value_train_iters):
            self.value_optim.zero_grad()

            values = self.ac.value(obs)
            value_loss = nn.MSELoss()(values, returns)

            value_loss.backward()
            self.value_optim.step()


In [69]:
### Utility functions
def discount_rewards(rewards, gamma=0.99):

    "Return the discounted rewards for a trajectory."

    new_rewards = [float(rewards[-1])]
    for i in reversed(range(len(rewards)-1)):
        new_rewards.append(rewards[i] + gamma * new_rewards[-1])
    return np.array(new_rewards[::-1])

def calculate_gaes(rewards, values, gamma=0.99, decay=0.97):
    """
    rewards: torch tensor (T,)
    values: torch tensor (T,)
    """
    gaes = torch.zeros_like(rewards)
    next_value = 0
    gae = 0

    for t in reversed(range(len(rewards))):
        delta = rewards[t] + gamma * next_value - values[t]
        gae = delta + gamma * decay * gae
        gaes[t] = gae
        next_value = values[t]

    return gaes


    


In [84]:
def rollout(model, env, max_steps=1000):
    """
    Performs a single rollout.
    Returns training data in the shape (n_steps, observation_dim)
    and the cumulative reward.
    """
    ### Create data storage
    train_data = [[], [], [], [], []]  # obs, actions, rewards, values, act_log_probs 
    obs, _ = env.reset()

    ep_reward = 0
    for _ in range(max_steps):
        obs_t = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(DEVICE)
        logits, val = model(obs_t)
        act_distribution = Categorical(logits=logits)
        act = act_distribution.sample()
        act_log_prob = act_distribution.log_prob(act).item()

        act, val = act.item(), val

        next_obs, reward, terminated, truncated, _ = env.step(act)
        done = terminated or truncated

        ### Record data for training
        for i , item in enumerate((obs, act, reward, val, act_log_prob)):
            train_data[i].append(item)

        obs = next_obs
        ep_reward += reward

        if done:
            break

    ### Do train data filtering
    train_data[2] = torch.tensor(train_data[2], dtype=torch.float32, device=DEVICE)
    train_data[3] = torch.stack(train_data[3]).squeeze(-1)

    train_data[3] = calculate_gaes(train_data[2], train_data[3])

    return train_data, ep_reward

In [85]:
env = gym.make('CartPole-v1')
model = ActorCritic(env.observation_space.shape[0], env.action_space.n).to(DEVICE)
train_data, reward = rollout(model, env) # Test rollout function

In [86]:
### Init PPO trainer and parameters
# Define hyperparameters
n_episodes = 200
print_freq = 20

ppo = PPOTrainer(
    model,
    policy_lr=3e-4,
    value_lr=1e-3,
    target_kl_div=0.02,
    max_policy_train_iters=40,
    value_train_iters=40)



In [92]:
# Training loop
ep_rewards = []
for episode_idx in range(n_episodes):
    # Perform rollout
    train_data, reward = rollout(model, env)
    ep_rewards.append(reward)

    # Data formatting for training would go here

    # Shuffle
    permute_idx = np.random.permutation(len(train_data[0]))
      
    # Policy Data
    obs = torch.tensor(np.array(train_data[0])[permute_idx], dtype=torch.float32).to(DEVICE)
    acts = torch.tensor(np.array(train_data[1])[permute_idx], dtype=torch.int64).to(DEVICE)
    gaes = train_data[3].detach()[permute_idx].to(DEVICE)
    act_log_probs = torch.tensor(np.array(train_data[4])[permute_idx], dtype=torch.float32).to(DEVICE)


    # Value Data
    returns = discount_rewards(train_data[2])[permute_idx]
    returns = torch.tensor(returns, dtype=torch.float32).to(DEVICE)

    # Train policy and value networks
    ppo.train_policy(obs, acts, act_log_probs, gaes)
    ppo.train_value(obs, returns.unsqueeze(-1))

    # Logging
    if (episode_idx + 1) % print_freq == 0:
        print(f'Episode: {episode_idx + 1}, Reward: {np.mean(ep_rewards[-print_freq:])}')

Episode: 20, Reward: 63.85
Episode: 40, Reward: 140.15
Episode: 60, Reward: 152.9
Episode: 80, Reward: 232.15
Episode: 100, Reward: 72.75
Episode: 120, Reward: 180.15
Episode: 140, Reward: 169.55
Episode: 160, Reward: 224.55
Episode: 180, Reward: 244.0
Episode: 200, Reward: 222.25
