# Installs & Imports

In [None]:
!pip install dm-env
!pip install dm-haiku
!pip install dm-tree
!pip install optax
!pip install distrax
!pip install chex
!pip install imageio
!pip install gym
!pip install gym[classic_control]
!pip install free-mujoco-py
!pip install wandb

!apt install -y x11-utils xvfb x11-utils ffmpeg
!apt install -y libgl1-mesa-dev libgl1-mesa-glx libglew-dev libosmesa6-dev software-properties-common
!apt install -y patchelf

!pip install pyglet
!pip install gym pyvirtualdisplay
!pip install imageio-ffmpeg

from IPython.display import clear_output
clear_output()

In [None]:
import IPython
from IPython.display import HTML
from IPython import display as ipythondisplay

# Python stuff
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# RL stuff
import dm_env
import gym
import mujoco_py

# DL stuff
from jax import tree_util
import haiku as hk
import jax
import jax.numpy as jnp
import optax
import distrax

from base64 import b64encode
from collections import namedtuple
from typing import *
from tqdm import tqdm
import base64
import chex
import collections
import enum
import functools
import imageio
import io
import itertools
import multiprocessing as mp
import multiprocessing.connection
import random
import time
import tree
import warnings
import wandb

# gym render things, if this doesn't work for you -> good luck
import pyglet
pyglet.options['search_local_libs'] = False
pyglet.options['shadow_window']=False
from pyglet.window import xlib
xlib._have_utf8 = False

from pyvirtualdisplay import Display
display = Display(visible=False, size=(1400, 900))
display.start()

<pyvirtualdisplay.display.Display at 0x7f8488194550>

# Environnement

In [None]:
class InvertedPendulumEnv(dm_env.Environment):
    def __init__(self, for_evaluation: bool) -> None:
        self._env = gym.make('InvertedPendulum-v2')
        self._for_evaluation = for_evaluation
        if self._for_evaluation:
            self.screens = []

    def step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep:
        new_obs, reward, done, _ = self._env.step(action)
        if self._for_evaluation:
            self.screens.append(self._env.render(mode='rgb_array'))
        if done:
            return dm_env.termination(reward, new_obs)
        return dm_env.transition(reward, new_obs)

    def reset(self) -> dm_env.TimeStep:
        obs = self._env.reset()
        if self._for_evaluation:
            self.screens.append(self._env.render(mode='rgb_array'))
        return dm_env.restart(obs)

    def observation_spec(self) -> dm_env.specs.BoundedArray:
        return dm_env.specs.Array(shape=(4,), dtype=np.float32)

    def action_spec(self) -> dm_env.specs.BoundedArray:
        return dm_env.specs.BoundedArray(shape=(1,), minimum=-3., maximum=3., dtype=np.float32)

    def close(self) -> None:
        self._env.close()

class ReacherEnv(dm_env.Environment):
    def __init__(self, for_evaluation: bool) -> None:
        self._env = gym.make("Reacher-v2")
        self._for_evaluation = for_evaluation
        if self._for_evaluation:
            self.screens = []

    def step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep:
        new_obs, reward, done, _ = self._env.step(action)
        if self._for_evaluation:
            self.screens.append(self._env.render(mode='rgb_array'))
        if done:
            return dm_env.termination(reward, new_obs)
        return dm_env.transition(reward, new_obs)

    def reset(self) -> dm_env.TimeStep:
        obs = self._env.reset()
        if self._for_evaluation:
            self.screens.append(self._env.render(mode='rgb_array'))
        return dm_env.restart(obs)

    def observation_spec(self) -> dm_env.specs.BoundedArray:
        return dm_env.specs.Array(shape=(11,), dtype=np.float32)

    def action_spec(self) -> dm_env.specs.BoundedArray:
        return dm_env.specs.BoundedArray(shape=(2,), minimum=-1., maximum=1., dtype=np.float32)

    def close(self) -> None:
        self._env.close()

