[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/emasquil/ppo/blob/main/ppo.ipynb)

# Proximal Policy Optimization (PPO) playground

Notebook for running PPO on simple environments from OpenAI Gym

## Install dependencies (only on Google Colab)

In [None]:
# # Installing our own implementation
# ! git clone https://github.com/emasquil/ppo.git
# ! pip install -e /content/ppo

# # Visualization stuff
# !sudo apt-get update
# !sudo apt-get install -y xvfb ffmpeg freeglut3-dev

# # Dependencies needed for running mujoco on colab
# !apt-get install -y \
#     libgl1-mesa-dev \
#     libgl1-mesa-glx \
#     libglew-dev \
#     libosmesa6-dev \
#     software-properties-common

# !apt-get install -y patchelf

# !pip install free-mujoco-py
# !pip install imageio-ffmpeg

## Imports

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

# Load autoreload extension
%load_ext autoreload
%autoreload 2

In [None]:
import base64
import imageio
import IPython
import tqdm.notebook as tq
import numpy as np
import os

from acme import specs
from trax.jaxboard import SummaryWriter
import pyvirtualdisplay

# Set up a virtual display for rendering.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()

from ppo.agents import VanillaPPO, general_advantage_estimation
from ppo.env_wrapper import PendulumEnv, ReacherEnv
from ppo.networks import PolicyNetwork, ValueNetwork
from ppo.replay_buffers import DataLoader

### Visualization functions

In [None]:
def display_video(frames, filename="temp.mp4", frame_repeat=1):
    """Save and display video."""
    # Write video
    with imageio.get_writer(filename, fps=60) as video:
        for frame in frames:
            for _ in range(frame_repeat):
                video.append_data(frame)
    # Read video and display the video
    video = open(filename, "rb").read()
    b64_video = base64.b64encode(video)
    video_tag = ('<video  width="320" height="240" controls alt="test" ' 'src="data:video/mp4;base64,{0}">').format(
        b64_video.decode()
    )
    return IPython.display.HTML(video_tag)

## Definitions

Definition of all the parts used in the learning loop: environment, agent, etc.

In [None]:
environment = PendulumEnv()
environment_spec = specs.make_environment_spec(environment)

In [None]:
# Training
num_training_iterations = 50
timesteps_per_iteration = 2000
gae_lambda = 0.95
num_epochs = 10
batch_size = 32
learning_rate_params = {
    "annealing": True,
    "initial_learning_rate": 3e-4,
    "last_learning_rate": 1e-6,
    "annealing_duration": num_training_iterations * np.ceil(timesteps_per_iteration / batch_size) * num_epochs,
}  # if "annealing" = False then "initial_learning_rate" is taken as the steady value
clipping_ratio_threshold = 0.2
max_grad_norm = 0.5
discount = 0.99
kl_threshold = None

# Network
policy_hidden_layers = [
    {"output_size": 64, "std": np.sqrt(2), "bias": 0},
    {"output_size": 64, "std": np.sqrt(2), "bias": 0},
]
policy_last_layer = {"output_size": 64, "std": 0.01, "bias": 0}
value_hidden_layers = [
    {"output_size": 64, "std": np.sqrt(2), "bias": 0},
    {"output_size": 64, "std": np.sqrt(2), "bias": 0},
]
value_last_layer = {"output_size": 64, "std": 1, "bias": 0}

# Logs
log_dir = "experiments"
experiment_name = "pendulum_0"

# Keys
seed = 0
key = jax.random.PRNGKey(seed)
key, key_networks = jax.random.split(key)
key, key_sampling_policy = jax.random.split(key)
key, key_shuffling_batch = jax.random.split(key)

In [None]:
# Create the agent
def policy_network(observations):
    return PolicyNetwork(policy_hidden_layers, policy_last_layer, environment_spec.actions, "policy")(observations)


def value_network(observations):
    return ValueNetwork(value_hidden_layers, value_last_layer, "value")(observations)


agent = VanillaPPO(
    observation_spec=environment_spec.observations,
    policy_network=policy_network,
    value_network=value_network,
    key_networks=key_networks,
    key_sampling_policy=key_sampling_policy,
    learning_rate_params=learning_rate_params,
    discount=discount,
    clipping_ratio_threshold=clipping_ratio_threshold,
    max_grad_norm=max_grad_norm,
)

## Interaction loop

