Skip to content

Commit

Permalink
feature(xcy): save the updated searched policy and value to the buffe…
Browse files Browse the repository at this point in the history
…r during reanalyze (#190)

* feature(xcy): reanalyze with no noise and save

* polish(xcy):fix some typos

* polish(xcy): a test config

* polish(xcy):fix config

* polish(xcy):solve some review problems

* polish(xcy):add noise check

* polish(xcy): unify the format for final commit
  • Loading branch information
HarryXuancy committed Mar 13, 2024
1 parent dbff144 commit 3a7424d
Show file tree
Hide file tree
Showing 12 changed files with 108 additions and 86 deletions.
37 changes: 30 additions & 7 deletions lzero/mcts/buffer/game_buffer_efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,19 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
if self._cfg.mcts_ctree:
# cpp mcts_tree
roots = MCTSCtree.roots(transition_batch_size, legal_actions)
roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
if self._cfg.reanalyze_noise:
roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
else:
roots.prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play)
# do MCTS for a new policy with the recent target model
MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play)
else:
# python mcts_tree
roots = MCTSPtree.roots(transition_batch_size, legal_actions)
roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
if self._cfg.reanalyze_noise:
roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
else:
roots.prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play)
# do MCTS for a new policy with the recent target model
MCTSPtree(self._cfg).search(
roots, model, latent_state_roots, reward_hidden_state_roots, to_play=to_play
Expand Down Expand Up @@ -326,7 +332,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
return []
batch_target_policies_re = []

policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \
policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, root_values, game_segment_lens, action_mask_segment, \
to_play_segment = policy_re_context # noqa
# transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1)
transition_batch_size = len(policy_obs_list)
Expand Down Expand Up @@ -369,34 +375,44 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
)
value_prefix_pool = value_prefix_pool.squeeze().tolist()
policy_logits_pool = policy_logits_pool.tolist()
# noises are not necessary for reanalyze
noises = [
np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size
).astype(np.float32).tolist() for _ in range(transition_batch_size)
]
if self._cfg.mcts_ctree:
# cpp mcts_tree
roots = MCTSCtree.roots(transition_batch_size, legal_actions)
roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
if self._cfg.reanalyze_noise:
roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
else:
roots.prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play)
# do MCTS for a new policy with the recent target model
MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play)
else:
# python mcts_tree
roots = MCTSPtree.roots(transition_batch_size, legal_actions)
roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
if self._cfg.reanalyze_noise:
roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
else:
roots.prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play)
# do MCTS for a new policy with the recent target model
MCTSPtree(self._cfg).search(
roots, model, latent_state_roots, reward_hidden_state_roots, to_play=to_play
)

roots_legal_actions_list = legal_actions
roots_distributions = roots.get_distributions()
roots_values = roots.get_values()
policy_index = 0
for state_index, game_index in zip(pos_in_game_segment_list, batch_index_list):
for state_index, child_visit, root_value in zip(pos_in_game_segment_list, child_visits, root_values):
target_policies = []
for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
distributions = roots_distributions[policy_index]
searched_value = roots_values[policy_index]

