Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Tim Salimans committed Jul 4, 2018
1 parent b38191e commit 0c1b112
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
9 changes: 9 additions & 0 deletions atari_reset/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,15 @@ def learn(policy, env, nsteps, total_timesteps, ent_coef=1e-4, lr=1e-4,
[epinfo['as_good_as_demo'] for epinfo in epinfos_to_report]))
logger.logkv('perc_started_below_max_sp', safemean(
[epinfo['starting_point'] <= env.max_starting_point for epinfo in epinfos_to_report]))
elif hasattr(env.venv, 'max_starting_point'):
logger.logkv('max_starting_point', env.venv.max_starting_point)
logger.logkv('as_good_as_demo_start', safemean(
[epinfo['as_good_as_demo'] for epinfo in epinfos_to_report if
epinfo['starting_point'] <= env.venv.max_starting_point]))
logger.logkv('as_good_as_demo_all', safemean(
[epinfo['as_good_as_demo'] for epinfo in epinfos_to_report]))
logger.logkv('perc_started_below_max_sp', safemean(
[epinfo['starting_point'] <= env.venv.max_starting_point for epinfo in epinfos_to_report]))

logger.logkv('time_elapsed', tnow - tfirststart)
logger.logkv('perc_valid', np.mean(valids))
Expand Down
6 changes: 3 additions & 3 deletions atari_reset/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from multiprocessing import Process, Pipe
import mpi4py.rc
import horovod.tensorflow as hvd
mpi4py.rc.initialize = False
from mpi4py import MPI

Expand Down Expand Up @@ -392,15 +393,14 @@ def __init__(self, env):
super(ResetManager, self).__init__(env)
starting_points = self.env.recursive_getattr('starting_point')
all_starting_points = flatten_lists(MPI.COMM_WORLD.allgather(starting_points))
self.min_starting_point = max(all_starting_points)
self.min_starting_point = min(all_starting_points)
self.max_starting_point = max(all_starting_points)
self.nrstartsteps = self.max_starting_point - self.min_starting_point
assert(self.nrstartsteps > 10)
self.max_max_starting_point = self.max_starting_point
self.starting_point_success = np.zeros(self.max_starting_point+10000)
self.counter = 0
self.infos = []
if isinstance(self.nrstartsteps, list):
self.nrstartsteps = self.nrstartsteps[0]

def proc_infos(self):
epinfos = [info['episode'] for info in self.infos if 'episode' in info]
Expand Down
2 changes: 1 addition & 1 deletion train_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def env_fn():
env.set_max_starting_point(starting_point)

policy = {'cnn' : CnnPolicy, 'gru': GRUPolicy}[policy]
learn(policy=policy, env=env, nsteps=128, lam=.95, gamma=.9995, noptepochs=4, log_interval=1, save_interval=100,
learn(policy=policy, env=env, nsteps=128, lam=.95, gamma=.999, noptepochs=4, log_interval=1, save_interval=100,
ent_coef=entropy_coef, l2_coef=1e-7, lr=lr, cliprange=0.1, total_timesteps=num_timesteps,
norm_adv=True, load_path=load_path, save_path=save_path, game_name=game_name)

Expand Down

0 comments on commit 0c1b112

Please sign in to comment.