Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDPG implementation fails to learn well on at least five MuJoCo-v2 envs for all three noise types. I report steps to reproduce and learning curve plots [and show that PPO2 seems to work fine]. #938

Open
DanielTakeshi opened this issue Jun 20, 2019 · 21 comments

Comments

@DanielTakeshi
Copy link

Dear @pzhokhov @matthiasplappert @christopherhesse et al.,

Thank you for providing an implementation of DDPG. However, I have been unable to get it to learn well on the standard MuJoCo environments by running the provided command in the README (and with related commands). Here are the steps to reproduce. I apologize for the length of the post, but I want to show what I tried to reduce ambiguity and to potentially counter the potential argument that it might be due to bad hyperparameters.

First, here's the machine I am using with relevant versions of software:

  • Ubuntu 18.04
  • MuJoCo 2.0
  • Create a clean Python 3.6.7 virtualenv and install all the required stuff with pip install commands. I'm using TensorFlow 1.13, gym 0.12.1, and mujoco-py 2.0.2.2. All appear to be installed correctly and show no signs of error.
  • Use baselines master branch, commit ba2b017

Next, here are the set of commands to run. I'm splitting these into three groups based on the three types of noise we can inject into our policy.

Group 1: Parameter Noise

I first decided to take the default command provided in the README because I assumed that hyperparameters here have been tuned to save users the time and compute needed for expensive hyperparameter sweeps.

python -m baselines.run --alg=ddpg --env=Ant-v2 --num_timesteps=1e6
python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6
python -m baselines.run --alg=ddpg --env=Hopper-v2 --num_timesteps=1e6
python -m baselines.run --alg=ddpg --env=Swimmer-v2 --num_timesteps=1e6
python -m baselines.run --alg=ddpg --env=Walker2d-v2 --num_timesteps=1e6

I use my plotting code to get plots. Here it is:

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('seaborn-darkgrid')
import argparse
import csv
import pandas
import os
import sys
import pickle
import numpy as np
from os.path import join

# matplotlib
titlesize = 33
xsize = 30
ysize = 30
ticksize = 25
legendsize = 25
error_region_alpha = 0.25


def smoothed(x, w):
    """Smooth x by averaging over sliding windows of w, assuming sufficient length.
    """
    if len(x) <= w:
        return x
    smooth = []
    for i in range(1, w):
        smooth.append( np.mean(x[0:i]) )
    for i in range(w, len(x)+1):
        smooth.append( np.mean(x[i-w:i]) )
    assert len(x) == len(smooth), "lengths: {}, {}".format(len(x), len(smooth))
    return np.array(smooth)


def _get_stuff_from_monitor(mon):
    """Get stuff from `monitor` log files.

    Monitor files are named `0.envidx.monitor.csv` and have one line for each
    episode that finished in that CPU 'core', with the reward, length (number
    of steps) and the time (in seconds). The lengths are not cumulative, but
    time is cumulative.
    """
    scores = []
    steps  = []
    times  = []
    with open(mon, 'r') as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        line_count = 0
        for csv_row in csv_reader:
            # First two lines don't contain interesting stuff.
            if line_count == 0 or line_count == 1:
                line_count += 1
                continue
            scores.append(float(csv_row[0]))
            steps.append(int(csv_row[1]))
            times.append(float(csv_row[2]))
            line_count += 1
    print("finished: {}".format(mon))
    return scores, steps, times