if policy_mask[policy_index] == 0:
# NOTE: the invalid padding target policy, O is to make sure the correspoding cross_entropy_loss=0
# NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0
target_policies.append([0 for _ in range(self._cfg.model.action_space_size)])
else:
if distributions is None:
Expand All @@ -405,6 +421,13 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size)
)
else:
# Update the data in game segment:
# after the reanalyze search, new target policies and root values are obtained
# the target policies and root values are stored in the gamesegment, specifically, ``child_visit_segment`` and ``root_value_segment``
# we replace the data at the corresponding location with the latest search results to keep the most up-to-date targets
sim_num = sum(distributions)
child_visit[current_index] = [visit_count/sim_num for visit_count in distributions]
root_value[current_index] = searched_value
if self._cfg.mcts_ctree:
# cpp mcts_tree
if self._cfg.action_type == 'fixed_action_space':
Expand Down
39 changes: 31 additions & 8 deletions lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def _prepare_policy_reanalyzed_context(
policy_mask = []
# 0 -> Invalid target policy for padding outside of game segments,
# 1 -> Previous target policy for game segments.
rewards, child_visits, game_segment_lens = [], [], []
rewards, child_visits, game_segment_lens, root_values = [], [], [], []
# for board games
action_mask_segment, to_play_segment = [], []
for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list):
Expand All @@ -316,6 +316,7 @@ def _prepare_policy_reanalyzed_context(
to_play_segment.append(game_segment.to_play_segment)

child_visits.append(game_segment.child_visit_segment)
root_values.append(game_segment.root_value_segment)
# prepare the corresponding observations
game_obs = game_segment.get_unroll_obs(state_index, self._cfg.num_unroll_steps)
for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
Expand All @@ -331,7 +332,7 @@ def _prepare_policy_reanalyzed_context(
policy_obs_list.append(obs)

policy_re_context = [
policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens,
policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, root_values, game_segment_lens,
action_mask_segment, to_play_segment
]
return policy_re_context
Expand Down Expand Up @@ -411,13 +412,19 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
if self._cfg.mcts_ctree:
# cpp mcts_tree
roots = MCTSCtree.roots(transition_batch_size, legal_actions)
roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play)
if self._cfg.reanalyze_noise:
roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play)
else:
roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play)
# do MCTS for a new policy with the recent target model
MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play)
else:
# python mcts_tree
roots = MCTSPtree.roots(transition_batch_size, legal_actions)
roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play)
if self._cfg.reanalyze_noise:
roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play)
else:
roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play)
# do MCTS for a new policy with the recent target model
MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play)

Expand Down Expand Up @@ -495,7 +502,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
batch_target_policies_re = []

# for board games
policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \
policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, root_values, game_segment_lens, action_mask_segment, \
to_play_segment = policy_re_context
# transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1)
transition_batch_size = len(policy_obs_list)
Expand Down Expand Up @@ -542,31 +549,40 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
_, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero')
reward_pool = reward_pool.squeeze().tolist()
policy_logits_pool = policy_logits_pool.tolist()
# noises are not necessary for reanalyze
noises = [
np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size
).astype(np.float32).tolist() for _ in range(transition_batch_size)
]
if self._cfg.mcts_ctree:
# cpp mcts_tree
roots = MCTSCtree.roots(transition_batch_size, legal_actions)
roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play)
if self._cfg.reanalyze_noise:
roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play)
else:
roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play)
# do MCTS for a new policy with the recent target model
MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play)
else:
# python mcts_tree
roots = MCTSPtree.roots(transition_batch_size, legal_actions)
roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play)
if self._cfg.reanalyze_noise:
roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play)
else:
roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play)
# do MCTS for a new policy with the recent target model
MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play)

roots_legal_actions_list = legal_actions
roots_distributions = roots.get_distributions()
roots_values = roots.get_values()
policy_index = 0
for state_index, game_index in zip(pos_in_game_segment_list, batch_index_list):
for state_index, child_visit, root_value in zip(pos_in_game_segment_list, child_visits, root_values):
target_policies = []

for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
distributions = roots_distributions[policy_index]
searched_value = roots_values[policy_index]

if policy_mask[policy_index] == 0:
# NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0
Expand All @@ -578,6 +594,13 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size)
)
else:
# Update the data in game segment:
# after the reanalyze search, new target policies and root values are obtained
# the target policies and root values are stored in the gamesegment, specifically, ``child_visit_segment`` and ``root_value_segment``
# we replace the data at the corresponding location with the latest search results to keep the most up-to-date targets
sim_num = sum(distributions)
child_visit[current_index] = [visit_count/sim_num for visit_count in distributions]
root_value[current_index] = searched_value
if self._cfg.action_type == 'fixed_action_space':
# for atari/classic_control/box2d environments that only have one player.
sum_visits = sum(distributions)
Expand Down

0 comments on commit 3a7424d

Please sign in to comment.