[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/emasquil/ppo/blob/main/ppo.ipynb)

# Proximal Policy Optimization (PPO) playground

Notebook for running PPO on simple environments from OpenAI Gym

## Install dependencies (only on Google Colab)

Please run the following cell and after everything has been installed restart the runtime

In [None]:
# # Installing our own implementation
! git clone https://github.com/emasquil/ppo.git
! pip install -e /content/ppo

# Visualization stuff
!sudo apt-get update
!sudo apt-get install -y xvfb ffmpeg freeglut3-dev

# Dependencies needed for running mujoco on colab
!apt-get install -y \
    libgl1-mesa-dev \
    libgl1-mesa-glx \
    libglew-dev \
    libosmesa6-dev \
    software-properties-common

!apt-get install -y patchelf

!pip install free-mujoco-py
!pip install imageio-ffmpeg

## Imports

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

# Load autoreload extension
%load_ext autoreload
%autoreload 2

In [None]:
import base64
import imageio
import IPython
import tqdm.notebook as tq
import numpy as np
import os
import jax

from acme import specs
from trax.jaxboard import SummaryWriter
import pyvirtualdisplay

# Set up a virtual display for rendering.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()

from ppo.agents import VanillaPPO, general_advantage_estimation
from ppo.env_wrapper import PendulumEnv, ReacherEnv
from ppo.networks import PolicyNetFixedSigma, PolicyNetComplete, ValueNetwork
from ppo.replay_buffers import DataLoader

### Visualization functions

In [None]:
def display_video(frames, filename="temp.mp4", frame_repeat=1):
    """Save and display video."""
    # Write video
    with imageio.get_writer(filename, fps=60) as video:
        for frame in frames:
            for _ in range(frame_repeat):
                video.append_data(frame)
    # Read video and display the video
    video = open(filename, "rb").read()
    b64_video = base64.b64encode(video)
    video_tag = ('<video  width="320" height="240" controls alt="test" ' 'src="data:video/mp4;base64,{0}">').format(
        b64_video.decode()
    )
    return IPython.display.HTML(video_tag)

## Definitions

Definition of all the parts used in the learning loop: environment, agent, etc.

In [None]:
environment = PendulumEnv()
environment_spec = specs.make_environment_spec(environment)

In [None]:
# Training
num_training_iterations = 50
timesteps_per_iteration = 2000
gae_lambda = 0.95
num_epochs = 10
batch_size = 32
learning_rate_params = {
    "annealing": False,
    "policy": {
        "initial_learning_rate": 1e-4,
        "last_learning_rate": 1e-6,
    },
    "value": {
        "initial_learning_rate": 3e-3,
        "last_learning_rate": 1e-5,
    },
    "annealing_duration": num_training_iterations * np.ceil(timesteps_per_iteration / batch_size) * num_epochs,
}  # if "annealing" = False then "initial_learning_rate" is taken as the steady value
clipping_ratio_threshold = 0.2
max_grad_norm = 0.5
discount = 0.99
kl_threshold = None  # if kl_threshold is None we're not using it for early stopping
# if providing a sigma value, the policy net will only predict the mean and we'll use this fixed value as std
policy_net_sigma = 0.3 * (environment_spec.actions.maximum - environment_spec.actions.minimum)
policy_net_sigma = None
# if not, you need to provide this 2 parameters to set a proper scale for sigma prediction
min_sigma = 1e-6 * (environment_spec.actions.maximum - environment_spec.actions.minimum)
init_sigma = 0.3 * (environment_spec.actions.maximum - environment_spec.actions.minimum)

# Network
policy_hidden_layers = [
    {"output_size": 64, "std": np.sqrt(2), "bias": 0},
    {"output_size": 64, "std": np.sqrt(2), "bias": 0},
]
policy_last_layer = {"output_size": 64, "std": 0.01, "bias": 0}
value_hidden_layers = [
    {"output_size": 64, "std": np.sqrt(2), "bias": 0},
    {"output_size": 64, "std": np.sqrt(2), "bias": 0},
]
value_last_layer = {"output_size": 64, "std": 1, "bias": 0}