def plot(args):
    """Load monitor curves and the progress csv file. And plot from those.
    """
    nrows, ncols = 1, 2
    fig, ax = plt.subplots(nrows, ncols, squeeze=False, sharey=True, figsize=(11*ncols,7*nrows))
    title = args.title

    # Global statistics across all monitors
    scores_all = []
    steps_all = []
    times_all = []
    total_train_steps = 0
    train_hours = 0

    monitors = sorted(
        [x for x in os.listdir(args.path) if 'monitor.csv' in x and '.swp' not in x]
    )
    progfile = join(args.path,'progress.csv')

    # First row, info from all the monitors, i.e., number of CPUs.
    for env_idx,mon in enumerate(monitors):
        monitor_path = join(args.path, mon)
        scores, steps, times = _get_stuff_from_monitor(monitor_path)

        # Now process to see as a function of episodes and training steps, etc.
        num_episodes = len(scores)
        tr_episodes = np.arange(num_episodes)
        tr_steps = np.cumsum(steps)
        tr_times = np.array(times) / 60.0 # get it in minutes

        # Plot for individual monitors.
        envlabel = 'env {}'.format(env_idx)
        sm_10 = smoothed(scores, w=10)
        ax[0,0].plot(tr_steps, sm_10, label=envlabel+'; avg {:.1f} last {:.1f}'.format(
                np.mean(sm_10), sm_10[-1]))
        sm_100 = smoothed(scores, w=100)
        ax[0,1].plot(tr_times, sm_100, label=envlabel+'; avg {:.1f} last {:.1f}'.format(
                np.mean(sm_100), sm_100[-1]))

        # Handle global stuff.
        total_train_steps += tr_steps[-1]
        train_hours = max(train_hours, tr_times[-1] / 60.0)

    # Bells and whistles
    for row in range(nrows):
        for col in range(ncols):
            ax[row,col].set_ylabel("Scores", fontsize=30)
            ax[row,col].tick_params(axis='x', labelsize=25)
            ax[row,col].tick_params(axis='y', labelsize=25)
            leg = ax[row,col].legend(loc="best", ncol=1, prop={'size':25})
            for legobj in leg.legendHandles:
                legobj.set_linewidth(5.0)
    ax[0,0].set_title(title+', Smoothed (w=10)', fontsize=titlesize)
    ax[0,0].set_xlabel("Train Steps (total {})".format(total_train_steps), fontsize=xsize)
    ax[0,1].set_title(title+', Smoothed (w=100)', fontsize=titlesize)
    ax[0,1].set_xlabel("Train Time (in Hours {:.2f})".format(train_hours), fontsize=xsize)
    plt.tight_layout()
    figname = '{}.png'.format(title)
    plt.savefig(figname)
    print("\nJust saved: {}".format(figname))


if __name__ == "__main__":
    pp = argparse.ArgumentParser()
    pp.add_argument('--path', type=str)
    pp.add_argument('--title', type=str)
    args = pp.parse_args()
    plot(args)

