-
Notifications
You must be signed in to change notification settings - Fork 348
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(lxy): add popart & value rescale & symlog to ppof (#605)
* add popart & value rescale & symlog * polish: enable_save_replay * add unittest of popart and symlog, polish format * polish assert and comment * polish popart update
- Loading branch information
Showing
11 changed files
with
363 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
from .head import DiscreteHead, DuelingHead, DistributionHead, RainbowHead, QRDQNHead, \ | ||
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \ | ||
independent_normal_dist, AttentionPolicyHead | ||
independent_normal_dist, AttentionPolicyHead, PopArtVHead | ||
from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder | ||
from .utils import create_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,65 @@ | ||
""" | ||
Referenced papar <Observe and Look Further: Achieving Consistent Performance on Atari> | ||
""" | ||
import torch | ||
|
||
|
||
def value_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor: | ||
r""" | ||
Overview: | ||
A function to reduce the scale of the action-value function. | ||
:math: `h(x) = sign(x)(\sqrt{(abs(x)+1)} - 1) + \eps * x` . | ||
Arguments: | ||
- x: (:obj:`torch.Tensor`) The input tensor to be normalized. | ||
- eps: (:obj:`float`) The coefficient of the additive regularization term \ | ||
to ensure h^{-1} is Lipschitz continuous | ||
Returns: | ||
- (:obj:`torch.Tensor`) Normalized tensor. | ||
.. note:: | ||
Observe and Look Further: Achieving Consistent Performance on Atari | ||
(https://arxiv.org/abs/1805.11593) | ||
""" | ||
return torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + eps * x | ||
|
||
|
||
def value_inv_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor: | ||
r""" | ||
Overview: | ||
The inverse form of value rescale. | ||
:math: `h^{-1}(x) = sign(x)({(\frac{\sqrt{1+4\eps(|x|+1+\eps)}-1}{2\eps})}^2-1)` . | ||
Arguments: | ||
- x: (:obj:`torch.Tensor`) The input tensor to be unnormalized. | ||
- eps: (:obj:`float`) The coefficient of the additive regularization term \ | ||
to ensure h^{-1} is Lipschitz continuous | ||
Returns: | ||
- (:obj:`torch.Tensor`) Unnormalized tensor. | ||
""" | ||
return torch.sign(x) * (((torch.sqrt(1 + 4 * eps * (torch.abs(x) + 1 + eps)) - 1) / (2 * eps)) ** 2 - 1) | ||
|
||
|
||
def symlog(x: torch.Tensor) -> torch.Tensor: | ||
r""" | ||
Overview: | ||
A function to normalize the targets. | ||
:math: `symlog(x) = sign(x)(\ln{|x|+1})` . | ||
Arguments: | ||
- x: (:obj:`torch.Tensor`) The input tensor to be normalized. | ||
Returns: | ||
- (:obj:`torch.Tensor`) Normalized tensor. | ||
.. note:: | ||
Mastering Diverse Domains through World Models | ||
(https://arxiv.org/abs/2301.04104) | ||
""" | ||
return torch.sign(x) * (torch.log(torch.abs(x) + 1)) | ||
|
||
|
||
def inv_symlog(x: torch.Tensor) -> torch.Tensor: | ||
r""" | ||
Overview: | ||
The inverse form of symlog. | ||
:math: `symexp(x) = sign(x)(\exp{|x|}-1)` . | ||
Arguments: | ||
- x: (:obj:`torch.Tensor`) The input tensor to be unnormalized. | ||
Returns: | ||
- (:obj:`torch.Tensor`) Unnormalized tensor. | ||
""" | ||
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) |
Oops, something went wrong.