Skip to content

Commit

Permalink
polish(nyz): complete ppo/impala/dt comments
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Oct 31, 2023
1 parent 2c6408c commit ac7c6e2
Show file tree
Hide file tree
Showing 12 changed files with 398 additions and 138 deletions.
13 changes: 7 additions & 6 deletions ding/policy/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def default_config(cls: type) -> EasyDict:
Overview:
Get the default config of policy. This method is used to create the default config of policy.
Returns:
cfg (:obj:`EasyDict`): The default config of corresponding policy. For the derived policy class, \
- cfg (:obj:`EasyDict`): The default config of corresponding policy. For the derived policy class, \
it will recursively merge the default config of base class and its own default config.
.. tip::
Expand Down Expand Up @@ -196,16 +196,17 @@ def hook(*ignore):
def _create_model(self, cfg: EasyDict, model: Optional[torch.nn.Module] = None) -> torch.nn.Module:
"""
Overview:
Create neural network model according to input configures and model. If the input model is None, then \
the model will be created according to ``default_model`` method and ``cfg.model`` field. Otherwise, the \
model will be set to the ``model`` instance created by outside caller.
Create or validate the neural network model according to input configures and model. If the input model is \
None, then the model will be created according to ``default_model`` method and ``cfg.model`` field. \
Otherwise, the model will be verified as an instance of ``torch.nn.Module`` and set to the ``model`` \
instance created by outside caller.
Arguments:
- cfg (:obj:`EasyDict`): The final merged config used to initialize policy.
- model (:obj:`torch.nn.Module`): The neural network model used to initialize policy. User can refer to \
the default model defined in corresponding policy to customize its own model.
Returns:
- model (:obj:`torch.nn.Module`): The created neural network model. Then different modes of policy will \
add wrappers and plugins to the model, which is used to train, collect and evaluate.
- model (:obj:`torch.nn.Module`): The created neural network model. The different modes of policy will \
add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.
Raises:
- RuntimeError: If the input model is not None and is not an instance of ``torch.nn.Module``.
"""
Expand Down
2 changes: 1 addition & 1 deletion ding/policy/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.
timestep: namedtuple) -> Dict[str, torch.Tensor]:
"""
Overview:
Process and pack one timestep transition data info a dict, which can be directly used for training and \
Process and pack one timestep transition data into a dict, which can be directly used for training and \
saved in replay buffer. For DDPG, it contains obs, next_obs, action, reward, done.
Arguments:
- obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari.
Expand Down
9 changes: 5 additions & 4 deletions ding/policy/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,10 +403,10 @@ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str,
in ``self._forward_learn`` method.
Arguments:
- transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \
the same format as the return value of ``self._process_transition`` method.
in the same format as the return value of ``self._process_transition`` method.
Returns:
- samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \
as input transitions, but may contain more data for training, such as nstep reward and target obs.
- samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is similar in format \
to input transitions, but may contain more data for training, such as nstep reward and target obs.
"""
transitions = get_nstep_return_data(transitions, self._nstep, gamma=self._gamma)
return get_train_sample(transitions, self._unroll_len)
Expand All @@ -415,7 +415,7 @@ def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.
timestep: namedtuple) -> Dict[str, torch.Tensor]:
"""
Overview:
Process and pack one timestep transition data info a dict, which can be directly used for training and \
Process and pack one timestep transition data into a dict, which can be directly used for training and \
saved in replay buffer. For DQN, it contains obs, next_obs, action, reward, done.
Arguments:
- obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari.
Expand Down Expand Up @@ -539,6 +539,7 @@ class DQNSTDIMPolicy(DQNPolicy):
"""
Overview:
Policy class of DQN algorithm, extended by ST-DIM auxiliary objectives.
ST-DIM paper link: https://arxiv.org/abs/1906.08226.
Config:
== ==================== ======== ============== ======================================== =======================
ID Symbol Type Default Value Description Other(Shape)
Expand Down
50 changes: 42 additions & 8 deletions ding/policy/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from ding.torch_utils import to_device
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_decollate
from ding.torch_utils import one_hot
from .base_policy import Policy


Expand Down Expand Up @@ -56,8 +55,20 @@ def default_model(self) -> Tuple[str, List[str]]:
def _init_learn(self) -> None:
"""
Overview:
Learn mode init method. Called by ``self.__init__``.
Init the optimizer, algorithm config, main and target models.
Initialize the learn mode of policy, including related attributes and modules. For Decision Transformer, \
it mainly contains the optimizer, algorithm-specific arguments such as rtg_scale and lr scheduler.
This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
.. note::
For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
and ``_load_state_dict_learn`` methods.
.. note::
For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
.. note::
If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
"""
# rtg_scale: scale of `return to go`
# rtg_target: max target of `return to go`
Expand Down Expand Up @@ -92,14 +103,26 @@ def _init_learn(self) -> None:

self.max_env_score = -1.0

def _forward_learn(self, data: list) -> Dict[str, Any]:
def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]:
"""
Overview:
Forward and backward function of learn mode.
Policy forward function of learn mode (training policy and updating parameters). Forward means \
that the policy inputs some training batch data from the offline dataset and then returns the output \
result, including various training information such as loss, current learning rate.
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
- data (:obj:`List[torch.Tensor]`): The input data used for policy forward, including a series of \
processed torch.Tensor data, i.e., timesteps, states, actions, returns_to_go, traj_mask.
Returns:
- info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
- info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
.. note::
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
For the data type that not supported, the main reason is that the corresponding model does not support it. \
You can implement you own model rather than use the default model. For more information, please raise an \
issue in GitHub repo and we will continue to follow up.
"""
self._learn_model.train()

Expand Down Expand Up @@ -156,7 +179,18 @@ def _forward_learn(self, data: list) -> Dict[str, Any]:
def _init_eval(self) -> None:
"""
Overview:
Evaluate mode init method. Called by ``self.__init__``, initialize eval_model.
Initialize the eval mode of policy, including related attributes and modules. For DQN, it contains the \
eval model, some algorithm-specific parameters such as context_len, max_eval_ep_len, etc.
This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
.. tip::
For the evaluation of complete episodes, we need to maintain some historical information for transformer \
inference. These variables need to be initialized in ``_init_eval`` and reset in ``_reset_eval`` when \
necessary.
.. note::
If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
"""
self._eval_model = self._model
# init data
Expand Down
Loading

0 comments on commit ac7c6e2

Please sign in to comment.