Skip to content

Commit

Permalink
Bugfix/greedy eval (#246)
Browse files Browse the repository at this point in the history
* only add framestack body if framestack > 1

* allow greedy policy to still randomly choose actions in eval mode if specified

* clean up code for parallel greedy policy
  • Loading branch information
cpnota committed May 28, 2021
1 parent 4bef9a1 commit bd40aa9
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions all/policies/greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def no_grad(self, state):
return torch.argmax(self.q.no_grad(state)).item()

def eval(self, state):
if np.random.rand() < self.epsilon:
return np.random.randint(0, self.num_actions)
return torch.argmax(self.q.eval(state)).item()


Expand Down Expand Up @@ -68,16 +70,16 @@ def __init__(
self.epsilon = epsilon

def __call__(self, state):
best_actions = torch.argmax(self.q(state), dim=-1)
random_actions = torch.randint(0, self.n_actions, best_actions.shape, device=best_actions.device)
choices = (torch.rand(best_actions.shape, device=best_actions.device) < self.epsilon).int()
return choices * random_actions + (1 - choices) * best_actions
return self._choose_action(self.q(state))

def no_grad(self, state):
best_actions = torch.argmax(self.q.no_grad(state), dim=-1)
return self._choose_action(self.q.no_grad(state))

def eval(self, state):
return self._choose_action(self.q.eval(state))

def _choose_action(self, action_values):
best_actions = torch.argmax(action_values, dim=-1)
random_actions = torch.randint(0, self.num_actions, best_actions.shape, device=best_actions.device)
choices = (torch.rand(best_actions.shape, device=best_actions.device) < self.epsilon).int()
return choices * random_actions + (1 - choices) * best_actions

def eval(self, state):
return torch.argmax(self.q.eval(state), dim=-1)

0 comments on commit bd40aa9

Please sign in to comment.