To use this code, just run python [script].py --path [PATH] --title [TITLE]. Feed in the path to the 0.0.monitor.csv file (i.e., that's the /tmp/openai-[DATE] directory) and some title. I did this for all five environments above and got these results:

ant00

halfcheetah00

hopper00

swimmer00

walker00

None of these curves appear to be getting better than random performance. Maaaaaybe Ant-v2 is getting better than random performance, but it seems to be stuck at 0 and many, many papers report values far above 0. Perhaps it has something to do with the number of environments? I briefly tried increasing the number of parallel environments to 8 but that did not seem to work:

python -m baselines.run --alg=ddpg --env=Ant-v2 --num_timesteps=1e6 --num_env=8
python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6 --num_env=8

ant-08envs

halfcheetah-08envs

Incidentally, it seems like having N environments means that the actual number of steps increases by a factor of N. This is different behavior from PPO2 where increasing N does not change the number of actual time steps total at the end; increasing N for PPO2 means each individual environment can execute fewer steps.

PS: for some of the above plots, I did not run to exactly 1M steps, i.e., I terminated it near the end if it was clear that the algorithm was not learning well.

Group 2: Gaussian Noise

All right, next I decided to avoid parameter space noise. In the TD3 paper which used DDPG, the authors used Gaussian noise with standard deviation 0.1. I decided to try that, keeping all other settings fixed:

python -m baselines.run --alg=ddpg --env=Ant-v2 --num_timesteps=1e6 --noise_type=normal_0.1
python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6 --noise_type=normal_0.1
python -m baselines.run --alg=ddpg --env=Hopper-v2 --num_timesteps=1e6 --noise_type=normal_0.1
python -m baselines.run --alg=ddpg --env=Swimmer-v2 --num_timesteps=1e6 --noise_type=normal_0.1
python -m baselines.run --alg=ddpg --env=Walker2d-v2 --num_timesteps=1e6 --noise_type=normal_0.1

Here are the results:

ant01

halfcheetah01

hopper01

swimmer01

walker01

Once again, it seems like there is no learning happening. The performance appears to be similar to the parameter space noise case.

Group 3: OU Noise (along with tau=0.001)

I decided to run one last batch of commands, this time with the original OU noise. After carefully checking the TD3 paper, and the DDPG directory from the July 27, 2017 commit when DDPG was first released, I saw that the tau parameter back then was set at 0.001. Now for some reason it is 0.01. DeepMind used 0.001 so I decided to try OU noise with tau 0.001. This appears to be the only hyperparameter difference that I can see from this code base and the values used by DeepMind.

python -m baselines.run --alg=ddpg --env=Ant-v2 --num_timesteps=1e6  --noise_type=ou_0.2 --tau=0.001
python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6 --noise_type=ou_0.2 --tau=0.001
python -m baselines.run --alg=ddpg --env=Hopper-v2 --num_timesteps=1e6 --noise_type=ou_0.2 --tau=0.001
python -m baselines.run --alg=ddpg --env=Swimmer-v2 --num_timesteps=1e6 --noise_type=ou_0.2 --tau=0.001
python -m baselines.run --alg=ddpg --env=Walker2d-v2 --num_timesteps=1e6 --noise_type=ou_0.2 --tau=0.001

Results:

ant02

halfcheetah02

hopper02

swimmer02

walker02

(The swimmer curve looks like it's going up, but the reward is lower as compared to the other two plots.)

The results I am getting seem to differ from the blog post here which shows HalfCheetah rewards of at least +1500, and much larger depending on the parameter noise setting, and for 2M steps. It might be a hyperparameter issue, but I'm not sure. In particular, notice that the hyperparameters here (for the most part) match those from the DDPG or TD3 papers.

The TD3 paper reports these results:

td3_paper

The TD3 paper says it used DDPG (presumably from OpenAI baselines as of late 2017?) and then "Our DDPG" above is when the author tuned hyperparameters. Both get final rewards that are far higher than what I am seeing, and we are all using 1M training steps here. Unfortunately, from reading the TD3 code base, it is not clear which commit from baselines was used for the results.

The paper above does not report results for Swimmer, so I looked at the "Benchmarking DeepRL" paper, which says DDPG on Swimmer should get 85 +/- 1.8, and this is far higher than the Swimmer results I am getting above.

I suspect that there must be have been some change to the code that caused it to somehow either stop working well or be exorbitantly sensitive to hyperparameters? For example, maybe the process of removing MPI caused some unexpected results? Or it could be due to MuJoCo environments v1 to v2, since the TD3 paper used MuJoCo v1 environments, but as this report suggests, RL performance should be similar. Notice that all the reward curves there for PPO show increasing reward, whereas I'm just seeing stagnation and noise for DDPG.

This is perhaps relevant to the following issue reports:

all of which have noticed issues with DDPG. If the fix is found, then the above can probably all be closed.

Hopefully in the spirit of my previous report on DQN here, we can resolve this issue together. Does anyone have any general suggestions or ideas about the potential causes? At this point I am unable to confidently use the DDPG code because it does not pass standard benchmarks. My previous issue report about DQN suggests that it could be an environment processing issue. Is the code processing the MuJoCo environments in a simliar way as in July 2017? Do the PPO2 results apper to be fine, but the DDPG results off? Is there a difference with how the two algorithms process observations and normalize data?

I'm happy to help investigate this if you have ideas on what might be the root cause. I only report this issue because having highly tuned algorithms and hyper-parameters ready to go "off-the-shelf" greatly helps the entire research community by accelerating research cycles and reducing the need to write our own error-prone implementations of algorithms.

Thanks!

@DanielTakeshi
Copy link
Author

PPO2 Results as a Sanity Check

To confirm that something related to DDPG is the issue (which might include DDPG-specific processing steps) I ran these commands for PPO2, using the same master branch commit as above. These commands:

python -m baselines.run --alg=ppo2 --env=Ant-v2 --num_timesteps=1e6
python -m baselines.run --alg=ppo2 --env=HalfCheetah-v2 --num_timesteps=1e6
python -m baselines.run --alg=ppo2 --env=Hopper-v2 --num_timesteps=1e6
python -m baselines.run --alg=ppo2 --env=Swimmer-v2 --num_timesteps=1e6
python -m baselines.run --alg=ppo2 --env=Walker2d-v2 --num_timesteps=1e6

Resulted in the following learning curves:

ant-ppo

halfcheetah-ppo

hopper-ppo

swimmer-ppo

walker2d-ppo

Which were generated via a similar plotting script as I used in my post above.

These look much better!

  • Ant is exceeding 1000, actually it seems better than TD3 paper results
  • HalfCheetah is getting around 1500, roughly matching TD3 paper results
  • Hopper is getting around 2000 as expected.
  • Swimmer is exceeding 100 points.
  • Walker2d is getting around 1500, slightly lower than reported results from TD3 but still way better than earlier.

It therefore seems like something is wrong with DDPG.

@DanielTakeshi DanielTakeshi changed the title DDPG implementation fails to learn well on at least five MuJoCo-v2 envs for all three noise types. I report steps to reproduce and learning curve plots. DDPG implementation fails to learn well on at least five MuJoCo-v2 envs for all three noise types. I report steps to reproduce and learning curve plots [and show that PPO2 seems to work fine]. Jun 20, 2019
@araffin
Copy link
Contributor

araffin commented Jun 20, 2019

Hello,

was the normalize_obsset to True? (seems like the default, but I would rather double check)
Otherwise, the observation get clipped.
And how many random seeds did you try?

You can find hyperparams (and trained agents) for pybullet envs here (in the rl zoo)

NOTE: it uses stable baselines version, but should be the same underlying algorithm.

@DanielTakeshi
Copy link
Author

@araffin
The normalize observation seems to be True here:

normalize_observations=True,

@schneimo
Copy link

A question about your experiments with OU Noise only:
The OU noise does not use tau. It uses the another parameter called dt. Where do you read that dt and tau are the same? tau is the soft update parameter of the target networks, which is 0.001 in the Baseline code and also in the Deepmind paper. dt is another parameter only for the OU noise. In the Deepmind paper dt doesn't exist.

But this is only a side question. I don't think, that will solve your problem at all. But this clears, that there seems to be no difference in the hyperparameters of the Deepmind paper and the code from Baselines, when using OU noise only.

@DanielTakeshi
Copy link
Author

DanielTakeshi commented Jun 21, 2019

@MoritzTaylor

Just to be clear, the tau variable indeed is the soft update. I only changed it from 0.01 to 0.001 because that is what the DeepMind paper used and I wanted to pick some noise setting to go with it (OU, normal, or param noise). It does not have anything to do specifically with the OU noise. So perhaps I should have run a fourth trial above, with OU noise but with tau 0.01. It is not a big deal.

I agree that it will not solve DDPG's current performance issues.

Note that tau by default is 0.01:

The above will override the 0.001 from the DDPG code here:

class DDPG(object):
def __init__(self, actor, critic, memory, observation_shape, action_shape, param_noise=None, action_noise=None,
gamma=0.99, tau=0.001, normalize_returns=False, enable_popart=False, normalize_observations=True,
batch_size=128, observation_range=(-5., 5.), action_range=(-1., 1.), return_range=(-np.inf, np.inf),
critic_l2_reg=0., actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1.):
# Inputs.

@sritee
Copy link
Contributor

sritee commented Jun 21, 2019

As a side note, in case you're interested in quickly trying something out before this issue gets resolved, I would highly recommend the TD3 author's official implementation (which is in pytorch though). Clean stand alone interface, and I benchmarked it a month or so back and it matched with the paper result. Perhaps seeing some differences could help resolve this issue as well.

@DanielTakeshi
Copy link
Author

@sritee Good point. :) I am actually benchmarking with rlkit right now https://github.com/vitchyr/rlkit which seems to be similar to the original author's implementation.

@araffin
Copy link
Contributor

araffin commented Jun 21, 2019

Hello,

I have been doing a quick sanity check using the rl zoo.

To reproduce, add that to hyperparams/ddpg.yaml:

Gaussian noise:

HalfCheetah-v2:
  n_timesteps: !!float 1e6
  policy: 'MlpPolicy'
  gamma: 0.99
  memory_limit: 1000000
  noise_type: 'normal'
  noise_std: 0.2
  batch_size: 64
  normalize_observations: True
  normalize_returns: False

Param noise:

HalfCheetah-v2:
  n_timesteps: !!float 1e6
  policy: 'LnMlpPolicy'
  gamma: 0.99
  memory_limit: 1000000
  noise_type: 'adaptive-param'
  noise_std: 0.2
  batch_size: 64
  normalize_observations: True
  normalize_returns: False

Command to run (with tensorboard support):

python train.py --algo ddpg --env HalfCheetah-v2 -tb /tmp/ddpg/

With one random seed, 100k steps on HalfCheetah-v2, I am getting those results (gaussian noise is orange, param noise is blue):

ddpg

So it seems something wrong happened with the original baselines (stable baselines is based on OpenAI Baselines of last year).

EDIT: to match TD3 paper, you would need to change the network architecture too, using policy_kwargs: "dict(layers=[256, 256, 256])" for instance

@schneimo
Copy link

schneimo commented Jun 22, 2019

@MoritzTaylor

Just to be clear, the tau variable indeed is the soft update. I only changed it from 0.01 to 0.001 because that is what the DeepMind paper used and I wanted to pick some noise setting to go with it (OU, normal, or param noise). It does not have anything to do specifically with the OU noise. So perhaps I should have run a fourth trial above, with OU noise but with tau 0.01. It is not a big deal.

I agree that it will not solve DDPG's current performance issues.

Note that tau by default is 0.01:

The above will override the 0.001 from the DDPG code here:

class DDPG(object):
def __init__(self, actor, critic, memory, observation_shape, action_shape, param_noise=None, action_noise=None,
gamma=0.99, tau=0.001, normalize_returns=False, enable_popart=False, normalize_observations=True,
batch_size=128, observation_range=(-5., 5.), action_range=(-1., 1.), return_range=(-np.inf, np.inf),
critic_l2_reg=0., actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1.):
# Inputs.

Ok, unfortunately I did not see that tau is set to 0.01 in the learn function. Of course this overwrites all other tau params.

But since this tau occurs regardless of the type of noise, you should always set tau to 0.001 manually, or not? And neither in the parameter noise setup nor in the gaussian noise setup you set tau to 0.001 manually, or am I missing something again? But I am not sure if this really solves the problem, since the OU noise setup does not work properly as well.

@DanielTakeshi
Copy link
Author

@MoritzTaylor Right, if this were a research paper, I would always keep tau = 0.001, or always keep it at 0.01 to keep experimental conditions the same while varying just one thing (the noise setting, if that was what I was investigating, but really I was just trying to see what setting would lead to any improvements in performance). I only changed it here because I thought I might as well at least try to see what happens.

@araffin Thanks for the results. It does seem like something happened between the commit corresponding to stable-baselines and the commit corresponding to the most recent one on master, that I used to test. Out of curiosity do you have the exact commit for baselines here that corresponds to what stable-baselines uses?

@araffin
Copy link
Contributor

araffin commented Jun 22, 2019

@DanielTakeshi

Out of curiosity do you have the exact commit for baselines here that corresponds to what stable-baselines uses?

there was several changes/bug fixes afterward, but we forked it (apparently) from:
hill-a@a6b1bc7

@DanielTakeshi
Copy link
Author

Thanks, maybe something happened between then that caused changes in the environment processing code? From reading OpenAI's DDPG code, I can't find any obvious errors yet.

Hmm @araffin just wondering, you report the "exploration policy" right? From looking at the DDPG code, there is an exploration environment by default, and there is a separate evaluation environment which will step in the environment with a deterministic policy (which is what we want all along with DDPG). I wonder if the evaluation policy can still do a good job even if the exploration policy is very bad. My tests with other software packages that have DDPG show that the exploration policy does nearly as well as the evaluation policy so I am not entirely optimistic, but it is something to think about.

@araffin
Copy link
Contributor

araffin commented Jun 24, 2019

maybe something happened between then that caused changes in the environment processing code

Maybe. In the stable-baselines code, there is no preprocessing for DDPG.

you report the "exploration policy" right?

Exactly. This is the behavioral policy, so the deterministic policy + noise in that case.

there is a separate evaluation environment which will step in the environment with a deterministic policy

Yes, the eval env will only be used if you provide it. I did not provide any in my case.

I wonder if the evaluation policy can still do a good job even if the exploration policy is very bad.
My tests with other software packages that have DDPG show that the exploration policy does nearly as well as the evaluation policy so I am not entirely optimistic, but it is something to think about.

I think it would do only in the case of high amount of noise. Otherwise, the training performance (of the exploration policy) is usually a good proxy for the real performance.

@araffin
Copy link
Contributor

araffin commented Jun 24, 2019

@DanielTakeshi I think your intuition was good:

env = VecNormalize(env, use_tf=True)

It seems the normalization is applied twice (and reward normalization is also active by default).
Can you check commenting this line?

@DanielTakeshi
Copy link
Author

@araffin @pzhokhov

I commented the line out above. Results look much better, though they're not as good as some published results that I see (e.g., from the TD3 paper figure above).

After avoiding the VecNormalize command, I ran the following:

python -m baselines.run --alg=ddpg --env=Ant-v2 --num_timesteps=1e6
python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6
python -m baselines.run --alg=ddpg --env=Hopper-v2 --num_timesteps=1e6
python -m baselines.run --alg=ddpg --env=Swimmer-v2 --num_timesteps=1e6
python -m baselines.run --alg=ddpg --env=Walker2d-v2 --num_timesteps=1e6
python -m baselines.run --alg=ddpg --env=InvertedPendulum-v2 --num_timesteps=1e6
python -m baselines.run --alg=ddpg --env=Reacher-v2 --num_timesteps=1e6
python -m baselines.run --alg=ddpg --env=Ant-v2 --num_timesteps=1e6 --noise_type=normal_0.1
python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6 --noise_type=normal_0.1
python -m baselines.run --alg=ddpg --env=Hopper-v2 --num_timesteps=1e6 --noise_type=normal_0.1
python -m baselines.run --alg=ddpg --env=Swimmer-v2 --num_timesteps=1e6 --noise_type=normal_0.1
python -m baselines.run --alg=ddpg --env=Walker2d-v2 --num_timesteps=1e6 --noise_type=normal_0.1
python -m baselines.run --alg=ddpg --env=InvertedPendulum-v2 --num_timesteps=1e6 --noise_type=normal_0.1
python -m baselines.run --alg=ddpg --env=Reacher-v2 --num_timesteps=1e6 --noise_type=normal_0.1
python -m baselines.run --alg=ddpg --env=Ant-v2 --num_timesteps=1e6  --noise_type=ou_0.2 --tau=0.001
python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6 --noise_type=ou_0.2 --tau=0.001
python -m baselines.run --alg=ddpg --env=Hopper-v2 --num_timesteps=1e6 --noise_type=ou_0.2 --tau=0.001
python -m baselines.run --alg=ddpg --env=Swimmer-v2 --num_timesteps=1e6 --noise_type=ou_0.2 --tau=0.001
python -m baselines.run --alg=ddpg --env=Walker2d-v2 --num_timesteps=1e6 --noise_type=ou_0.2 --tau=0.001
python -m baselines.run --alg=ddpg --env=InvertedPendulum-v2 --num_timesteps=1e6 --noise_type=ou_0.2 --tau=0.001
python -m baselines.run --alg=ddpg --env=Reacher-v2 --num_timesteps=1e6 --noise_type=ou_0.2 --tau=0.001

For each of the seven environments, I ran one training run with three noise settings. I then have this following plot, which overlays the three noise settings together for each game to make comparisons easier:

compare_three_noise_methods

Note that these are the exploration environment rewards. I don't have an evaluation environment here. (It's easy to enforce an evaluation environment, and I'm actually testing that now, but there's no one single command we can add to the command line to get the evaluation environment set up.) Also, I plot these episode rewards from the 0.0.monitor.csv files that are stored in the log directory, and I smooth by a sliding window of size 20 across the episodes list.

  • Ant: parameter space noise doesn't seem to work well but the other two finally get above 0, but getting 1000+ would be ideal.
  • HalfCheetah: looks better, gets to +2000-ish for parameter/Gaussian noise, though the TD3 paper claims DDPG should get 3000+.
  • Hopper: parameter space noise does the best, with 1500, but the other two can get over 500. But we should be seeing about 2000+.
  • InvertedPendulum: Gaussian noise sometimes gets to the 1000 limit, the others don't, though this might be a function of how we're using an exploration policy. Ideally all methods get 1000 on this env.
  • Reacher: Gaussian and OU noise are better with around -10 as the reward, but getting -4 or -3 would be ideal.
  • Swimmer: the methods get around 14-43. It is not really much of an improvement over the prior results I was getting, unfortunately.
  • Walker2d: gets around 400-800 but ideally we'd get around 2000.

It looks like parameter space noise might have a slight advantage, but I am not sure I can tell from these limited results. Gaussian noise on the actions seems easier conceptually.

Overall these results are giving me much more confidence in the code base. :) And I think with evaluation rewards instead of exploration rewards, the above would be closer to published results for DDPG.

@sritee
Copy link
Contributor

sritee commented Jul 11, 2019

@DanielTakeshi Did you run any of these benchmarks on vision-based tasks, or know of any results?

@DanielTakeshi
Copy link
Author

@sritee No, I'm not aware of any standard vision-based DDPG tasks. You'd actually need to change the network design in the DDPG class as well

@GameHoo
Copy link

GameHoo commented Aug 17, 2019

there are some differences between this code and the implementation in the paper:

image

@schneimo
Copy link

@GameHoo
Yes, because OpenAI published not the original DDPG but rather the implementation of the paper PARAMETER SPACE NOISE FOR EXPLORATION. I remember that the parameters matched with the paper.

@christopherhesse
Copy link
Contributor

@DanielTakeshi do you get the same issues if you downgrade mujoco as in openai/gym#1541 ?

@DanielTakeshi
Copy link
Author

@christopherhesse Apologies for not responding, I never had time to follow-up on this, apologies.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants