Skip to content

Commit

Permalink
feature(lxy): add popart & value rescale & symlog to ppof (#605)
Browse files Browse the repository at this point in the history
* add popart & value rescale & symlog

* polish: enable_save_replay

* add unittest of popart and symlog, polish format

* polish assert and comment

* polish popart update
  • Loading branch information
karroyan committed Apr 7, 2023
1 parent 44226be commit 1cb1038
Show file tree
Hide file tree
Showing 11 changed files with 363 additions and 33 deletions.
33 changes: 24 additions & 9 deletions ding/bonus/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from copy import deepcopy
from ding.utils import SequenceType, squeeze
from ding.model.common import ReparameterizationHead, RegressionHead, MultiHead, \
FCEncoder, ConvEncoder, IMPALAConvEncoder
FCEncoder, ConvEncoder, IMPALAConvEncoder, PopArtVHead
from ding.torch_utils import MLP, fc_block


Expand Down Expand Up @@ -57,6 +57,7 @@ def __init__(
fixed_sigma_value: Optional[int] = 0.3,
bound_type: Optional[str] = None,
encoder: Optional[torch.nn.Module] = None,
popart_head=False,
) -> None:
super(PPOFModel, self).__init__()
obs_shape = squeeze(obs_shape)
Expand Down Expand Up @@ -108,9 +109,15 @@ def new_encoder(outsize):
self.critic_encoder = new_encoder(critic_head_hidden_size)

# Head Type
self.critic_head = RegressionHead(
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
)
if not popart_head:
self.critic_head = RegressionHead(
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
)
else:
self.critic_head = PopArtVHead(
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
)

self.action_space = action_space
assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space
if self.action_space == 'continuous':
Expand Down Expand Up @@ -207,7 +214,7 @@ def compute_critic(self, x: ttorch.Tensor) -> ttorch.Tensor:
else:
x = self.critic_encoder(x)
x = self.critic_head(x)
return x['pred']
return x

def compute_actor_critic(self, x: ttorch.Tensor) -> ttorch.Tensor:
if self.share_encoder:
Expand All @@ -216,15 +223,23 @@ def compute_actor_critic(self, x: ttorch.Tensor) -> ttorch.Tensor:
actor_embedding = self.actor_encoder(x)
critic_embedding = self.critic_encoder(x)

value = self.critic_head(critic_embedding)['pred']
value = self.critic_head(critic_embedding)

if self.action_space == 'discrete':
logit = self.actor_head(actor_embedding)
return ttorch.as_tensor({'logit': logit, 'value': value})
return ttorch.as_tensor({'logit': logit, 'value': value['pred']})
elif self.action_space == 'continuous':
x = self.actor_head(actor_embedding)
return ttorch.as_tensor({'logit': x, 'value': value})
return ttorch.as_tensor({'logit': x, 'value': value['pred']})
elif self.action_space == 'hybrid':
action_type = self.actor_head[0](actor_embedding)
action_args = self.actor_head[1](actor_embedding)
return ttorch.as_tensor({'logit': {'action_type': action_type, 'action_args': action_args}, 'value': value})
return ttorch.as_tensor(
{
'logit': {
'action_type': action_type,
'action_args': action_args
},
'value': value['pred']
}
)
21 changes: 18 additions & 3 deletions ding/bonus/ppof.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,25 @@ def __init__(
action_shape = get_hybrid_shape(action_space)
else:
action_shape = action_space.shape

# Three types of value normalization is supported currently
assert self.cfg.value_norm in ['popart', 'value_rescale', 'symlog']
if model is None:
model = PPOFModel(
self.env.observation_space.shape, action_shape, action_space=self.cfg.action_space, **self.cfg.model
)
if self.cfg.value_norm != 'popart':
model = PPOFModel(
self.env.observation_space.shape,
action_shape,
action_space=self.cfg.action_space,
**self.cfg.model
)
else:
model = PPOFModel(
self.env.observation_space.shape,
action_shape,
action_space=self.cfg.action_space,
popart_head=True,
**self.cfg.model
)
self.policy = PPOFPolicy(self.cfg, model=model)
if policy_state_dict is not None:
self.policy.load_state_dict(policy_state_dict)
Expand Down
2 changes: 1 addition & 1 deletion ding/model/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .head import DiscreteHead, DuelingHead, DistributionHead, RainbowHead, QRDQNHead, \
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \
independent_normal_dist, AttentionPolicyHead
independent_normal_dist, AttentionPolicyHead, PopArtVHead
from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder
from .utils import create_model
75 changes: 74 additions & 1 deletion ding/model/common/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn.functional as F
from torch.distributions import Normal, Independent

from ding.torch_utils import fc_block, noise_block, NoiseLinearLayer, MLP
from ding.torch_utils import fc_block, noise_block, NoiseLinearLayer, MLP, PopArt
from ding.rl_utils import beta_function_map
from ding.utils import lists_to_dicts, SequenceType

Expand Down Expand Up @@ -1176,6 +1176,78 @@ def forward(self, x: torch.Tensor) -> Dict:
return {'mu': mu, 'sigma': sigma}


class PopArtVHead(nn.Module):
"""
Overview:
The ``PopArtVHead`` used to output discrete actions logit with the last layer as popart. \
Input is a (:obj:`torch.Tensor`) of shape ``(B, N)`` and returns a (:obj:`Dict`) containing \
output ``logit``.
Interfaces:
``__init__``, ``forward``.
"""

def __init__(
self,
hidden_size: int,
output_size: int,
layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
noise: Optional[bool] = False,
) -> None:
"""
Overview:
Init the ``PopArtVHead`` layers according to the provided arguments.
Arguments:
- hidden_size (:obj:`int`): The ``hidden_size`` of the MLP connected to ``PopArtVHead``.
- output_size (:obj:`int`): The number of outputs.
- layer_num (:obj:`int`): The number of layers used in the network to compute Q value output.
- activation (:obj:`nn.Module`): The type of activation function to use in MLP. \
If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
- norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
for more details. Default ``None``.
- noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \
Default ``False``.
"""
super(PopArtVHead, self).__init__()
layer = NoiseLinearLayer if noise else nn.Linear
self.popart = PopArt(hidden_size, output_size)
self.Q = nn.Sequential(
MLP(
hidden_size,
hidden_size,
hidden_size,
layer_num,
layer_fn=layer,
activation=activation,
norm_type=norm_type
), self.popart
)

def forward(self, x: torch.Tensor) -> Dict:
"""
Overview:
Use encoded embedding tensor to run MLP with ``PopArtVHead`` and return the normalized prediction and \
the unnormalized prediction dictionary.
Arguments:
- x (:obj:`torch.Tensor`): Tensor containing input embedding.
Returns:
- outputs (:obj:`Dict`): Dict containing keyword ``pred`` (:obj:`torch.Tensor`) \
and ``unnormalized_pred`` (:obj:`torch.Tensor`).
Shapes:
- x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``.
- logit: :math:`(B, M)`, where ``M = output_size``.
Examples:
>>> head = PopArtVHead(64, 64)
>>> inputs = torch.randn(4, 64)
>>> outputs = head(inputs)
>>> assert isinstance(outputs, dict) and outputs['pred'].shape == torch.Size([4, 64]) and \
outputs['unnormalized_pred'].shape == torch.Size([4, 64])
"""
x = self.Q(x)
return x


class AttentionPolicyHead(nn.Module):

def __init__(self) -> None:
Expand Down Expand Up @@ -1266,6 +1338,7 @@ def independent_normal_dist(logits: Union[List, Dict]) -> torch.distributions.Di
# continuous
'regression': RegressionHead,
'reparameterization': ReparameterizationHead,
'popart': PopArtVHead,
# multi
'multi': MultiHead,
}
51 changes: 37 additions & 14 deletions ding/policy/ppof.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ding.rl_utils import ppo_data, ppo_error, ppo_policy_error, ppo_policy_data, gae, gae_data, ppo_error_continuous, \
get_gae, ppo_policy_error_continuous, ArgmaxSampler, MultinomialSampler, ReparameterizationSampler, MuSampler, \
HybridStochasticSampler, HybridDeterminsticSampler
HybridStochasticSampler, HybridDeterminsticSampler, value_transform, value_inv_transform, symlog, inv_symlog
from ding.utils import POLICY_REGISTRY, RunningMeanStd


Expand All @@ -32,7 +32,7 @@ class PPOFPolicy:
entropy_weight=0.01,
clip_ratio=0.2,
adv_norm=True,
value_norm=True,
value_norm='symlog',
ppo_param_init=True,
grad_norm=0.5,
# collect
Expand Down Expand Up @@ -148,25 +148,48 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
for epoch in range(self._cfg.epoch_per_collect):
# recompute adv
with torch.no_grad():
# get the value dictionary
# In popart, the dictionary has two keys: 'pred' and 'unnormalized_pred'
value = self._model.compute_critic(data.obs)
next_value = self._model.compute_critic(data.next_obs)
if self._cfg.value_norm:
value *= self._running_mean_std.std
next_value *= self._running_mean_std.std
reward = data.reward

assert self._cfg.value_norm in ['popart', 'value_rescale', 'symlog'],\
'Not supported value normalization! Value normalization supported: popart, value rescale, symlog'

if self._cfg.value_norm == 'popart':
unnormalized_value = value['unnormalized_pred']
unnormalized_next_value = value['unnormalized_pred']

mu = self._model.critic_head.popart.mu
sigma = self._model.critic_head.popart.sigma
reward = (reward - mu) / sigma

value = value['pred']
next_value = next_value['pred']
elif self._cfg.value_norm == 'value_rescale':
value = value_inv_transform(value['pred'])
next_value = value_inv_transform(next_value['pred'])
elif self._cfg.value_norm == 'symlog':
value = inv_symlog(value['pred'])
next_value = inv_symlog(next_value['pred'])

traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory
adv_data = gae_data(value, next_value, data.reward, data.done, traj_flag)
adv_data = gae_data(value, next_value, reward, data.done, traj_flag)
data.adv = gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda)

