Skip to content

Commit

Permalink
allow for negative rewards/evaluations + a little mcts cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lowrollr committed Nov 20, 2023
1 parent 43e8112 commit 87fd4d8
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 32 deletions.
2 changes: 1 addition & 1 deletion core/algorithms/lazy_mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def choose_action_with_puct(self, probs: torch.Tensor, legal_actions: torch.Tens
puct_scores = q_values + \
(self.puct_coeff * probs * torch.sqrt(n_sum + 1) / (1 + self.visit_counts))

puct_scores *= legal_actions
puct_scores = (puct_scores * legal_actions) + (torch.finfo(torch.float32).min * (~legal_actions))
return torch.argmax(puct_scores, dim=1)

def iterate(self, evaluation_fn: Callable, depth: int, rewards: torch.Tensor) -> torch.Tensor: # type: ignore
Expand Down
69 changes: 38 additions & 31 deletions core/algorithms/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,44 @@ def choose_action(self) -> torch.Tensor:
n_sum = visits.sum(dim=1, keepdim=True)
probs = self.p_vals[self.env_indices, self.cur_nodes]
puct_scores = q_values + (self.puct_coeff * probs * torch.sqrt(1 + n_sum) / (1 + visits))
puct_scores *= self.env.get_legal_actions()
legal_actions = self.env.get_legal_actions()
puct_scores = (puct_scores * legal_actions) + (torch.finfo(torch.float32).min * (~legal_actions))
return torch.argmax(puct_scores, dim=1)

def traverse(self, actions: torch.Tensor) -> torch.Tensor:
# make a step in the environment with the chosen actions
self.env.step(actions)

# look up master index for each child node
master_action_indices = self.next_idx[self.env_indices, self.cur_nodes, actions]

# if the node doesn't have an index yet (0 is null), the node is unvisited
unvisited = master_action_indices == 0

# check if creating a new node will go out of bounds
in_bounds = ~((self.next_empty >= self.total_slots) & unvisited)

# assign new nodes to the next empty indices (if there is space)
master_action_indices += self.next_empty * in_bounds * unvisited

# increment self.next_empty to reflect the new next empty index
self.next_empty += 1 * in_bounds * unvisited

# map action to child idx in parent node
self.next_idx[self.env_indices, self.cur_nodes, actions] = master_action_indices

# update visits, actions to reflect the path taken from the root
self.visits[self.env_indices, self.depths] = master_action_indices
self.actions[self.env_indices, self.depths - 1] = actions

# map master child idx to master parent idx
self.parents[self.env_indices, master_action_indices] = self.cur_nodes

# cur nodes should now reflect the taken actions
self.cur_nodes = master_action_indices

return unvisited


def evaluate(self, evaluation_fn: Callable) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
self.reset_search()
Expand All @@ -174,36 +210,7 @@ def evaluate(self, evaluation_fn: Callable) -> Tuple[torch.Tensor, Optional[torc
# choose next action with PUCT scores
actions = self.choose_action()

# make a step in the environment with the chosen actions
self.env.step(actions)

# look up master index for each child node
master_action_indices = self.next_idx[self.env_indices, self.cur_nodes, actions]

# if the node doesn't have an index yet (0 is null), the node is unvisited
unvisited = master_action_indices == 0

# check if creating a new node will go out of bounds
in_bounds = ~((self.next_empty >= self.total_slots) & unvisited)

# assign new nodes to the next empty indices (if there is space)
master_action_indices += self.next_empty * in_bounds * unvisited

# increment self.next_empty to reflect the new next empty index
self.next_empty += 1 * in_bounds * unvisited

# map action to child idx in parent node
self.next_idx[self.env_indices, self.cur_nodes, actions] = master_action_indices

# update visits, actions to reflect the path taken from the root
self.visits[self.env_indices, self.depths] = master_action_indices
self.actions[self.env_indices, self.depths - 1] = actions

# map master child idx to master parent idx
self.parents[self.env_indices, master_action_indices] = self.cur_nodes

# cur nodes should now reflect the taken actions
self.cur_nodes = master_action_indices
unvisited = self.traverse(actions)

# get (policy distribution, evaluation) from evaluation function
with torch.no_grad():
Expand Down

0 comments on commit 87fd4d8

Please sign in to comment.