In [7]:
# Proximal Policy Optimization
# Original Implementation: https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/ppo
# Algorithm Doc: https://spinningup.openai.com/en/latest/algorithms/ppo.html

In [8]:
import numpy as np
import scipy.signal
import gymnasium as gym
from gymnasium.spaces import Discrete, Box
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical

In [9]:
def combined_shape(length, shape=None):
    if shape is None: 
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

def layers(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers) 

def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])

def discount_cumsum(x, discount):
    '''
    Magic from rllab for computing discounted cumulative sums of vectors.
    Args: vector x, [x0, x1, x2]
    Returns: [x0 + discount * x1 + discount^2 * x2,  x1 + discount * x2, x2]
    '''
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]

In [10]:
class Actor(nn.Module):
    def _distribution(self, obs):
        raise NotImplementedError
    
    def _log_prob_from_distribution(self, pi, act):
        raise NotImplementedError

    def forward(self, obs, act=None):
        '''
        Produces action distribution for a given observation, and
        optionally computes log probabilities of given action under these distributions
        '''
        pi = self._distribution(obs)
        logp_a = None
        if act is not None:
            logp_a = self._log_prob_from_distribution(pi, act)
        return pi, logp_a

class MLPCategoricalActor(Actor):
    '''Categorial actor for discrete action spaces'''
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.logits_net = layers([obs_dim] + list(hidden_sizes) + [act_dim], activation)

    def _distribution(self, obs):
        logits = self.logits_net(obs)
        return Categorical(logits=logits)

    def _log_prob_from_distribution(self, pi, act):
        return pi.log_prob(act)


class MLPGaussianActor(Actor):
    '''Gaussian actor for continuous action spaces'''
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        log_std = -0.5 * np.ones(act_dim, dtype=np.float32) 
        self.log_std = torch.nn.Parameter(torch.as_tensor(log_std)) 
        self.mu_net = layers([obs_dim] + list(hidden_sizes) + [act_dim], activation)

    def _distribution(self, obs):
        mu = self.mu_net(obs)
        std = torch.exp(self.log_std)
        return Normal(mu, std)

    def _log_prob_from_distribution(self, pi, act):
        return pi.log_prob(act).sum(axis=-1)


