In [1]:
#!/usr/bin/env python3
import numpy as np
import gym
import os
import multiprocessing

import matplotlib.pyplot as plt
import time
from IPython import display
%matplotlib notebook

from stable_baselines.common.cmd_util import mujoco_arg_parser
from stable_baselines import bench, logger
from stable_baselines.common import set_global_seeds
from stable_baselines.common.vec_env.vec_normalize import VecNormalize
from stable_baselines.ppo2 import PPO2
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy
from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines.results_plotter import load_results, ts2xy


def train(env_id, num_timesteps, seed):
    """
    Train PPO2 model for Mujoco environment, for testing purposes
    :param env_id: (str) the environment id string
    :param num_timesteps: (int) the number of timesteps to run
    :param seed: (int) Used to seed the random generator.
    """
    def make_env(_env_id):
        env_out = gym.make(_env_id)
        env_out = bench.Monitor(env_out, log_dir, allow_early_resets=True)
        return env_out
    
    #n_cpu = multiprocessing.cpu_count()
    #env = SubprocVecEnv([lambda: make_env(env_id) for i in range(n_cpu)])

    env = DummyVecEnv([lambda: make_env(env_id)])
    env = VecNormalize(env)

    def callback(_locals, _globals):
        global n_steps, best_mean_reward
        print("Step:", n_steps)

        if (n_steps + 1) % 100 == 0:
            x, y = ts2xy(load_results(log_dir), 'timesteps')
            if len(x) > 0:
                mean_reward = np.mean(y[-100:])
                print(x[-1], 'timesteps')
                print("Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}".format(best_mean_reward, mean_reward))

                if mean_reward > best_mean_reward:
                    best_mean_reward = mean_reward
                    # Example for saving best model
                    print("Saving new best model")
                    _locals['self'].save("model_{}".format(env_id))
        n_steps += 1
        return False

    set_global_seeds(seed)
    policy = MlpPolicy
    model = PPO2(policy=policy, env=env, n_steps=2048, nminibatches=1, lam=0.95, gamma=0.99, noptepochs=10,
                 ent_coef=0.0, learning_rate=3e-4, cliprange=0.2, verbose=1, tensorboard_log=log_dir)
    model.learn(total_timesteps=num_timesteps, callback=callback)
    model.save("model_{}".format(env_id))

    return model, env

  from ._conv import register_converters as _register_converters


In [None]:
env_id='Ant-v2'
num_timesteps=20000000
seed=343
best_mean_reward, n_steps = -np.inf, 0

base_dir = '/home/nathan/ppo_logs'
prev = [f for f in os.listdir(base_dir) if env_id in f]
log_dir = base_dir + '/{}-{}'.format(env_id, len(prev))
os.makedirs(log_dir, exist_ok=True)

print('Logging to {}'.format(log_dir))

logger.configure()
model, env = train(env_id, num_timesteps, seed)

Logging to /home/nathan/ppo_logs/Ant-v2-0
Logging to /home/nathan/openai_logs
Step: 0
--------------------------------------
| approxkl           | 0.0004054485  |
| clipfrac           | 0.0005371094  |
| ep_rewmean         | -85.8         |
| eplenmean          | 81.6          |
| explained_variance | -0.511        |
| fps                | 329           |
| nupdates           | 1             |
| policy_entropy     | 11.348761     |
| policy_loss        | -0.0072244657 |
| serial_timesteps   | 2048          |
| time_elapsed       | 4.05e-06      |
| total_timesteps    | 2048          |
| value_loss         | 0.6423418     |
--------------------------------------
Step: 1
--------------------------------------
| approxkl           | 0.000264716   |
| clipfrac           | 0.0           |
| ep_rewmean         | -91.6         |
| eplenmean          | 87            |
| explained_variance | -5.38         |
| fps                | 352           |
| nupdates           | 2             |
| policy_

Step: 14
--------------------------------------
| approxkl           | 0.00026014578 |
| clipfrac           | 0.0           |
| ep_rewmean         | -92.9         |
| eplenmean          | 86.6          |
| explained_variance | -0.554        |
| fps                | 346           |
| nupdates           | 15            |
| policy_entropy     | 11.236638     |
| policy_loss        | -0.0051333    |
| serial_timesteps   | 30720         |
| time_elapsed       | 82.5          |
| total_timesteps    | 30720         |
| value_loss         | 0.0936635     |
--------------------------------------
Step: 15
--------------------------------------
| approxkl           | 0.00036001444 |
| clipfrac           | 0.0           |
| ep_rewmean         | -81.1         |
| eplenmean          | 75.2          |
| explained_variance | -0.746        |
| fps                | 343           |
| nupdates           | 16            |
| policy_entropy     | 11.222844     |
| policy_loss        | -0.0056270435 |
| seria

