Skip to content

Commit

Permalink
[RLlib] Fix bad assertion error in PPO when use_kl_loss=False. (r…
Browse files Browse the repository at this point in the history
  • Loading branch information
simonsays1980 authored and harborn committed May 8, 2024
1 parent 287ecc0 commit f1e0590
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 6 deletions.
2 changes: 0 additions & 2 deletions rllib/algorithms/ppo/ppo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,11 @@ def additional_update_for_module(
module_id: ModuleID,
config: "PPOConfig",
timestep: int,
sampled_kl_values: dict,
) -> Dict[str, Any]:
results = super().additional_update_for_module(
module_id=module_id,
config=config,
timestep=timestep,
sampled_kl_values=sampled_kl_values,
)

# Update entropy coefficient via our Scheduler.
Expand Down
3 changes: 1 addition & 2 deletions rllib/algorithms/ppo/tf/ppo_tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,16 @@ def additional_update_for_module(
timestep: int,
sampled_kl_values: dict,
) -> Dict[str, Any]:
assert sampled_kl_values, "Sampled KL values are empty."

results = super().additional_update_for_module(
module_id=module_id,
config=config,
timestep=timestep,
sampled_kl_values=sampled_kl_values,
)

# Update KL coefficient.
if config.use_kl_loss:
assert sampled_kl_values, "Sampled KL values are empty."
sampled_kl = sampled_kl_values[module_id]
curr_var = self.curr_kl_coeffs_per_module[module_id]
if sampled_kl > 2.0 * config.kl_target:
Expand Down
3 changes: 1 addition & 2 deletions rllib/algorithms/ppo/torch/ppo_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,16 @@ def additional_update_for_module(
timestep: int,
sampled_kl_values: dict,
) -> Dict[str, Any]:
assert sampled_kl_values, "Sampled KL values are empty."

results = super().additional_update_for_module(
module_id=module_id,
config=config,
timestep=timestep,
sampled_kl_values=sampled_kl_values,
)

# Update KL coefficient.
if config.use_kl_loss:
assert sampled_kl_values, "Sampled KL values are empty."
sampled_kl = sampled_kl_values[module_id]
curr_var = self.curr_kl_coeffs_per_module[module_id]
if sampled_kl > 2.0 * config.kl_target:
Expand Down
1 change: 1 addition & 0 deletions rllib/tuned_examples/ppo/cartpole_ppo_envrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
lr=0.0003,
num_sgd_iter=6,
vf_loss_coeff=0.01,
use_kl_loss=True,
)
.evaluation(
evaluation_num_env_runners=1,
Expand Down

0 comments on commit f1e0590

Please sign in to comment.