# Logs
log_dir = "experiments"
experiment_name = "test"

# Keys
seed = 0
key = jax.random.PRNGKey(seed)
key, key_init_networks = jax.random.split(key)
key, key_sampling_policy = jax.random.split(key)
key_dataloader, key_replay_buffer = jax.random.split(key)

logger_freq = 200

In [None]:
# Create the agent
if policy_net_sigma:

    def policy_network(observations):
        return PolicyNetFixedSigma(
            policy_hidden_layers, policy_last_layer, environment_spec.actions, "policy", policy_net_sigma
        )(observations)

else:

    def policy_network(observations):
        return PolicyNetComplete(
            policy_hidden_layers,
            policy_last_layer,
            environment_spec.actions,
            min_sigma=min_sigma,
            init_sigma=init_sigma,
            name="policy",
        )(observations)


def value_network(observations):
    return ValueNetwork(value_hidden_layers, value_last_layer, "value")(observations)


agent = VanillaPPO(
    environment_spec=environment_spec,
    policy_network=policy_network,
    value_network=value_network,
    key_init_networks=key_init_networks,
    key_sampling_policy=key_sampling_policy,
    key_replay_buffer=key_replay_buffer,
    learning_rate_params=learning_rate_params,
    discount=discount,
    clipping_ratio_threshold=clipping_ratio_threshold,
    max_grad_norm=max_grad_norm,
)

## Interaction loop

