Skip to content

Commit

Permalink
[RLlib] Fix 2 broken CI test cases: test_learner_group and `cartpol…
Browse files Browse the repository at this point in the history
…e_dqn_envrunner`. (ray-project#45110)
  • Loading branch information
sven1977 authored and harborn committed May 8, 2024
1 parent 9d33dae commit 211c8d0
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 44 deletions.
67 changes: 28 additions & 39 deletions rllib/core/learner/tests/test_learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def test_add_remove_module(self):
),
)

self._check_multi_worker_weights(learner_group, results)
_check_multi_worker_weights(learner_group, results)

# check that module ids are updated to include the new module
module_ids_after_add = {DEFAULT_MODULE_ID, new_module_id}
Expand All @@ -260,7 +260,7 @@ def test_add_remove_module(self):
# run training without the test_module
results = learner_group.update_from_batch(batch.as_multi_agent())

self._check_multi_worker_weights(learner_group, results)
_check_multi_worker_weights(learner_group, results)

# check that module ids are updated after remove operation to not
# include the new module
Expand All @@ -272,20 +272,6 @@ def test_add_remove_module(self):
learner_group.shutdown()
del learner_group

def _check_multi_worker_weights(self, learner_group, results):
# Check that module weights are updated across workers and synchronized.
# for i in range(1, len(results)):
for module_id, mod_results in results.items():
if module_id == ALL_MODULES:
continue
# Compare the reported mean weights (merged across all Learner workers,
# which all should have the same weights after updating) with the actual
# current mean weights.
reported_mean_weights = mod_results["mean_weight"]
parameters = learner_group.get_weights(module_ids=[module_id])[module_id]
actual_mean_weights = np.mean([w.mean() for w in parameters.values()])
check(reported_mean_weights, actual_mean_weights, rtol=0.02)


class TestLearnerGroupCheckpointRestore(unittest.TestCase):
@classmethod
Expand Down Expand Up @@ -525,7 +511,6 @@ def test_async_update(self):
config = BaseTestingAlgorithmConfig().update_from_dict(config_overrides)
learner_group = config.build_learner_group(env=env)
reader = get_cartpole_dataset_reader(batch_size=512)
min_loss = float("inf")
batch = reader.next()
timer_sync = _Timer()
timer_async = _Timer()
Expand All @@ -541,8 +526,7 @@ def test_async_update(self):
# way to check that is if the time for an async update call is faster
# than the time for a sync update call.
self.assertLess(timer_async.mean, timer_sync.mean)
self.assertIsInstance(result_async, list)
self.assertEqual(len(result_async), 0)
self.assertIsInstance(result_async, dict)
iter_i = 0
while True:
batch = reader.next()
Expand All @@ -551,31 +535,36 @@ def test_async_update(self):
)
if not async_results:
continue
losses = [
np.mean(
[res[ALL_MODULES][Learner.TOTAL_LOSS_KEY] for res in results]
)
for results in async_results
]
min_loss_this_iter = min(losses)
min_loss = min(min_loss_this_iter, min_loss)
print(
f"[iter = {iter_i}] Loss: {min_loss_this_iter:.3f}, Min Loss: "
f"{min_loss:.3f}"
)
loss = async_results[ALL_MODULES][Learner.TOTAL_LOSS_KEY]
# The loss is initially around 0.69 (ln2). When it gets to around
# 0.57 the return of the policy gets to around 100.
if min_loss < 0.57:
if loss < 0.57:
break
for results in async_results:
for res1, res2 in zip(results, results[1:]):
self.assertEqual(
res1[DEFAULT_MODULE_ID]["mean_weight"],
res2[DEFAULT_MODULE_ID]["mean_weight"],
)
# Compare reported "mean_weight" with actual ones.
# TODO (sven): Right now, we don't have any way to know, whether
# an async update result came from the most recent call to
# `learner_group.update_from_batch(async_update=True)` or an earlier
# one. Once APPO/IMPALA are properly implemented on the new API stack,
# this problem should be resolved and we can uncomment the below line.
# _check_multi_worker_weights(learner_group, async_results)
iter_i += 1
learner_group.shutdown()
self.assertLess(min_loss, 0.57)
self.assertLess(loss, 0.57)


def _check_multi_worker_weights(learner_group, results):
# Check that module weights are updated across workers and synchronized.
# for i in range(1, len(results)):
for module_id, mod_results in results.items():
if module_id == ALL_MODULES:
continue
# Compare the reported mean weights (merged across all Learner workers,
# which all should have the same weights after updating) with the actual
# current mean weights.
reported_mean_weights = mod_results["mean_weight"]
parameters = learner_group.get_weights(module_ids=[module_id])[module_id]
actual_mean_weights = np.mean([w.mean() for w in parameters.values()])
check(reported_mean_weights, actual_mean_weights, rtol=0.02)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions rllib/env/single_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def _sample_timesteps(
if explore:
env_steps_lifetime = self.metrics.peek(
NUM_ENV_STEPS_SAMPLED_LIFETIME
) + self.metrics.peek(NUM_ENV_STEPS_SAMPLED, default=0)
)
to_env = self.module.forward_exploration(
to_module, t=env_steps_lifetime
)
Expand Down Expand Up @@ -465,7 +465,7 @@ def _sample_episodes(
if explore:
env_steps_lifetime = self.metrics.peek(
NUM_ENV_STEPS_SAMPLED_LIFETIME
) + self.metrics.peek(NUM_ENV_STEPS_SAMPLED, default=0)
)
to_env = self.module.forward_exploration(
to_module, t=env_steps_lifetime
)
Expand Down
13 changes: 10 additions & 3 deletions rllib/tuned_examples/dqn/cartpole_dqn_envrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
model_config_dict={
"fcnet_hiddens": [256],
"fcnet_activation": "relu",
"epsilon": [(0, 1.0), (50000, 0.05)],
"epsilon": [(0, 1.0), (10000, 0.02)],
"fcnet_bias_initializer": "zeros_",
"post_fcnet_bias_initializer": "zeros_",
"post_fcnet_hiddens": [256],
Expand All @@ -23,7 +23,7 @@
# Settings identical to old stack.
replay_buffer_config={
"type": "PrioritizedEpisodeReplayBuffer",
"capacity": 100000,
"capacity": 50000,
"alpha": 0.6,
"beta": 0.4,
},
Expand All @@ -37,7 +37,14 @@
evaluation_parallel_to_training=True,
evaluation_num_env_runners=1,
evaluation_duration="auto",
evaluation_config={"explore": False},
evaluation_config={
"explore": False,
# TODO (sven): Add support for window=float(inf) and reduce=mean for
# evaluation episode_return_mean reductions (identical to old stack
# behavior, which does NOT use a window (100 by default) to reduce
# eval episode returns.
"metrics_num_episodes_for_smoothing": 4,
},
)
)

Expand Down

0 comments on commit 211c8d0

Please sign in to comment.