unnormalized_returns = value + data.adv
unnormalized_returns = value + data.adv # In popart, this return is normalized

if self._cfg.value_norm:
data.value = value / self._running_mean_std.std
data.return_ = unnormalized_returns / self._running_mean_std.std
self._running_mean_std.update(unnormalized_returns.cpu().numpy())
else:
data.value = value
data.return_ = unnormalized_returns
if self._cfg.value_norm == 'popart':
self._model.critic_head.popart.update_parameters((data.reward).unsqueeze(1))
elif self._cfg.value_norm == 'value_rescale':
value = value_transform(value)
unnormalized_returns = value_transform(unnormalized_returns)
elif self._cfg.value_norm == 'symlog':
value = symlog(value)
unnormalized_returns = symlog(unnormalized_returns)
data.value = value
data.return_ = unnormalized_returns

# inner training loop
split_data = ttorch.split(data, self._cfg.batch_size)
Expand Down
2 changes: 1 addition & 1 deletion ding/rl_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .vtrace import vtrace_loss, compute_importance_weights
from .upgo import upgo_loss
from .adder import get_gae, get_gae_with_default_last_value, get_nstep_return_data, get_train_sample
from .value_rescale import value_transform, value_inv_transform
from .value_rescale import value_transform, value_inv_transform, symlog, inv_symlog
from .vtrace import vtrace_data, vtrace_error_discrete_action, vtrace_error_continuous_action
from .beta_function import beta_function_map
from .retrace import compute_q_retraces
Expand Down
25 changes: 24 additions & 1 deletion ding/rl_utils/tests/test_value_rescale.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import torch
from ding.rl_utils.value_rescale import value_inv_transform, value_transform
from ding.rl_utils.value_rescale import value_inv_transform, value_transform, symlog, inv_symlog