class InvertedDoublePendulumEnv(dm_env.Environment):
    def __init__(self, for_evaluation: bool) -> None:
        self._env = gym.make("InvertedDoublePendulum-v2")
        self._for_evaluation = for_evaluation
        if self._for_evaluation:
            self.screens = []

    def step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep:
        new_obs, reward, done, _ = self._env.step(action)
        if self._for_evaluation:
            self.screens.append(self._env.render(mode='rgb_array'))
        if done:
            return dm_env.termination(reward, new_obs)
        return dm_env.transition(reward, new_obs)

    def reset(self) -> dm_env.TimeStep:
        obs = self._env.reset()
        if self._for_evaluation:
            self.screens.append(self._env.render(mode='rgb_array'))
        return dm_env.restart(obs)

    def observation_spec(self) -> dm_env.specs.BoundedArray:
        return dm_env.specs.Array(shape=(11,), dtype=np.float32)

    def action_spec(self) -> dm_env.specs.BoundedArray:
        return dm_env.specs.BoundedArray(shape=(1,), minimum=-1., maximum=1., dtype=np.float32)

    def close(self) -> None:
        self._env.close()

class CartPole(dm_env.Environment):
    def __init__(self, for_evaluation: bool) -> None:
        self._env = gym.make('CartPole-v1')
        self._for_evaluation = for_evaluation
        if self._for_evaluation:
            self.screens = []

    def step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep:
        new_obs, reward, done, _ = self._env.step(action)
        if self._for_evaluation:
            self.screens.append(self._env.render(mode='rgb_array'))
        if done:
            return dm_env.termination(reward, new_obs)
        return dm_env.transition(reward, new_obs)

    def reset(self) -> dm_env.TimeStep:
        obs = self._env.reset()
        if self._for_evaluation:
            self.screens.append(self._env.render(mode='rgb_array'))
        return dm_env.restart(obs)

    def observation_spec(self) -> dm_env.specs.BoundedArray:
        return dm_env.specs.Array(shape=(4,), dtype=np.float32)

    def action_spec(self) -> dm_env.specs.BoundedArray:
        return dm_env.specs.DiscreteArray(num_values=self._env.action_space.n, dtype=np.int32)

    def close(self) -> None:
        self._env.close()

class TransitionBuffer():
    def __init__(self):
        self.action = []
        self.action_log_prob = []
        self.state_old = []
        self.state_new = []
        self.state_type_old = []
        self.state_type_new = []
        self.reward = []
    def push(self, ts, next_ts, a, log_prob):
        self.action.append(a)
        self.action_log_prob.append(log_prob)
        self.state_old.append(ts.observation)
        self.state_type_old.append(ts.step_type)
        self.state_new.append(next_ts.observation)
        self.state_type_new.append(next_ts.step_type)
        self.reward.append(next_ts.reward)


# Parameters

In [None]:
# Parameters
logging = True
continuous = True
env_factory = InvertedPendulumEnv
iterations = 1000
N = 5
T = 2048
gamma = 0.99
lbd = 0.95
n_epochs = 10
batch_size = 64
eps = 0.2 # Clip loss epsilon
c1 = 0.9 # Value Function loss coefficient
c2 = 0.01 # Entropy loss coefficient
key = jax.random.PRNGKey(28042000)
learning_rate = 3e-4

# Initialize env
env = env_factory(False)
state_dim = env.observation_spec().shape[0]
if not(continuous):
    action_dim = env.action_spec().num_values
else:
    action_dim = env.action_spec().shape[0]

# Networks
def policy_network(x):
    x = hk.nets.MLP([64,64,action_dim], activation=jax.numpy.tanh)(x)
    return x
key, subkey = jax.random.split(key)
policy = hk.without_apply_rng(hk.transform(policy_network))
policy_params = policy.init(rng=subkey, x=jnp.ones([1, state_dim]))

def value_network(x):
    x = hk.nets.MLP([64,64,1], activation=jax.numpy.tanh)(x)
    return x
