Skip to content

Commit

Permalink
polish(nyz): polish rl_utils api docs
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Apr 15, 2024
1 parent 15ff277 commit 96c4955
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 53 deletions.
9 changes: 5 additions & 4 deletions ding/rl_utils/adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def get_gae(cls, data: List[Dict[str, Any]], last_value: torch.Tensor, gamma: fl
Overview:
Get GAE advantage for stacked transitions(T timestep, 1 batch). Call ``gae`` for calculation.
Arguments:
- data (:obj:`list`): Transitions list, each element is a transition dict with at least ['value', 'reward']
- data (:obj:`list`): Transitions list, each element is a transition dict with at least \
``['value', 'reward']``.
- last_value (:obj:`torch.Tensor`): The last value(i.e.: the T+1 timestep)
- gamma (:obj:`float`): The future discount factor, should be in [0, 1], defaults to 0.99.
- gae_lambda (:obj:`float`): GAE lambda parameter, should be in [0, 1], defaults to 0.97, \
Expand Down Expand Up @@ -63,7 +64,7 @@ def get_gae_with_default_last_value(cls, data: deque, done: bool, gamma: float,
Overview:
Like ``get_gae`` above to get GAE advantage for stacked transitions. However, this function is designed in
case ``last_value`` is not passed. If transition is not done yet, it wouold assign last value in ``data``
as ``last_value``, discard the last element in ``data``(i.e. len(data) would decrease by 1), and then call
as ``last_value``, discard the last element in ``data`` (i.e. len(data) would decrease by 1), and then call
``get_gae``. Otherwise it would make ``last_value`` equal to 0.
Arguments:
- data (:obj:`deque`): Transitions list, each element is a transition dict with \
Expand Down Expand Up @@ -103,7 +104,7 @@ def get_nstep_return_data(
) -> deque:
"""
Overview:
Process raw traj data by updating keys ['next_obs', 'reward', 'done'] in data's dict element.
Process raw traj data by updating keys ``['next_obs', 'reward', 'done']`` in data's dict element.
Arguments:
- data (:obj:`deque`): Transitions list, each element is a transition dict
- nstep (:obj:`int`): Number of steps. If equals to 1, return ``data`` directly; \
Expand Down Expand Up @@ -159,7 +160,7 @@ def get_train_sample(
) -> List[Dict[str, Any]]:
"""
Overview:
Process raw traj data by updating keys ['next_obs', 'reward', 'done'] in data's dict element.
Process raw traj data by updating keys ``['next_obs', 'reward', 'done']`` in data's dict element.
If ``unroll_len`` equals to 1, which means no process is needed, can directly return ``data``.
Otherwise, ``data`` will be splitted according to ``unroll_len``, process residual part according to
``last_fn_type`` and call ``lists_to_dicts`` to form sampled training data.
Expand Down
27 changes: 27 additions & 0 deletions ding/rl_utils/beta_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
# For CPW, eta = 0.71 most closely match human subjects
# this function is locally concave for small values of τ and becomes locally convex for larger values of τ
def cpw(x: Union[torch.Tensor, float], eta: float = 0.71) -> Union[torch.Tensor, float]:
"""
Overview:
The implementation of CPW function.
Arguments:
- x (:obj:`Union[torch.Tensor, float]`): The input value.
- eta (:obj:`float`): The hyperparameter of CPW function.
Returns:
- output (:obj:`Union[torch.Tensor, float]`): The output value.
"""
return (x ** eta) / ((x ** eta + (1 - x) ** eta) ** (1 / eta))


Expand All @@ -22,6 +31,15 @@ def cpw(x: Union[torch.Tensor, float], eta: float = 0.71) -> Union[torch.Tensor,

# CVaR is risk-averse
def CVaR(x: Union[torch.Tensor, float], eta: float = 0.71) -> Union[torch.Tensor, float]:
"""
Overview:
The implementation of CVaR function, which is a risk-averse function.
Arguments:
- x (:obj:`Union[torch.Tensor, float]`): The input value.
- eta (:obj:`float`): The hyperparameter of CVaR function.
Returns:
- output (:obj:`Union[torch.Tensor, float]`): The output value.
"""
assert eta <= 1.0
return x * eta

Expand All @@ -31,6 +49,15 @@ def CVaR(x: Union[torch.Tensor, float], eta: float = 0.71) -> Union[torch.Tensor

# risk-averse (eta < 0) or risk-seeking (eta > 0)
def Pow(x: Union[torch.Tensor, float], eta: float = 0.0) -> Union[torch.Tensor, float]:
"""
Overview:
The implementation of Pow function, which is risk-averse when eta < 0 and risk-seeking when eta > 0.
Arguments:
- x (:obj:`Union[torch.Tensor, float]`): The input value.
- eta (:obj:`float`): The hyperparameter of Pow function.
Returns:
- output (:obj:`Union[torch.Tensor, float]`): The output value.
"""
if eta >= 0:
return x ** (1 / (1 + eta))
else:
Expand Down
59 changes: 30 additions & 29 deletions ding/rl_utils/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ def get_epsilon_greedy_fn(start: float, end: float, decay: int, type_: str = 'ex
Overview:
Generate an epsilon_greedy function with decay, which inputs current timestep and outputs current epsilon.
Arguments:
- start (:obj:`float`): Epsilon start value. For 'linear', it should be 1.0.
- start (:obj:`float`): Epsilon start value. For ``linear`` , it should be 1.0.
- end (:obj:`float`): Epsilon end value.
- decay (:obj:`int`): Controls the speed that epsilon decreases from ``start`` to ``end``. \
We recommend epsilon decays according to env step rather than iteration.
- type (:obj:`str`): How epsilon decays, now supports ['linear', 'exp'(exponential)]
- type (:obj:`str`): How epsilon decays, now supports ``['linear', 'exp'(exponential)]`` .
Returns:
- eps_fn (:obj:`function`): The epsilon greedy function with decay
- eps_fn (:obj:`function`): The epsilon greedy function with decay.
"""
assert type_ in ['linear', 'exp'], type_
if type_ == 'exp':
Expand Down Expand Up @@ -48,27 +48,27 @@ class BaseNoise(ABC):
def __init__(self) -> None:
"""
Overview:
Initialization method
Initialization method.
"""
super().__init__()

@abstractmethod
def __call__(self, shape: tuple, device: str) -> torch.Tensor:
"""
Overview:
Generate noise according to action tensor's shape, device
Generate noise according to action tensor's shape, device.
Arguments:
- shape (:obj:`tuple`): size of the action tensor, output noise's size should be the same
- device (:obj:`str`): device of the action tensor, output noise's device should be the same as it
- shape (:obj:`tuple`): size of the action tensor, output noise's size should be the same.
- device (:obj:`str`): device of the action tensor, output noise's device should be the same as it.
Returns:
- noise (:obj:`torch.Tensor`): generated action noise, \
have the same shape and device with the input action tensor
have the same shape and device with the input action tensor.
"""
raise NotImplementedError


class GaussianNoise(BaseNoise):
r"""
"""
Overview:
Derived class for generating gaussian noise, which satisfies :math:`X \sim N(\mu, \sigma^2)`
Interface:
Expand All @@ -78,10 +78,10 @@ class GaussianNoise(BaseNoise):
def __init__(self, mu: float = 0.0, sigma: float = 1.0) -> None:
"""
Overview:
Initialize :math:`\mu` and :math:`\sigma` in Gaussian Distribution
Initialize :math:`\mu` and :math:`\sigma` in Gaussian Distribution.
Arguments:
- mu (:obj:`float`): :math:`\mu` , mean value
- sigma (:obj:`float`): :math:`\sigma` , standard deviation, should be positive
- mu (:obj:`float`): :math:`\mu` , mean value.
- sigma (:obj:`float`): :math:`\sigma` , standard deviation, should be positive.
"""
super(GaussianNoise, self).__init__()
self._mu = mu
Expand Down Expand Up @@ -125,14 +125,15 @@ def __init__(
"""
Overview:
Initialize ``_alpha`` :math:`=\theta * dt\`,
``beta`` :math:`= \sigma * \sqrt{dt}`, in Ornstein-Uhlenbeck process
``beta`` :math:`= \sigma * \sqrt{dt}`, in Ornstein-Uhlenbeck process.
Arguments:
- mu (:obj:`float`): :math:`\mu` , mean value
- sigma (:obj:`float`): :math:`\sigma` , standard deviation of the perturbation noise
- theta (:obj:`float`): how strongly the noise reacts to perturbations, \
greater value means stronger reaction
- dt (:obj:`float`): derivative of time t
- x0 (:obj:`float` or :obj:`torch.Tensor`): initial action
- mu (:obj:`float`): :math:`\mu` , mean value.
- sigma (:obj:`float`): :math:`\sigma` , standard deviation of the perturbation noise.
- theta (:obj:`float`): How strongly the noise reacts to perturbations, \
greater value means stronger reaction.
- dt (:obj:`float`): The derivative of time t.
- x0 (:obj:`Union[float, torch.Tensor]`): The initial state of the noise, \
should be a scalar or tensor with the same shape as the action tensor.
"""
super().__init__()
self._mu = mu
Expand All @@ -144,21 +145,21 @@ def __init__(
def reset(self) -> None:
"""
Overview:
Reset ``_x`` to the initial state ``_x0``
Reset ``_x`` to the initial state ``_x0``.
"""
self._x = deepcopy(self._x0)

def __call__(self, shape: tuple, device: str, mu: Optional[float] = None) -> torch.Tensor:
"""
Overview:
Generate gaussian noise according to action tensor's shape, device
Generate gaussian noise according to action tensor's shape, device.
Arguments:
- shape (:obj:`tuple`): size of the action tensor, output noise's size should be the same
- device (:obj:`str`): device of the action tensor, output noise's device should be the same as it
- mu (:obj:`float`): new mean value :math:`\mu`, you can set it to `None` if don't need it
- shape (:obj:`tuple`): The size of the action tensor, output noise's size should be the same.
- device (:obj:`str`): The device of the action tensor, output noise's device should be the same as it.
- mu (:obj:`float`): The new mean value :math:`\mu`, you can set it to `None` if don't need it.
Returns:
- noise (:obj:`torch.Tensor`): generated action noise, \
have the same shape and device with the input action tensor
have the same shape and device with the input action tensor.
"""
if self._x is None or \
(isinstance(self._x, torch.Tensor) and self._x.shape != shape):
Expand All @@ -174,15 +175,15 @@ def __call__(self, shape: tuple, device: str, mu: Optional[float] = None) -> tor
def x0(self) -> Union[float, torch.Tensor]:
"""
Overview:
Get ``self._x0``
Get ``self._x0``.
"""
return self._x0

@x0.setter
def x0(self, _x0: Union[float, torch.Tensor]) -> None:
"""
Overview:
Set ``self._x0`` and reset ``self.x`` to ``self._x0`` as well
Set ``self._x0`` and reset ``self.x`` to ``self._x0`` as well.
"""
self._x0 = _x0
self.reset()
Expand All @@ -198,10 +199,10 @@ def create_noise_generator(noise_type: str, noise_kwargs: dict) -> BaseNoise:
or raise an KeyError. In other words, a derived noise generator must first register,
then call ``create_noise generator`` to get the instance object.
Arguments:
- noise_type (:obj:`str`): the type of noise generator to be created
- noise_type (:obj:`str`): the type of noise generator to be created.
Returns:
- noise (:obj:`BaseNoise`): the created new noise generator, should be an instance of one of \
noise_mapping's values
noise_mapping's values.
"""
if noise_type not in noise_mapping.keys():
raise KeyError("not support noise type: {}".format(noise_type))
Expand Down
18 changes: 8 additions & 10 deletions ding/rl_utils/td.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def v_nstep_td_error(
nstep: int = 1,
criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa
) -> torch.Tensor:
r"""
"""
Overview:
Multistep (n step) td_error for distributed value based algorithm
Arguments:
Expand All @@ -588,14 +588,14 @@ def v_nstep_td_error(
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
Shapes:
- data (:obj:`dist_nstep_td_data`): The v_nstep_td_data containing\
- data (:obj:`dist_nstep_td_data`): The v_nstep_td_data containing \
['v', 'next_n_v', 'reward', 'done', 'weight', 'value_gamma']
- v (:obj:`torch.FloatTensor`): :math:`(B, )` i.e. [batch_size, ]
- next_v (:obj:`torch.FloatTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
- value_gamma (:obj:`torch.Tensor`): If the remaining data in the buffer is less than n_step\
- value_gamma (:obj:`torch.Tensor`): If the remaining data in the buffer is less than n_step \
we use value_gamma as the gamma discount value for next_v rather than gamma**n_step
Examples:
>>> v = torch.randn(5).requires_grad_(True)
Expand Down Expand Up @@ -1098,7 +1098,7 @@ def qrdqn_nstep_td_error(
Overview:
Multistep (1 step or n step) td_error with in QRDQN
Arguments:
- data (:obj:`iqn_nstep_td_data`): The input data, iqn_nstep_td_data to calculate loss
- data (:obj:`qrdqn_nstep_td_data`): The input data, qrdqn_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- nstep (:obj:`int`): nstep num, default set to 1
Returns:
Expand Down Expand Up @@ -1605,18 +1605,16 @@ def multistep_forward_view(
lambda_: float,
done: Optional[torch.Tensor] = None
) -> torch.Tensor:
r"""
"""
Overview:
Same as trfl.sequence_ops.multistep_forward_view
Implementing (12.18) in Sutton & Barto
Same as trfl.sequence_ops.multistep_forward_view, which implements (12.18) in Sutton & Barto.
Assuming the first dim of input tensors correspond to the index in batch.
```
.. note::
result[T-1] = rewards[T-1] + gammas[T-1] * bootstrap_values[T]
for t in 0...T-2 :
result[t] = rewards[t] + gammas[t]*(lambdas[t]*result[t+1] + (1-lambdas[t])*bootstrap_values[t+1])
```
Assuming the first dim of input tensors correspond to the index in batch
Arguments:
- bootstrap_values (:obj:`torch.Tensor`): Estimation of the value at *step 1 to T*, of size [T_traj, batchsize]
- rewards (:obj:`torch.Tensor`): The returns from 0 to T-1, of size [T_traj, batchsize]
Expand Down
4 changes: 2 additions & 2 deletions ding/rl_utils/upgo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def tb_cross_entropy(logit, label, mask=None):


def upgo_returns(rewards: torch.Tensor, bootstrap_values: torch.Tensor) -> torch.Tensor:
r"""
"""
Overview:
Computing UPGO return targets. Also notice there is no special handling for the terminal state.
Arguments:
Expand Down Expand Up @@ -82,7 +82,7 @@ def upgo_loss(
bootstrap_values: torch.Tensor,
mask=None
) -> torch.Tensor:
r"""
"""
Overview:
Computing UPGO loss given constant gamma and lambda. There is no special handling for terminal state value,
if the last state in trajectory is the terminal, just pass a 0 as bootstrap_terminal_value.
Expand Down
14 changes: 6 additions & 8 deletions ding/rl_utils/value_rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def value_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor:
r"""
"""
Overview:
A function to reduce the scale of the action-value function.
:math: `h(x) = sign(x)(\sqrt{(abs(x)+1)} - 1) + \eps * x` .
Expand All @@ -14,14 +14,13 @@ def value_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor:
- (:obj:`torch.Tensor`) Normalized tensor.
.. note::
Observe and Look Further: Achieving Consistent Performance on Atari
(https://arxiv.org/abs/1805.11593)
Observe and Look Further: Achieving Consistent Performance on Atari (https://arxiv.org/abs/1805.11593).
"""
return torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + eps * x


def value_inv_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor:
r"""
"""
Overview:
The inverse form of value rescale.
:math: `h^{-1}(x) = sign(x)({(\frac{\sqrt{1+4\eps(|x|+1+\eps)}-1}{2\eps})}^2-1)` .
Expand All @@ -36,7 +35,7 @@ def value_inv_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor:


def symlog(x: torch.Tensor) -> torch.Tensor:
r"""
"""
Overview:
A function to normalize the targets.
:math: `symlog(x) = sign(x)(\ln{|x|+1})` .
Expand All @@ -46,14 +45,13 @@ def symlog(x: torch.Tensor) -> torch.Tensor:
- (:obj:`torch.Tensor`) Normalized tensor.
.. note::
Mastering Diverse Domains through World Models
(https://arxiv.org/abs/2301.04104)
Mastering Diverse Domains through World Models (https://arxiv.org/abs/2301.04104)
"""
return torch.sign(x) * (torch.log(torch.abs(x) + 1))


def inv_symlog(x: torch.Tensor) -> torch.Tensor:
r"""
"""
Overview:
The inverse form of symlog.
:math: `symexp(x) = sign(x)(\exp{|x|}-1)` .
Expand Down

0 comments on commit 96c4955

Please sign in to comment.