Skip to content

Commit

Permalink
feature(pu): add muzero_rnn variant
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Apr 23, 2024
1 parent f05516c commit 1af5f92
Show file tree
Hide file tree
Showing 10 changed files with 2,853 additions and 49 deletions.
44 changes: 22 additions & 22 deletions lzero/entry/train_muzero_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def train_muzero_context(
"""

cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_context','sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_context','muzero_rnn', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \
"train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'"

if create_cfg.policy.type in ['muzero', 'muzero_context']:
if create_cfg.policy.type in ['muzero', 'muzero_context', 'muzero_rnn']:
from lzero.mcts import MuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'efficientzero':
from lzero.mcts import EfficientZeroGameBuffer as GameBuffer
Expand Down Expand Up @@ -213,26 +213,26 @@ def train_muzero_context(
# remove the oldest data if the replay buffer is full.
replay_buffer.remove_oldest_data_to_fit()

if replay_buffer.get_num_of_transitions() > 2000:
# Learn policy from collected data.
for i in range(update_per_collect):
# Learner will train ``update_per_collect`` times in one iteration.
if replay_buffer.get_num_of_transitions() > batch_size:
train_data = replay_buffer.sample(batch_size, policy)
else:
logging.warning(
f'The data in replay_buffer is not sufficient to sample a mini-batch: '
f'batch_size: {batch_size}, '
f'{replay_buffer} '
f'continue to collect now ....'
)
break

# The core train steps for MCTS+RL algorithms.
log_vars = learner.train(train_data, collector.envstep)

if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])
# if replay_buffer.get_num_of_transitions() > 2000: # TODO
# Learn policy from collected data.
for i in range(update_per_collect):
# Learner will train ``update_per_collect`` times in one iteration.
if replay_buffer.get_num_of_transitions() > batch_size:
train_data = replay_buffer.sample(batch_size, policy)
else:
logging.warning(
f'The data in replay_buffer is not sufficient to sample a mini-batch: '
f'batch_size: {batch_size}, '
f'{replay_buffer} '
f'continue to collect now ....'
)
break

# The core train steps for MCTS+RL algorithms.
log_vars = learner.train(train_data, collector.envstep)

if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])

if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
if cfg.policy.eval_offline:
Expand Down
2 changes: 1 addition & 1 deletion lzero/mcts/tree_search/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .mcts_ctree import MuZeroMCTSCtree, EfficientZeroMCTSCtree, GumbelMuZeroMCTSCtree, UniZeroMCTSCtree
from .mcts_ctree import MuZeroMCTSCtree, EfficientZeroMCTSCtree, GumbelMuZeroMCTSCtree, UniZeroMCTSCtree, MuZeroRNNMCTSCtree
from .mcts_ctree_sampled import SampledEfficientZeroMCTSCtree
from .mcts_ctree_stochastic import StochasticMuZeroMCTSCtree
from .mcts_ptree import MuZeroMCTSPtree, EfficientZeroMCTSPtree
Expand Down
234 changes: 231 additions & 3 deletions lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,234 @@ def search(
)



class MuZeroRNNMCTSCtree(object):
"""
Overview:
The C++ implementation of MCTS (batch format) for EfficientZero. \
It completes the ``roots``and ``search`` methods by calling functions in module ``ctree_muzero``, \
which are implemented in C++.
Interfaces:
``__init__``, ``roots``, ``search``
..note::
The benefit of searching for a batch of nodes at the same time is that \
it can be parallelized during model inference, thus saving time.
"""

config = dict(
# (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree.
root_dirichlet_alpha=0.3,
# (float) The noise weight at the root node of the search tree.
root_noise_weight=0.25,
# (int) The base constant used in the PUCT formula for balancing exploration and exploitation during tree search.
pb_c_base=19652,
# (float) The initialization constant used in the PUCT formula for balancing exploration and exploitation during tree search.
pb_c_init=1.25,
# (float) The maximum change in value allowed during the backup step of the search tree update.
value_delta_max=0.01,
env_type='not_board_games',
)

@classmethod
def default_config(cls: type) -> EasyDict:
"""
Overview:
A class method that returns a default configuration in the form of an EasyDict object.
Returns:
- cfg (:obj:`EasyDict`): The dict of the default configuration.
"""
# Create a deep copy of the `config` attribute of the class.
cfg = EasyDict(copy.deepcopy(cls.config))
# Add a new attribute `cfg_type` to the `cfg` object.
cfg.cfg_type = cls.__name__ + 'Dict'
return cfg

def __init__(self, cfg: EasyDict = None) -> None:
"""
Overview:
Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key
in the default configuration, the user-provided value will override the default configuration. Otherwise,
the default configuration will be used.
Arguments:
- cfg (:obj:`EasyDict`): The configuration passed in by the user.
"""
# Get the default configuration.
default_config = self.default_config()
# Update the default configuration with the values provided by the user in ``cfg``.
default_config.update(cfg)
self._cfg = default_config
self.inverse_scalar_transform_handle = InverseScalarTransform(
self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution
)

@classmethod
def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "ez_ctree.Roots":
"""
Overview:
Initializes a batch of roots to search parallelly later.
Arguments:
- root_num (:obj:`int`): the number of the roots in a batch.
- legal_action_list (:obj:`List[Any]`): the vector of the legal actions for the roots.
..note::
The initialization is achieved by the ``Roots`` class from the ``ctree_muzero`` module.
"""
return tree_muzero.Roots(active_collect_env_num, legal_actions)

# @profile
def search(
self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any],
reward_hidden_state_roots: List[Any], to_play_batch: Union[int, List[Any]]
) -> None:
"""
Overview:
Do MCTS for a batch of roots. Parallel in model inference. \
Use C++ to implement the tree search.
Arguments:
- roots (:obj:`Any`): a batch of expanded root nodes.
- latent_state_roots (:obj:`list`): the hidden states of the roots.
- reward_hidden_state_roots (:obj:`list`): the value prefix hidden states in LSTM of the roots.
- model (:obj:`torch.nn.Module`): The model used for inference.
- to_play (:obj:`list`): the to_play list used in in self-play-mode board games.
.. note::
The core functions ``batch_traverse`` and ``batch_backpropagate`` are implemented in C++.
"""
with torch.no_grad():
model.eval()

# preparation some constant
batch_size = roots.num
pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor

# the data storage of latent states: storing the latent state of all the nodes in one search.
latent_state_batch_in_search_path = [latent_state_roots]
# the data storage of value prefix hidden states in LSTM
reward_hidden_state_c_batch = [reward_hidden_state_roots[0]]
reward_hidden_state_h_batch = [reward_hidden_state_roots[1]]

# minimax value storage
min_max_stats_lst = tree_muzero.MinMaxStatsList(batch_size)
min_max_stats_lst.set_delta(self._cfg.value_delta_max)

state_action_history = [] # 初始化 state_action_history 变量
last_latent_state = latent_state_roots
# NOTE: very important, from the right init key-value-cache
# forward_initial_inference()以及执行了下面的操作
# _ = model.world_model.refresh_keys_values_with_initial_obs_tokens(model.world_model.obs_tokens)

# model.world_model.past_keys_values_cache.clear() # 清除缓存
for simulation_index in range(self._cfg.num_simulations):
# In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most.

latent_states = []
hidden_states_c_reward = []
hidden_states_h_reward = []

# prepare a result wrapper to transport results between python and c++ parts
results = tree_muzero.ResultsWrapper(num=batch_size)

# latent_state_index_in_search_path: the first index of leaf node states in latent_state_batch_in_search_path, i.e. is current_latent_state_index in one the search.
# latent_state_index_in_batch: the second index of leaf node states in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``.
# e.g. the latent state of the leaf node in (x, y) is latent_state_batch_in_search_path[x, y], where x is current_latent_state_index, y is batch_index.
# The index of value prefix hidden state of the leaf node is in the same manner.
"""
MCTS stage 1: Selection
Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l.
"""
if self._cfg.env_type=='not_board_games':
latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_muzero.batch_traverse(
roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results,
to_play_batch
)
else:
latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_muzero.batch_traverse(
roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results,
copy.deepcopy(to_play_batch)
)

# obtain the search horizon for leaf nodes
search_lens = results.get_search_len()

# obtain the latent state for leaf node
for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch):
latent_states.append(latent_state_batch_in_search_path[ix][iy])
hidden_states_c_reward.append(reward_hidden_state_c_batch[ix][0][iy])
hidden_states_h_reward.append(reward_hidden_state_h_batch[ix][0][iy])

latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device)
hidden_states_c_reward = torch.from_numpy(np.asarray(hidden_states_c_reward)).to(self._cfg.device
).unsqueeze(0)
hidden_states_h_reward = torch.from_numpy(np.asarray(hidden_states_h_reward)).to(self._cfg.device
).unsqueeze(0)
# latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device).float()
# TODO: .long() is only for discrete action
last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long()

# TODO
# 在每次模拟后更新 state_action_history
# state_action_history.append((last_latent_state, last_actions.detach().cpu().numpy()))
# state_action_history.append((latent_states.detach().cpu().numpy(), last_actions.detach().cpu().numpy()))
state_action_history.append((latent_states.detach().cpu().numpy(), last_actions))

"""
MCTS stage 2: Expansion
At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function.
Then we calculate the policy_logits and value for the leaf node (next_latent_state) by the prediction function. (aka. evaluation)
MCTS stage 3: Backup
At the end of the simulation, the statistics along the trajectory are updated.
"""
## MuZeroRNN ######################
network_output = model.recurrent_inference(
latent_states, (hidden_states_c_reward, hidden_states_h_reward), last_actions
)

network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state)
network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits)
network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value))
network_output.value_prefix = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value_prefix))

network_output.reward_hidden_state = (
network_output.reward_hidden_state[0].detach().cpu().numpy(),
network_output.reward_hidden_state[1].detach().cpu().numpy()
)

latent_state_batch_in_search_path.append(network_output.latent_state)

# TODO
# last_latent_state = network_output.latent_state

# tolist() is to be compatible with cpp datatype.
reward_batch = network_output.value_prefix.reshape(-1).tolist()
value_batch = network_output.value.reshape(-1).tolist()
policy_logits_batch = network_output.policy_logits.tolist()

reward_latent_state_batch = network_output.reward_hidden_state
# reset the hidden states in LSTM every ``lstm_horizon_len`` steps in one search.
# which enable the model only need to predict the value prefix in a range (e.g.: [s0,...,s5])
# assert self._cfg.context_length_in_search > 0
reset_idx = (np.array(search_lens) % self._cfg.context_length_in_search == 0)
assert len(reset_idx) == batch_size
reward_latent_state_batch[0][:, reset_idx, :] = 0
reward_latent_state_batch[1][:, reset_idx, :] = 0
# is_reset_list = reset_idx.astype(np.int32).tolist()
reward_hidden_state_c_batch.append(reward_latent_state_batch[0])
reward_hidden_state_h_batch.append(reward_latent_state_batch[1])

# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# ``reward`` predicted by the model, then perform backpropagation along the search path to update the
# statistics.

# NOTE: simulation_index + 1 is very important, which is the depth of the current leaf node.
current_latent_state_index = simulation_index + 1
tree_muzero.batch_backpropagate(
current_latent_state_index, discount_factor, reward_batch, value_batch, policy_logits_batch,
min_max_stats_lst, results, virtual_to_play_batch
)



class EfficientZeroMCTSCtree(object):
"""
Overview:
Expand Down Expand Up @@ -516,9 +744,9 @@ def search(
At the end of the simulation, the statistics along the trajectory are updated.
"""
## EZ ######################
# network_output = model.recurrent_inference(latent_states, last_actions) # for classic muzero
# network_output = model.recurrent_inference(last_actions) # TODO: for unizero latent_states is not used in the model.
network_output = model.recurrent_inference(state_action_history) # TODO: latent_states is not used in the model.
network_output = model.recurrent_inference(
latent_states, (hidden_states_c_reward, hidden_states_h_reward), last_actions
)

network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state)
network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits)
Expand Down
24 changes: 1 addition & 23 deletions lzero/model/muzero_context_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np

from .common import MZNetworkOutput, RepresentationNetwork, PredictionNetwork, PredictionNetworkMLP, FeatureAndGradientHook
from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean
from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean, SimNorm
import torch.nn.init as init
import torch.nn.functional as F

Expand Down Expand Up @@ -442,28 +442,6 @@ def project(self, latent_state: torch.Tensor, with_grad: bool = True) -> torch.T
def get_params_mean(self) -> float:
return get_params_mean(self)

class SimNorm(nn.Module):
"""
Simplicial normalization.
Adapted from https://arxiv.org/abs/2204.00616.
"""

def __init__(self, simnorm_dim):
super().__init__()
self.dim = simnorm_dim

def forward(self, x):
shp = x.shape
# Ensure that there is at least one simplex to normalize across.
if shp[1] != 0:
x = x.view(*shp[:-1], -1, self.dim)
x = F.softmax(x, dim=-1)
return x.view(*shp)
else:
return x

def __repr__(self):
return f"SimNorm(dim={self.dim})"

class DynamicsNetwork(nn.Module):

Expand Down

0 comments on commit 1af5f92

Please sign in to comment.