key, subkey = jax.random.split(key)
value = hk.without_apply_rng(hk.transform(value_network))
value_params = value.init(rng=subkey, x=jnp.ones([1, state_dim]))


# Optimizers
policy_optimizer = optax.adam(learning_rate)
policy_opt_state = policy_optimizer.init(policy_params)
value_optimizer = optax.adam(learning_rate)
value_opt_state = value_optimizer.init(value_params)

# PPO

In [None]:
@jax.jit
def loss_fct(policy_params, value_params,
             advantage, state, action, target, log_prob,
             eps, c1, c2):
    policy_apply_output = policy.apply(policy_params, state)
    entropies, log_prob_new = jax.vmap(policy_entropy_logprob)(policy_apply_output, action)
    r = jnp.exp(log_prob_new - log_prob) # Compute r ratio
    # Compute each term
    l_clip = jnp.minimum(r * advantage, jnp.clip(r, 1 - eps, 1 + eps) * advantage)
    l_entropy = entropies
    l_vf = jnp.square((value.apply(value_params, state) - target))
    loss = jnp.mean(- l_clip + c1 * l_vf - c2 * l_entropy)
    return loss

@jax.jit
def compute_advantage(policy_params, value_params,
                      reward, state_old, state_new, state_type_old, state_type_new, action,
                      gamma, lbd):
    is_first = (state_type_old == dm_env.StepType.FIRST)
    is_last = jnp.roll(is_first, -1)
    target = reward + gamma * value.apply(value_params, state_new).squeeze() * (1 - is_last)
    delta = (target - value.apply(value_params, state_old).squeeze())
    advantage = []
    acc = 0.
    for d, reset in zip(delta[::-1], is_first[::-1]):
        acc = lbd * gamma * acc + d
        advantage.append(acc)
        acc *= (1 - reset)
    advantage = jnp.array(advantage[::-1])
    advantage = (advantage - jnp.mean(advantage)) / jnp.std(advantage) # Normalize advantage
    return target, advantage

if not(continuous):
    @jax.jit
    def policy_entropy_logprob(policy_apply_output, action):
        distrib = distrax.Categorical(logits=policy_apply_output)
        entropy = distrib.entropy()
        log_prob = distrib.log_prob(action)
        return entropy, log_prob
    @jax.jit
    def policy_sample(policy_apply_output, key):
        distrib = distrax.Categorical(logits=policy_apply_output)
        action, log_prob = distrib.sample_and_log_prob(seed=key, sample_shape=())
        return action.astype(int), log_prob
else:
    std = 0.2 * jnp.ones(action_dim)
    @jax.jit
    def policy_entropy_logprob(policy_apply_output, action, std=std):
        distrib = distrax.MultivariateNormalDiag(loc=policy_apply_output, scale_diag=std)
        entropy = distrib.entropy()
        log_prob = distrib.log_prob(action)
        return entropy, log_prob
    @jax.jit
    def policy_sample(policy_apply_output, key):
        distrib = distrax.MultivariateNormalDiag(loc=policy_apply_output, scale_diag=std)
        action, log_prob = distrib.sample_and_log_prob(seed=key, sample_shape=())
        return action, log_prob
        
full_reward_history = []
episode_reward_history = []
episode_timesteps_history = []
loss_history = []
ts_count = 0
if logging:
    wandb.init(project="ppo_jax",reinit=True)
