Skip to content

Commit

Permalink
[RLlib] Index tensors in slate epsilon greedy properly (ray-project#3…
Browse files Browse the repository at this point in the history
…7481)

Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
ArturNiederfahrenhorst authored and arvind-chandra committed Aug 31, 2023
1 parent e6a56ad commit 7a1bd2a
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion rllib/utils/exploration/slate_epsilon_greedy.py
Expand Up @@ -79,8 +79,10 @@ def _get_torch_exploration_action(

per_slate_q_values = action_distribution.inputs
all_slates = self.model.slates
device = all_slates.device

exploit_indices = action_distribution.deterministic_sample()
exploit_indices = exploit_indices.to(device)
exploit_action = all_slates[exploit_indices]

batch_size = per_slate_q_values.size()[0]
Expand All @@ -94,7 +96,10 @@ def _get_torch_exploration_action(
epsilon = self.epsilon_schedule(self.last_timestep)
# A random action.
random_indices = torch.randint(
0, per_slate_q_values.shape[1], (per_slate_q_values.shape[0],)
0,
per_slate_q_values.shape[1],
(per_slate_q_values.shape[0],),
device=device,
)
random_actions = all_slates[random_indices]

Expand Down

0 comments on commit 7a1bd2a

Please sign in to comment.