@pytest.mark.unittest
Expand All @@ -24,3 +24,26 @@ def test_trans_inverse(self):
diff = value_inv_transform(value_transform(t)) - t
assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0
assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0


@pytest.mark.unittest
class TestSymlog:

def test_symlog(self):
for _ in range(10):
t = torch.rand((3, 4))
assert isinstance(symlog(t), torch.Tensor)
assert symlog(t).shape == t.shape

def test_inv_symlog(self):
for _ in range(10):
t = torch.rand((3, 4))
assert isinstance(inv_symlog(t), torch.Tensor)
assert inv_symlog(t).shape == t.shape

def test_trans_inverse(self):
for _ in range(10):
t = torch.rand((4, 16))
diff = inv_symlog(symlog(t)) - t
assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0
assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0
51 changes: 48 additions & 3 deletions ding/rl_utils/value_rescale.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,65 @@
"""
Referenced papar <Observe and Look Further: Achieving Consistent Performance on Atari>
"""
import torch


def value_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor:
r"""
Overview:
A function to reduce the scale of the action-value function.
:math: `h(x) = sign(x)(\sqrt{(abs(x)+1)} - 1) + \eps * x` .
Arguments:
- x: (:obj:`torch.Tensor`) The input tensor to be normalized.
- eps: (:obj:`float`) The coefficient of the additive regularization term \
to ensure h^{-1} is Lipschitz continuous
Returns:
- (:obj:`torch.Tensor`) Normalized tensor.
.. note::
Observe and Look Further: Achieving Consistent Performance on Atari
(https://arxiv.org/abs/1805.11593)
"""
return torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + eps * x


def value_inv_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor:
r"""
Overview:
The inverse form of value rescale.
:math: `h^{-1}(x) = sign(x)({(\frac{\sqrt{1+4\eps(|x|+1+\eps)}-1}{2\eps})}^2-1)` .
Arguments:
- x: (:obj:`torch.Tensor`) The input tensor to be unnormalized.
- eps: (:obj:`float`) The coefficient of the additive regularization term \
to ensure h^{-1} is Lipschitz continuous
Returns:
- (:obj:`torch.Tensor`) Unnormalized tensor.
"""
return torch.sign(x) * (((torch.sqrt(1 + 4 * eps * (torch.abs(x) + 1 + eps)) - 1) / (2 * eps)) ** 2 - 1)


def symlog(x: torch.Tensor) -> torch.Tensor:
r"""
Overview:
A function to normalize the targets.
:math: `symlog(x) = sign(x)(\ln{|x|+1})` .
Arguments:
- x: (:obj:`torch.Tensor`) The input tensor to be normalized.
Returns:
- (:obj:`torch.Tensor`) Normalized tensor.
.. note::
Mastering Diverse Domains through World Models
(https://arxiv.org/abs/2301.04104)
"""
return torch.sign(x) * (torch.log(torch.abs(x) + 1))


def inv_symlog(x: torch.Tensor) -> torch.Tensor:
r"""
Overview:
The inverse form of symlog.
:math: `symexp(x) = sign(x)(\exp{|x|}-1)` .
Arguments:
- x: (:obj:`torch.Tensor`) The input tensor to be unnormalized.
Returns:
- (:obj:`torch.Tensor`) Unnormalized tensor.
"""
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
Loading

0 comments on commit 1cb1038

Please sign in to comment.