# for i in tqdm(range(iterations)):
while ts_count < 1000000:
    ts_buffer = TransitionBuffer()
    # Run N episodes over T timesteps
    for actor in range(N):
        ts = env.reset()
        acc_reward = 0
        for t in range(T):
            if ts.step_type == dm_env.StepType.LAST:
                break
            policy_apply_output = policy.apply(policy_params, jnp.expand_dims(ts.observation,0))
            key, subkey = jax.random.split(key)
            # action = int(jax.random.choice(subkey, action_dim, p=proba))
            action, log_prob = policy_sample(jnp.squeeze(policy_apply_output,0), key)
            next_ts = env.step(action)
            ts_buffer.push(ts, next_ts, action, log_prob)
            ts = next_ts
            acc_reward += next_ts.reward

            # Logging 
            full_reward_history.append(acc_reward)
            ts_count += 1
            if (ts_count % 1000000) == 0:
                print(f'{ts_count} timesteps reached')
            if (ts_count % 100000) == 0:
                eps = eps / 2.
        episode_reward_history.append(acc_reward)
        episode_reward_history.append(ts_count)
        if logging:
            wandb.log({'episode_reward':acc_reward, 'episode_timestep':ts_count})

    # Compute advantage
    action = jnp.array(ts_buffer.action)
    state_old = jnp.array(ts_buffer.state_old)
    state_new = jnp.array(ts_buffer.state_new)
    state_type_old = jnp.array(ts_buffer.state_type_old)
    state_type_new = jnp.array(ts_buffer.state_type_new)
    reward = jnp.array(ts_buffer.reward)
    log_prob = jnp.array(ts_buffer.action_log_prob)

    target, advantage = compute_advantage(policy_params, value_params, reward, state_old, state_new, state_type_old, state_type_new, action, gamma, lbd)

    # Update parameters
    for _ in range(n_epochs):
        # Compute loss            
        loss, (policy_grads, value_grads) = jax.value_and_grad(loss_fct, argnums = (0, 1))(
            policy_params, value_params, advantage, state_old,
            action, target, log_prob, eps, c1, c2
        )

        # Update parameters
        updates, policy_opt_state = policy_optimizer.update(policy_grads, policy_opt_state)
        policy_params = optax.apply_updates(policy_params, updates)

        updates, value_opt_state = value_optimizer.update(value_grads, value_opt_state)
        value_params = optax.apply_updates(value_params, updates)

        loss_history.append(float(loss))

[34m[1mwandb[0m: Currently logged in as: [33mlmagne[0m (use `wandb login --relogin` to force relogin)




# Results

In [None]:
def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w
plt.figure(figsize=(8, 6), dpi=80)
plt.plot(moving_average(full_reward_history,1000))
plt.title('rewards')
plt.tight_layout()
plt.savefig('reward.png')
plt.show()

plt.figure(figsize=(8, 6), dpi=80)
plt.plot(episode_reward_history)
plt.title('full episode rewards')
plt.show()

plt.figure(figsize=(8, 6), dpi=80)
plt.plot(loss_history)
plt.title('loss')
plt.show()

In [None]:
# Initialize env
env = env_factory(True)
# state_dim = env.observation_spec().shape[0]
# action_dim = env.action_spec().num_values

def display_video(frames, filename='temp.mp4', frame_repeat=1):
    """Save and display video."""
    frames = np.stack(frames, axis=0)
    # Write video
    with imageio.get_writer(filename, fps=60) as video:
        for frame in frames:
            for _ in range(frame_repeat):
                video.append_data(frame)
        # Read video and display the video
        video = open(filename, 'rb').read()
        b64_video = base64.b64encode(video)
        video_tag = ('<video  width="320" height="240" controls alt="test" '
                   'src="data:video/mp4;base64,{0}">').format(b64_video.decode())
    return IPython.display.HTML(video_tag)

def clip(i):
    T = 50000

    ts = env.reset()
    acc_reward = 0
    episode_rewards = []
    key = jax.random.PRNGKey(28042000)
    for t in range(T):
        if ts.step_type == dm_env.StepType.LAST:
            break
        print(t, ts.reward)
        proba = policy.apply(policy_params, jnp.expand_dims(ts.observation,0))
        key, subkey = jax.random.split(key)
        action, log_prob = policy_sample(proba, key)
        ts = env.step(action)
        acc_reward += ts.reward
        episode_rewards.append(acc_reward)
    display_video(env.screens, f'temp{i}.mp4')
    
for i in range(10):
    clip(i)

In [None]:
from google.colab import files
files.download('reward.png') 
files.download('temp.mp4') 