Skip to content

Commit

Permalink
fix(pu): fix total_episode_count bug in collector
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Mar 30, 2024
1 parent 29c9afd commit 5fc7c45
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 6 deletions.
8 changes: 5 additions & 3 deletions lzero/policy/muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from ding.policy.base_policy import Policy
from ding.torch_utils import to_tensor
from ding.utils import POLICY_REGISTRY
from torch.distributions import Categorical
from torch.nn import L1Loss

from lzero.mcts import MuZeroMCTSCtree as MCTSCtree
Expand Down Expand Up @@ -39,6 +38,7 @@ class MuZeroPolicy(Policy):
# (bool) Whether to use the self-supervised learning loss.
self_supervised_learning_loss=False,
# (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix.
# reference: http://proceedings.mlr.press/v80/imani18a/imani18a.pdf, https://arxiv.org/abs/2403.03950
categorical_distribution=True,
# (int) The image channel in image observation.
image_channel=1,
Expand Down Expand Up @@ -453,7 +453,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
weighted_total_loss.backward()
if self._cfg.multi_gpu:
self.sync_gradients(self._learn_model)
total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), self._cfg.grad_clip_value)
total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(),
self._cfg.grad_clip_value)
self._optimizer.step()
if self._cfg.lr_piecewise_constant_decay:
self.lr_scheduler.step()
Expand Down Expand Up @@ -645,7 +646,8 @@ def _get_target_obs_index_in_step_k(self, step):
end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num)
return beg_index, end_index

def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None,) -> Dict:
def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1,
ready_env_id: np.array = None, ) -> Dict:
"""
Overview:
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search.
Expand Down
1 change: 0 additions & 1 deletion lzero/worker/alphazero_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ def collect(self,

self._env_info[env_id]['time'] += self._timer.value + interaction_duration
if timestep.done:
self._total_episode_count += 1
# the eval_episode_return is calculated from Player 1's perspective
reward = timestep.info['eval_episode_return']
info = {
Expand Down
2 changes: 1 addition & 1 deletion lzero/worker/muzero_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def collect(self,
collected_episode = allreduce_data(collected_episode, 'sum')
collected_duration = allreduce_data(collected_duration, 'sum')
self._total_envstep_count += collected_step
self._total_episode_count += collected_episode
# self._total_episode_count += collected_episode
self._total_duration += collected_duration

# log
Expand Down
2 changes: 1 addition & 1 deletion lzero/worker/muzero_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,4 +454,4 @@ def eval(
episode_info = to_item(episode_info)
if return_trajectory:
episode_info['trajectory'] = game_segments
return stop_flag, episode_info
return stop_flag, episode_info

0 comments on commit 5fc7c45

Please sign in to comment.