Step: 28
-------------------------------------
| approxkl           | 0.0004772142 |
| clipfrac           | 0.0005371094 |
| ep_rewmean         | -118         |
| eplenmean          | 116          |
| explained_variance | -0.134       |
| fps                | 353          |
| nupdates           | 29           |
| policy_entropy     | 11.096403    |
| policy_loss        | -0.006441079 |
| serial_timesteps   | 59392        |
| time_elapsed       | 165          |
| total_timesteps    | 59392        |
| value_loss         | 0.09851854   |
-------------------------------------
Step: 29
-------------------------------------
| approxkl           | 0.0003932124 |
| clipfrac           | 9.765625e-05 |
| ep_rewmean         | -101         |
| eplenmean          | 98.8         |
| explained_variance | -0.142       |
| fps                | 346          |
| nupdates           | 30           |
| policy_entropy     | 11.078777    |
| policy_loss        | -0.005785974 |
| serial_timesteps   | 61440    

Step: 42
--------------------------------------
| approxkl           | 0.00062688935 |
| clipfrac           | 0.0011230469  |
| ep_rewmean         | -81           |
| eplenmean          | 81.2          |
| explained_variance | -0.199        |
| fps                | 343           |
| nupdates           | 43            |
| policy_entropy     | 10.976013     |
| policy_loss        | -0.0071237935 |
| serial_timesteps   | 88064         |
| time_elapsed       | 248           |
| total_timesteps    | 88064         |
| value_loss         | 0.07302257    |
--------------------------------------
Step: 43
--------------------------------------
| approxkl           | 0.0005001226  |
| clipfrac           | 0.00068359374 |
| ep_rewmean         | -94.3         |
| eplenmean          | 93            |
| explained_variance | -0.134        |
| fps                | 347           |
| nupdates           | 44            |
| policy_entropy     | 10.960793     |
| policy_loss        | -0.0059622466 |
| seria

Step: 56
-------------------------------------
| approxkl           | 0.0010117837 |
| clipfrac           | 0.0033203126 |
| ep_rewmean         | -56.9        |
| eplenmean          | 60.2         |
| explained_variance | -0.0164      |
| fps                | 345          |
| nupdates           | 57           |
| policy_entropy     | 10.830796    |
| policy_loss        | -0.009184063 |
| serial_timesteps   | 116736       |
| time_elapsed       | 330          |
| total_timesteps    | 116736       |
| value_loss         | 0.079894766  |
-------------------------------------
Step: 57
-------------------------------------
| approxkl           | 0.0009453875 |
| clipfrac           | 0.0036132813 |
| ep_rewmean         | -75.7        |
| eplenmean          | 77.8         |
| explained_variance | -0.112       |
| fps                | 348          |
| nupdates           | 58           |
| policy_entropy     | 10.814879    |
| policy_loss        | -0.008891371 |
| serial_timesteps   | 118784   

Step: 71
--------------------------------------
| approxkl           | 0.0013297612  |
| clipfrac           | 0.007714844   |
| ep_rewmean         | -110          |
| eplenmean          | 113           |
| explained_variance | -0.0754       |
| fps                | 345           |
| nupdates           | 72            |
| policy_entropy     | 10.694192     |
| policy_loss        | -0.0098602455 |
| serial_timesteps   | 147456        |
| time_elapsed       | 418           |
| total_timesteps    | 147456        |
| value_loss         | 0.054639      |
--------------------------------------
Step: 72
-------------------------------------
| approxkl           | 0.0012271106 |
| clipfrac           | 0.0058105467 |
| ep_rewmean         | -112         |
| eplenmean          | 114          |
| explained_variance | 0.0156       |
| fps                | 349          |
| nupdates           | 73           |
| policy_entropy     | 10.68395     |
| policy_loss        | -0.009751784 |
| serial_timestep

