Skip to content

Commit

Permalink
fix(pu): fix muzero_rnn_fullobs variants
Browse files Browse the repository at this point in the history
  • Loading branch information
jiayilee65 committed May 1, 2024
1 parent 3a44e8a commit 4168d9f
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 131 deletions.
25 changes: 16 additions & 9 deletions lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "e
# @profile
def search(
self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any],
world_model_hidden_state_roots: List[Any], to_play_batch: Union[int, List[Any]], ready_env_id=None,
world_model_latent_history_roots: List[Any], to_play_batch: Union[int, List[Any]], ready_env_id=None,
) -> None:
"""
Overview:
Expand All @@ -448,7 +448,7 @@ def search(
# the data storage of latent states: storing the latent state of all the nodes in one search.
latent_state_batch_in_search_path = [latent_state_roots]
# the data storage of value prefix hidden states in LSTM
world_model_hidden_state_batch = [world_model_hidden_state_roots]
world_model_latent_history_batch = [world_model_latent_history_roots]

# minimax value storage
min_max_stats_lst = tree_muzero.MinMaxStatsList(batch_size)
Expand All @@ -465,7 +465,7 @@ def search(
# In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most.

latent_states = []
world_model_hidden_state = []
world_model_latent_history = []

# prepare a result wrapper to transport results between python and c++ parts
results = tree_muzero.ResultsWrapper(num=batch_size)
Expand Down Expand Up @@ -495,10 +495,10 @@ def search(
# obtain the latent state for leaf node
for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch):
latent_states.append(latent_state_batch_in_search_path[ix][iy])
world_model_hidden_state.append(world_model_hidden_state_batch[ix][0][iy])
world_model_latent_history.append(world_model_latent_history_batch[ix][0][iy])

latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device)
world_model_hidden_state = torch.from_numpy(np.asarray(world_model_hidden_state)).to(self._cfg.device).unsqueeze(0)
world_model_latent_history = torch.from_numpy(np.asarray(world_model_latent_history)).to(self._cfg.device).unsqueeze(0)

# latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device).float()
# TODO: .long() is only for discrete action
Expand All @@ -519,20 +519,27 @@ def search(
"""
## MuZeroRNN full_obs ######################
if ready_env_id is None:
# for train
network_output = model.recurrent_inference(
latent_states, world_model_hidden_state, last_actions
latent_states, world_model_latent_history, last_actions
)
else:
# try:
network_output = model.recurrent_inference(
latent_states, world_model_hidden_state, last_actions, ready_env_id=ready_env_id
latent_states, world_model_latent_history, last_actions, ready_env_id=ready_env_id
)
# except Exception as e:
# print(e)
network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state)
network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits)
network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value))
network_output.value_prefix = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value_prefix))

network_output.reward_hidden_state = network_output.reward_hidden_state.detach().cpu().numpy()

if network_output.reward_hidden_state.shape[1] != world_model_latent_history.shape[1]:
print('debug')

latent_state_batch_in_search_path.append(network_output.latent_state)

# TODO: 检查muzero_context/muzero_rnn hidden state
Expand All @@ -543,15 +550,15 @@ def search(
value_batch = network_output.value.reshape(-1).tolist()
policy_logits_batch = network_output.policy_logits.tolist()

world_model_hidden_state = network_output.reward_hidden_state
world_model_latent_history = network_output.reward_hidden_state
# reset the hidden states in LSTM every ``lstm_horizon_len`` steps in one search.
# which enable the model only need to predict the value prefix in a range (e.g.: [s0,...,s5])
# reset_idx = ( (model.timestep + np.array(search_lens)) % self._cfg.context_length_in_search == 0)
# # reset_idx = (np.array(search_lens) % self._cfg.context_length_in_search == 0)
# reward_latent_state_batch[0][:, reset_idx, :] = 0
# reward_latent_state_batch[1][:, reset_idx, :] = 0
# is_reset_list = reset_idx.astype(np.int32).tolist()
world_model_hidden_state_batch.append(world_model_hidden_state)
world_model_latent_history_batch.append(world_model_latent_history)

# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# ``reward`` predicted by the model, then perform backpropagation along the search path to update the
Expand Down
6 changes: 3 additions & 3 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def __init__(
last_linear_layer_init_zero=last_linear_layer_init_zero
)

def forward(self, latent_state: torch.Tensor, world_model_hidden_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, latent_state: torch.Tensor, world_model_latent_history: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
Forward computation of the prediction network.
Expand All @@ -778,8 +778,8 @@ def forward(self, latent_state: torch.Tensor, world_model_hidden_state: torch.Te
latent_state_policy = policy.reshape(-1, self.flatten_output_size_for_policy_head)

# try:
latent_history_value = torch.cat([latent_state_value, world_model_hidden_state.squeeze(0)], dim=1) # TODO: world_model_hidden_state.squeeze(0) 隐状态的形状为: (num_layers * num_directions, batch_size, hidden_size) -> ( batch_size, hidden_size)
latent_history_policy = torch.cat([latent_state_policy, world_model_hidden_state.squeeze(0)], dim=1) # TODO
latent_history_value = torch.cat([latent_state_value, world_model_latent_history.squeeze(0)], dim=1) # TODO: world_model_latent_history.squeeze(0) 隐状态的形状为: (num_layers * num_directions, batch_size, hidden_size) -> ( batch_size, hidden_size)
latent_history_policy = torch.cat([latent_state_policy, world_model_latent_history.squeeze(0)], dim=1) # TODO
# except Exception as e:
# print(e)
value = self.fc_value(latent_history_value)
Expand Down

0 comments on commit 4168d9f

Please sign in to comment.