In [None]:
def training_loop(
    environment,
    agent,
    num_training_iterations,
    num_epochs,
    timesteps_per_iteration,
    batch_size,
    log_dir,
    experiment_name,
    kl_threshold,
):
    """
    Main training loop
    """
    # Initializing counters
    avg_reward = 0.0
    episodic_rewards = [[]]
    writer = SummaryWriter(os.path.join(log_dir, experiment_name))
    # Counter to keep track of the global timestep
    t = 0
    # Counter to keep track of the global episode
    episode = 0

    pbar = tq.tqdm(range(num_training_iterations), position=0)
    pbar.set_description("Training loop iteration")
    for iteration in pbar:
        # Rollout phase
        # Reset any counts and start the environment.
        agent.replay_buffer.clear()
        timestep = environment.reset()

        # Make the first observation.
        agent.observe_first(timestep)

        pbar_rollout = tq.tqdm(range(timesteps_per_iteration), position=1, leave=False)
        pbar_rollout.set_description(f"Rollout step")
        for rollout_step in pbar_rollout:
            pbar_rollout.set_postfix(reward=timestep.reward)

            value = agent.get_value(timestep.observation)
            action, log_prob = agent.select_action_and_prob(timestep.observation)
            timestep = environment.step(action)
            agent.observe(value, log_prob, action, timestep)
            # Incrementing count of time step
            t += 1
            # Avg reward per global timestep
            avg_reward += (timestep.reward - avg_reward) / t
            # Add reward of this timestep
            episodic_rewards[-1].append(timestep.reward)

            # Add to logs
            writer.scalar("per_timestep/avg_reward", avg_reward, t)
            writer.scalar("per_timestep/reward", timestep.reward, t)
            writer.scalar("per_timestep/action[0]", action[0], t)
            writer.scalar("per_timestep/action[1]", action[1], t)
            writer.scalar("per_timestep/training_iteration", iteration, t)

            if timestep.last():
                # Add last value
                agent.add_last_value(timestep)
                episode += 1
                # Only restart the environment if it's not the last timestep we should run
                if rollout_step != timesteps_per_iteration - 1:
                    episodic_rewards.append([])
                    # Restart the episode
                    timestep = environment.reset()
                    agent.observe_first(timestep)

        # Log the average episodic return for the episodes corresponding to this training iteration
        # TODO: check if this value is ok of we're missing something
        writer.scalar("per_training_it/avg_episodic_reward", np.mean(episodic_rewards), iteration)
        writer.scalar(
            "per_training_it/episodic_return",
            np.mean([np.sum(episodic_reward) for episodic_reward in episodic_rewards]),
            iteration,
        )
        episodic_rewards = [[]]

        # If last trajectory is not completed, we still need to add the last value
        if not timestep.last():
            # Add last value
            agent.add_last_value(timestep)
            episode += 1

        # Learning phase
        trajectories = agent.get_full_memory()
        # Compute advantages
        advantages = []
        for e, trajectory in enumerate(trajectories):
            value, done = agent.get_last_value_and_done(e)
            advantages.append(general_advantage_estimation(trajectory, value, done, discount, gae_lambda))
        agent.add_advantages(advantages)
        # Flatten the replay buffer
        agent.replay_buffer.flatten_memory()
        dataloader = DataLoader(agent.replay_buffer, batch_size, key_shuffling_batch)

        pbar_epochs = tq.tqdm(range(num_epochs), leave=False, position=1)
        pbar_epochs.set_description(f"Epoch")
        for epoch in pbar_epochs:
            dataloader.shuffle()
            value_losses = []
            policy_losses = []
            for batch in dataloader:
                value_loss, policy_loss, kl_approximation = agent.update(batch)
                value_losses.append(value_loss)
                policy_losses.append(policy_loss)

            writer.scalar("per_epoch/value_loss", np.mean(value_losses), epoch + iteration * num_epochs)
            writer.scalar("per_epoch/policy_loss", np.mean(policy_losses), epoch + iteration * num_epochs)
            writer.scalar("per_epoch/kl_divergence", kl_approximation, epoch + iteration * num_epochs)
            writer.scalar("per_epoch/learning_rate", agent.get_learning_rate())
            writer.scalar("per_epoch/training_iteration", iteration, epoch + iteration * num_epochs)

            if kl_threshold is not None and kl_approximation > kl_threshold:
                pbar_epochs.disp(close=True)
                break
    return

In [None]:
def evaluate(environment, agent, evaluation_episodes):
    frames = []

    pbar = tq.tqdm(range(evaluation_episodes))
    pbar.set_description("Episode")
    for episode in pbar:
        timestep = environment.reset()
        episode_return = 0
        steps = 0
        while not timestep.last():
            frames.append(environment.render(mode="rgb_array"))
            action = agent.select_action(timestep.observation)
            timestep = environment.step(action)
            steps += 1
            episode_return += timestep.reward
        print(f"Episode {episode} ended with reward {episode_return} in {steps} steps")
    return frames

In [None]:
# %tensorboard --logdir experiments

### Train

In [None]:
training_loop(
    agent=agent,
    environment=environment,
    timesteps_per_iteration=timesteps_per_iteration,
    batch_size=batch_size,
    num_training_iterations=num_training_iterations,
    num_epochs=num_epochs,
    log_dir=log_dir,
    experiment_name=experiment_name,
    kl_threshold=kl_threshold,
)

### Evaluate

In [None]:
display_video(evaluate(agent=agent, environment=environment, evaluation_episodes=5))