Step: 86
-------------------------------------
| approxkl           | 0.0018993423 |
| clipfrac           | 0.01689453   |
| ep_rewmean         | -72.6        |
| eplenmean          | 72.8         |
| explained_variance | -0.0817      |
| fps                | 348          |
| nupdates           | 87           |
| policy_entropy     | 10.58709     |
| policy_loss        | -0.01362628  |
| serial_timesteps   | 178176       |
| time_elapsed       | 507          |
| total_timesteps    | 178176       |
| value_loss         | 0.07594336   |
-------------------------------------
Step: 87
-------------------------------------
| approxkl           | 0.0016342113 |
| clipfrac           | 0.011279297  |
| ep_rewmean         | -81.9        |
| eplenmean          | 82.1         |
| explained_variance | 0.0138       |
| fps                | 345          |
| nupdates           | 88           |
| policy_entropy     | 10.575261    |
| policy_loss        | -0.0090522   |
| serial_timesteps   | 180224   

Step: 100
--------------------------------------
| approxkl           | 0.0018627252  |
| clipfrac           | 0.012988281   |
| ep_rewmean         | -36.2         |
| eplenmean          | 41            |
| explained_variance | -0.0617       |
| fps                | 342           |
| nupdates           | 101           |
| policy_entropy     | 10.504389     |
| policy_loss        | -0.0098338295 |
| serial_timesteps   | 206848        |
| time_elapsed       | 590           |
| total_timesteps    | 206848        |
| value_loss         | 0.1152334     |
--------------------------------------
Step: 101
-------------------------------------
| approxkl           | 0.0017657985 |
| clipfrac           | 0.014355469  |
| ep_rewmean         | -32.7        |
| eplenmean          | 38.2         |
| explained_variance | 0.309        |
| fps                | 349          |
| nupdates           | 102          |
| policy_entropy     | 10.50205     |
| policy_loss        | -0.011741488 |
| serial_timest

Step: 115
-------------------------------------
| approxkl           | 0.0022149757 |
| clipfrac           | 0.018408203  |
| ep_rewmean         | -33.9        |
| eplenmean          | 38.8         |
| explained_variance | -0.0357      |
| fps                | 344          |
| nupdates           | 116          |
| policy_entropy     | 10.340238    |
| policy_loss        | -0.010736255 |
| serial_timesteps   | 237568       |
| time_elapsed       | 678          |
| total_timesteps    | 237568       |
| value_loss         | 0.08007287   |
-------------------------------------
Step: 116
-------------------------------------
| approxkl           | 0.0025456161 |
| clipfrac           | 0.024658203  |
| ep_rewmean         | -22.6        |
| eplenmean          | 26.7         |
| explained_variance | 0.0117       |
| fps                | 342          |
| nupdates           | 117          |
| policy_entropy     | 10.31863     |
| policy_loss        | -0.010479967 |
| serial_timesteps   | 239616 

Step: 130
-------------------------------------
| approxkl           | 0.0019721384 |
| clipfrac           | 0.014746094  |
| ep_rewmean         | -51.9        |
| eplenmean          | 52.7         |
| explained_variance | 0.235        |
| fps                | 345          |
| nupdates           | 131          |
| policy_entropy     | 10.207287    |
| policy_loss        | -0.010272095 |
| serial_timesteps   | 268288       |
| time_elapsed       | 768          |
| total_timesteps    | 268288       |
| value_loss         | 0.09372375   |
-------------------------------------
Step: 131
-------------------------------------
| approxkl           | 0.002418675  |
| clipfrac           | 0.021972656  |
| ep_rewmean         | -20.8        |
| eplenmean          | 24.8         |
| explained_variance | 0.104        |
| fps                | 345          |
| nupdates           | 132          |
| policy_entropy     | 10.197997    |
| policy_loss        | -0.012224285 |
| serial_timesteps   | 270336 

Step: 145
-------------------------------------
| approxkl           | 0.00292897   |
| clipfrac           | 0.03203125   |
| ep_rewmean         | -20.9        |
| eplenmean          | 25.7         |
| explained_variance | 0.288        |
| fps                | 351          |
| nupdates           | 146          |
| policy_entropy     | 10.071734    |
| policy_loss        | -0.012040027 |
| serial_timesteps   | 299008       |
| time_elapsed       | 857          |
| total_timesteps    | 299008       |
| value_loss         | 0.082194746  |
-------------------------------------
Step: 146
-------------------------------------
| approxkl           | 0.0015080968 |
| clipfrac           | 0.008837891  |
| ep_rewmean         | -32.8        |
| eplenmean          | 36.4         |
| explained_variance | 0.122        |
| fps                | 341          |
| nupdates           | 147          |
| policy_entropy     | 10.0532055   |
| policy_loss        | -0.01051674  |
| serial_timesteps   | 301056 