class MLPCritic(nn.Module):
    '''Critic class'''
    def __init__(self, obs_dim, hidden_sizes, activation):
        super().__init__()
        self.v_net = layers([obs_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs):
        return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape.


class MLPActorCritic(nn.Module):
    '''Combined Actor-Critic Class'''
    def __init__(self, observation_space, action_space, hidden_sizes=(64,64), activation=nn.Tanh):
        super().__init__()
        obs_dim = observation_space.shape[0]

        # Actor network depends on action space
        if isinstance(action_space, Box):
            self.pi = MLPGaussianActor(obs_dim, action_space.shape[0], hidden_sizes, activation)
        elif isinstance(action_space, Discrete):
            self.pi = MLPCategoricalActor(obs_dim, action_space.n, hidden_sizes, activation)

        # Critic network
        self.v  = MLPCritic(obs_dim, hidden_sizes, activation)

    def step(self, obs):
        '''Returns action and value of that action'''
        with torch.no_grad():
            pi = self.pi._distribution(obs)
            a = pi.sample()
            logp_a = self.pi._log_prob_from_distribution(pi, a)
            v = self.v(obs)
        return a.numpy(), v.numpy(), logp_a.numpy()

    def act(self, obs):
        '''Returns just the action from step()'''
        return self.step(obs)[0]

In [11]:
class PPOBuffer:
    '''
    A buffer for storing trajectories experienced by a PPO agent interacting
    with the environment, and using Generalized Advantage Estimation (GAE-Lambda)
    for calculating the advantages of state-action pairs.
    '''
    def __init__(self, obs_dim, act_dim, size, gamma=0.99, lam=0.95):
        self.obs_buf = np.zeros(combined_shape(size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros(combined_shape(size, act_dim), dtype=np.float32)
        self.adv_buf = np.zeros(size, dtype=np.float32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.ret_buf = np.zeros(size, dtype=np.float32)
        self.val_buf = np.zeros(size, dtype=np.float32)
        self.logp_buf = np.zeros(size, dtype=np.float32)
        self.gamma = gamma
        self.lam = lam
        self.ptr, self.path_start_idx, self.max_size = 0, 0, size

    def store(self, obs, act, rew, val, logp):
        '''Append one timestep of agent-environment interaction to the buffer.'''
        assert self.ptr < self.max_size # Ensures buffer has avaliable space
        self.obs_buf[self.ptr] = obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.val_buf[self.ptr] = val
        self.logp_buf[self.ptr] = logp
        self.ptr += 1

    def finish_path(self, last_val=0):
        '''
        Call this at the end of a trajectory, or when one gets cut off
        by an epoch ending. This looks back over the trajectory and uses rewards 
        and value estimates to compute advantage estimates and rewards-to-go for each state 
        to use as the targets for the value function.
        '''
        path_slice = slice(self.path_start_idx, self.ptr) 
        rews = np.append(self.rew_buf[path_slice], last_val) # last_val should be 0 if the trajectory has ended, and V(s_T) otherwise
        vals = np.append(self.val_buf[path_slice], last_val)

        deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1] # TD error
        self.adv_buf[path_slice] = discount_cumsum(deltas, self.gamma * self.lam) # Discounted cumulative sum of TD error (GAE-Lambda advantage calculation)
        self.ret_buf[path_slice] = discount_cumsum(rews, self.gamma)[:-1] # Rewards-to-Go

        self.path_start_idx = self.ptr

    def get(self):
        '''
        Call this at the end of an epoch to get all of the data from the buffer,
        with advantages normalized. Also resets some pointers.
        '''
        assert self.ptr == self.max_size # Ensures buffer is full before collecting
        self.ptr, self.path_start_idx = 0, 0 # Reset pointers
        adv_mean, adv_std = np.mean(self.adv_buf), np.std(self.adv_buf) # Not using MPI
        self.adv_buf = (self.adv_buf - adv_mean) / adv_std # Normalise advantage (shifted to have mean zero and std one)
        data = dict(obs=self.obs_buf, act=self.act_buf, ret=self.ret_buf, adv=self.adv_buf, logp=self.logp_buf)
        return {k: torch.as_tensor(v, dtype=torch.float32) for k, v in data.items()} # Converts np array to tensors

In [12]:
def ppo(env_name='HalfCheetah-v4', 
        actor_critic=MLPActorCritic, 
        hidden_sizes=[64,64],
        steps_per_epoch=5000, 
        epochs=50, 
        gamma=0.99,
        clip_ratio=0.2, 
        pi_lr=0.0004, 
        vf_lr=0.003,
        train_pi_iters=80, 
        train_v_iters=80, 
        lam=0.97,
        target_kl=0.01,
        max_ep_length=1000):
    
    env = gym.make(env_name)
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    ac = actor_critic(env.observation_space, env.action_space, hidden_sizes=hidden_sizes) 
    var_counts = tuple(count_vars(module) for module in [ac.pi, ac.v])
    print(f'Number of parameters: \t pi: {var_counts[0]}, \t v: {var_counts[1]}\n')

    replay_buffer = PPOBuffer(obs_dim, act_dim, steps_per_epoch, gamma, lam)

    def compute_loss_pi(data):
        obs, act, adv, logp_old = data['obs'], data['act'], data['adv'], data['logp']
        _, logp = ac.pi(obs, act) # Under current policy, get the logp of same observations and actions
        ratio = torch.exp(logp - logp_old) # Log prob of action under old policy / Log prob of action under current policy
        clip_adv = torch.clamp(ratio, 1-clip_ratio, 1+clip_ratio) * adv # Clips the ratio within 1-e / 1+e
        loss_pi = -(torch.min(ratio * adv, clip_adv)).mean()
        approx_kl = (logp_old - logp).mean().item()
        return loss_pi, approx_kl

    def compute_loss_v(data):
        obs, ret = data['obs'], data['ret']
        return ((ac.v(obs) - ret)**2).mean() # Mean squared error loss
    
    # Optimizers
    pi_optimizer = Adam(ac.pi.parameters(), lr=pi_lr)
    vf_optimizer = Adam(ac.v.parameters(), lr=vf_lr)

    def update():
        data = replay_buffer.get() # Retrieve buffer data

        for i in range(train_pi_iters):
            pi_optimizer.zero_grad()
            loss_pi, kl_loss = compute_loss_pi(data)
            if kl_loss > 1.5 * target_kl:
                print('Early stopping at step %d due to reaching max KL' %i)
                break
            loss_pi.backward()
            pi_optimizer.step()

        for i in range(train_v_iters):
            vf_optimizer.zero_grad()
            loss_v = compute_loss_v(data)
            loss_v.backward()
            vf_optimizer.step()

    o, _ = env.reset()  
    ep_ret, ep_len = 0, 0
    ep_rets = []
    epoch_ret = []

    # Main experiment loop: collect experience in env and update each epoch
    for epoch in range(epochs):
        for t in range(steps_per_epoch):
            a, v, logp = ac.step(torch.as_tensor(o, dtype=torch.float32))
            next_o, r, terminated, _, _ = env.step(a)
            ep_ret += r
            ep_len += 1
            replay_buffer.store(o, a, r, v, logp)
            o = next_o

            timeout = ep_len == max_ep_length
            terminal = terminated or timeout
            epoch_ended = t==steps_per_epoch-1

            if terminal or epoch_ended:
                if timeout or epoch_ended:
                    _, v, _ = ac.step(torch.as_tensor(o, dtype=torch.float32))
                else:
                    v = 0
                replay_buffer.finish_path(v)
                ep_rets.append(ep_ret)

                o, _ = env.reset()  
                ep_ret, ep_len = 0, 0
    
        update()
        epoch_ret.append(np.mean(ep_rets))
        print('Epoch: %3d \t Mean epoch return %.3f \t '% (epoch, epoch_ret[-1]))

    return ac, epoch_ret

In [None]:
# Run experiment

ac, epoch_ret = ppo() # steps_per_epoch=5000, epochs=50 -> 250,000 TotalEnvInteracts

In [15]:
# Visualisation

env = gym.make('HalfCheetah-v4', render_mode='rgb_array')
o, _ = env.reset() 

frames = []
terminated = False
truncated = False
while not (terminated or truncated):
    frame = env.render()
    a, _, _ = ac.step(torch.as_tensor(o, dtype=torch.float32))
    next_o, _, terminated, truncated, _  = env.step(a)
    o = next_o
    frames.append(frame)

from PIL import Image
def create_gif(frames, filename= 'ppo.gif'):
    images = [Image.fromarray(frame) for frame in frames]
    images[0].save(filename, save_all=True, append_images=images[1:], optimize=False, duration=1, loop=0)

create_gif(frames)