Skip to content

Commit

Permalink
fix(pu): fix output_activation and output_norm in MLP
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Apr 25, 2023
1 parent 782bb96 commit 5a514fd
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 63 deletions.
32 changes: 16 additions & 16 deletions ding/torch_utils/network/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ def MLP(
norm_type: str = None,
use_dropout: bool = False,
dropout_probability: float = 0.5,
output_activation: nn.Module = None,
output_norm_type: str = None,
output_activation: bool = True,
output_norm: bool = True,
last_linear_layer_init_zero: bool = False
):
r"""
Expand All @@ -333,9 +333,11 @@ def MLP(
- norm_type (:obj:`str`): The type of the normalization.
- use_dropout (:obj:`bool`): Whether to use dropout in the fully-connected block.
- dropout_probability (:obj:`float`): The probability of an element to be zeroed in the dropout. Default: 0.5.
- output_activation (:obj:`nn.Module`): The activation function in the last layer. Default: None.
- output_norm_type (:obj:`str`): The type of the normalization in the last layer. Default: None.
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initialization for the last linear layer
- output_activation (:obj:`bool`): Whether to use activation in the output layer. If True,
we use the same activation as front layers. Default: True.
- output_norm (:obj:`bool`): Whether to use normalization in the output layer. If True,
we use the same normalization as front layers. Default: True.
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last linear layer
(including w and b), which can provide stable zero outputs in the beginning,
usually used in the policy network in RL settings.
Returns:
Expand Down Expand Up @@ -367,20 +369,18 @@ def MLP(
out_channels = channels[-1]
block.append(layer_fn(in_channels, out_channels))
"""
In the final layer of a neural network, the output normalization and activation functions are typically determined
In the final layer of a neural network, whether to use normalization and activation are typically determined
based on user specifications. These specifications depend on the problem at hand and the desired properties of
the model's output.
"""
if output_norm_type is not None and output_activation is not None:
# The last layer uses the user-specified output_norm and output_activation.
block.append(build_normalization(output_norm_type, dim=1)(out_channels))
block.append(output_activation)
elif output_activation is not None and output_norm_type is None:
# The last layer uses the user-specified output_activation.
block.append(output_activation)
elif output_activation is None and output_norm_type is not None:
# The last layer uses the user-specified output_norm.
block.append(build_normalization(output_norm_type, dim=1)(out_channels))
if output_norm is True:
# The last layer uses the same norm as front layers.
if norm_type is not None:
block.append(build_normalization(norm_type, dim=1)(out_channels))
if output_activation is True:
# The last layer uses the same activation as front layers.
if activation is not None:
block.append(activation)

if last_linear_layer_init_zero:
# Locate the last linear layer and initialize its weights and biases to 0.
Expand Down
86 changes: 39 additions & 47 deletions ding/torch_utils/network/tests/test_nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,53 +49,45 @@ def test_mlp(self):
layer_num = 3
input_tensor = torch.rand(batch_size, in_channels).requires_grad_(True)

# Test case 1: Simple MLP without dropout, normalization, or output activation.
model = MLP(in_channels, hidden_channels, out_channels, layer_num)
output_tensor = self.run_model(input_tensor, model)
assert output_tensor.shape == (batch_size, out_channels)

# Test case 2: MLP with dropout and normalization.
for norm_type in ["LN", "BN", None]:
model = MLP(
in_channels,
hidden_channels,
out_channels,
layer_num,
use_dropout=True,
dropout_probability=0.5,
norm_type=norm_type
)
output_tensor = self.run_model(input_tensor, model)
assert output_tensor.shape == (batch_size, out_channels)

for act in [torch.nn.LeakyReLU(), torch.nn.ReLU(), torch.nn.Sigmoid(), None]:
for norm_type in ["LN", "BN", None]:
# Test case 3: MLP without last linear layer initialized to 0.
model = MLP(
in_channels, hidden_channels, out_channels, layer_num, norm_type=norm_type, output_activation=act
)
output_tensor = self.run_model(input_tensor, model)
assert output_tensor.shape == (batch_size, out_channels)

# Test case 4: MLP with last linear layer initialized to 0.
model = MLP(
in_channels,
hidden_channels,
out_channels,
layer_num,
norm_type=norm_type,
output_activation=act,
last_linear_layer_init_zero=True
)
output_tensor = self.run_model(input_tensor, model)
assert output_tensor.shape == (batch_size, out_channels)
last_linear_layer = None
for layer in reversed(model):
if isinstance(layer, torch.nn.Linear):
last_linear_layer = layer
break
assert_allclose(last_linear_layer.weight, torch.zeros_like(last_linear_layer.weight))
assert_allclose(last_linear_layer.bias, torch.zeros_like(last_linear_layer.bias))
for output_activation in [True, False]:
for output_norm in [True, False]:
for activation in [torch.nn.ReLU(), torch.nn.LeakyReLU(), torch.nn.Tanh(), None]:
for norm_type in ["LN", "BN", None]:
# Test case 1: MLP without last linear layer initialized to 0.
model = MLP(
in_channels,
hidden_channels,
out_channels,
layer_num,
activation=activation,
norm_type=norm_type,
output_activation=output_activation,
output_norm=output_norm
)
output_tensor = self.run_model(input_tensor, model)
assert output_tensor.shape == (batch_size, out_channels)

# Test case 2: MLP with last linear layer initialized to 0.
model = MLP(
in_channels,
hidden_channels,
out_channels,
layer_num,
activation=activation,
norm_type=norm_type,
output_activation=output_activation,
output_norm=output_norm,
last_linear_layer_init_zero=True
)
output_tensor = self.run_model(input_tensor, model)
assert output_tensor.shape == (batch_size, out_channels)
last_linear_layer = None
for layer in reversed(model):
if isinstance(layer, torch.nn.Linear):
last_linear_layer = layer
break
assert_allclose(last_linear_layer.weight, torch.zeros_like(last_linear_layer.weight))
assert_allclose(last_linear_layer.bias, torch.zeros_like(last_linear_layer.bias))

def test_conv1d_block(self):
length = 2
Expand Down

0 comments on commit 5a514fd

Please sign in to comment.