In [None]:
def training_loop(
    environment,
    agent,
    num_training_iterations,
    num_epochs,
    timesteps_per_iteration,
    batch_size,
    log_dir,
    experiment_name,
    kl_threshold,
    key_dataloader,
    environment_spec,
    logger_freq,
):
    """
    Main training loop
    """
    # Initializing counters
    episodic_returns = []
    writer = SummaryWriter(os.path.join(log_dir, experiment_name))
    writer.text("info/title", f"Playing {str(environment)}")
    writer.text(
        "info/specs",
        f"Maximum action: {environment_spec.actions.maximum} | Minimum action: {environment_spec.actions.minimum}",
    )
    # Counter to keep track of the global timestep
    t = 0

    pbar = tq.tqdm(range(num_training_iterations), position=0)
    pbar.set_description("Training loop iteration")
    for iteration in pbar:

        # --------Rollout phase--------
        # Reset any counts and start the environment.
        agent.clear_memory()
        timestep = environment.reset()

        # Make the first observation.
        agent.observe_first(timestep)

        pbar_rollout = tq.tqdm(range(timesteps_per_iteration), position=1, leave=False)
        pbar_rollout.set_description(f"Rollout step")
        for rollout_step in pbar_rollout:
            t += 1
            value = agent.get_value(timestep.observation)
            action, log_prob = agent.select_action_and_prob(timestep.observation)
            timestep = environment.step(action)
            agent.observe(value, log_prob, action, timestep)

            if t % logger_freq == 0:
                # Add to logs
                writer.scalar("per_timestep/reward", timestep.reward, t)
                # Log each action
                for i, a in enumerate(action):
                    writer.scalar(f"per_timestep/action[{i}]", a, t)
                # Log mu and sigma values
                mu, sigma = agent.policy_network.apply(agent.policy_params, timestep.observation)
                for i in range(len(mu)):
                    writer.scalar(f"per_timestep/mu[{i}]", mu[i], t)
                    writer.scalar(f"per_timestep/sigma[{i}]", sigma[i], t)

            if timestep.last():
                # Plot the episodic return
                assert episodic_return != 0.0
                writer.scalar(f"per_timestep/episodic_return", episodic_return, t)
                episodic_returns.append(episodic_return)
                # Add last value
                agent.add_last_value(timestep)
                # Only restart the environment if it's not the last timestep we should run
                if rollout_step != timesteps_per_iteration - 1:
                    # Restart the episode
                    timestep = environment.reset()
                    episodic_return = 0.0
                    agent.observe_first(timestep)
            else:
                # the internal returns get reset at the last episode, so we need to keep track of it before that
                episodic_return = environment._env.returns

        # Log the average episodic return for the episodes corresponding to this training iteration
        writer.scalar(
            "per_training_it/avg_episodic_return",
            np.mean(episodic_returns),
            iteration,
        )
        episodic_returns = []

        # If last trajectory is not completed, we still need to add the last value
        if not timestep.last():
            # Add last value
            agent.add_last_value(timestep)

        # --------Learning phase--------
        # Cast the buffer to np arrays
        agent.cast_to_numpy()
        # Compute advantages
        advantages = general_advantage_estimation(
            agent.replay_buffer.values_t,
            agent.replay_buffer.dones_tp1,
            agent.replay_buffer.rewards_tp1,
            discount,
            gae_lambda,
        )
        agent.add_advantages(advantages)

        key_dataloader, rng = jax.random.split(key_dataloader)
        dataloader = DataLoader(agent.replay_buffer, batch_size, rng)

        pbar_epochs = tq.tqdm(range(num_epochs), leave=False, position=1)
        pbar_epochs.set_description(f"Epoch")
        for epoch in pbar_epochs:
            dataloader.shuffle()
            value_losses = []
            policy_losses = []
            for batch in dataloader:
                value_loss, policy_loss, kl_approximation = agent.update(batch)
                value_losses.append(value_loss)
                policy_losses.append(policy_loss)

            writer.scalar("per_epoch/value_loss", np.mean(value_losses), epoch + iteration * num_epochs)
            writer.scalar("per_epoch/policy_loss", np.mean(policy_losses), epoch + iteration * num_epochs)
            writer.scalar("per_epoch/kl_divergence", kl_approximation, epoch + iteration * num_epochs)
            policy_lr, value_lr = agent.get_learning_rate()
            writer.scalar("per_epoch/learning_rate_policy", policy_lr)
            writer.scalar("per_epoch/learning_rate_value", value_lr)

            if kl_threshold is not None and kl_approximation > kl_threshold:
                pbar_epochs.disp(close=True)
                break
    return

In [None]:
def evaluate(environment, agent, evaluation_episodes, experiment_name):
    writer = SummaryWriter(os.path.join(log_dir, experiment_name))

    frames = []
    episodic_returns = np.zeros(evaluation_episodes)
    pbar = tq.tqdm(range(evaluation_episodes))
    pbar.set_description("Episode")
    for episode in pbar:
        timestep = environment.reset()
        steps = 0
        while not timestep.last():
            frames.append(environment.render(mode="rgb_array"))
            action = agent.select_action(timestep.observation)
            timestep = environment.step(action)
            steps += 1
            if not timestep.last():
                episodic_returns[episode] = environment._env.returns
        print(f"Episode {episode} ended with reward {episodic_returns[episode]} in {steps} steps")
        writer.text(f"results/returns_{episode}", str(episodic_returns[episode]))
    writer.text("results/average_returns", str(np.mean(episodic_returns)))
    return frames

### Train

In [None]:
training_loop(
    agent=agent,
    environment=environment,
    timesteps_per_iteration=timesteps_per_iteration,
    batch_size=batch_size,
    num_training_iterations=num_training_iterations,
    num_epochs=num_epochs,
    log_dir=log_dir,
    experiment_name=experiment_name,
    kl_threshold=kl_threshold,
    key_dataloader=key_dataloader,
    environment_spec=environment_spec,
    logger_freq=logger_freq,
)

### Evaluate

In [None]:
display_video(evaluate(agent=agent, environment=environment, evaluation_episodes=10, experiment_name=experiment_name))