From 7a1bd2a80a6953d1f9132b9022c9d73da62fca3c Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 19 Jul 2023 19:54:53 +0200 Subject: [PATCH] [RLlib] Index tensors in slate epsilon greedy properly (#37481) Signed-off-by: Artur Niederfahrenhorst Signed-off-by: e428265 --- rllib/utils/exploration/slate_epsilon_greedy.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/rllib/utils/exploration/slate_epsilon_greedy.py b/rllib/utils/exploration/slate_epsilon_greedy.py index d4cf3cf59d966..95a52e057da6a 100644 --- a/rllib/utils/exploration/slate_epsilon_greedy.py +++ b/rllib/utils/exploration/slate_epsilon_greedy.py @@ -79,8 +79,10 @@ def _get_torch_exploration_action( per_slate_q_values = action_distribution.inputs all_slates = self.model.slates + device = all_slates.device exploit_indices = action_distribution.deterministic_sample() + exploit_indices = exploit_indices.to(device) exploit_action = all_slates[exploit_indices] batch_size = per_slate_q_values.size()[0] @@ -94,7 +96,10 @@ def _get_torch_exploration_action( epsilon = self.epsilon_schedule(self.last_timestep) # A random action. random_indices = torch.randint( - 0, per_slate_q_values.shape[1], (per_slate_q_values.shape[0],) + 0, + per_slate_q_values.shape[1], + (per_slate_q_values.shape[0],), + device=device, ) random_actions = all_slates[random_indices]