Skip to content

Commit

Permalink
style(pu): polish the annotations in MLP, yapf format
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Apr 24, 2023
1 parent 451fa45 commit dd238a2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
37 changes: 21 additions & 16 deletions ding/torch_utils/network/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def MLP(
dropout_probability: float = 0.5,
output_activation: nn.Module = None,
output_norm_type: str = None,
last_linear_layer_weight_bias_init_zero: bool = False
last_linear_layer_init_zero: bool = False
):
r"""
Overview:
Expand All @@ -328,15 +328,16 @@ def MLP(
- hidden_channels (:obj:`int`): Number of channels in the hidden tensor.
- out_channels (:obj:`int`): Number of channels in the output tensor.
- layer_num (:obj:`int`): Number of layers.
- layer_fn (:obj:`Callable`): layer function.
- activation (:obj:`nn.Module`): the optional activation function.
- norm_type (:obj:`str`): type of the normalization.
- use_dropout (:obj:`bool`): whether to use dropout in the fully-connected block.
- dropout_probability (:obj:`float`): probability of an element to be zeroed in the dropout. Default: 0.5.
- output_activation (:obj:`nn.Module`): the activation function in the last layer. Default: None.
- output_norm_type (:obj:`str`): the type of the normalization in the last layer. Default: None.
- last_linear_layer_weight_bias_init_zero (:obj:`bool`): zero initialization for the last linear layer
(including w and b), which can provide stable zero outputs in the beginning.
- layer_fn (:obj:`Callable`): Layer function.
- activation (:obj:`nn.Module`): The optional activation function.
- norm_type (:obj:`str`): The type of the normalization.
- use_dropout (:obj:`bool`): Whether to use dropout in the fully-connected block.
- dropout_probability (:obj:`float`): The probability of an element to be zeroed in the dropout. Default: 0.5.
- output_activation (:obj:`nn.Module`): The activation function in the last layer. Default: None.
- output_norm_type (:obj:`str`): The type of the normalization in the last layer. Default: None.
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initialization for the last linear layer
(including w and b), which can provide stable zero outputs in the beginning,
usually used in the policy network in RL settings.
Returns:
- block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block.
Expand All @@ -361,23 +362,27 @@ def MLP(
if use_dropout:
block.append(nn.Dropout(dropout_probability))

# the last layer
# The last layer
in_channels = channels[-2]
out_channels = channels[-1]
block.append(layer_fn(in_channels, out_channels))

"""
In the final layer of a neural network, the output normalization and activation functions are typically determined
based on user specifications. These specifications depend on the problem at hand and the desired properties of
the model's output.
"""
if output_norm_type is not None and output_activation is not None:
# the last layer use the user specified output_norm and output_activation
# The last layer uses the user-specified output_norm and output_activation.
block.append(build_normalization(output_norm_type, dim=1)(out_channels))
block.append(output_activation)

Check warning on line 377 in ding/torch_utils/network/nn_module.py

View check run for this annotation

Codecov / codecov/patch

ding/torch_utils/network/nn_module.py#L376-L377

Added lines #L376 - L377 were not covered by tests
elif output_activation is not None and output_norm_type is None:
# the last layer use the user specified output_activation
# The last layer uses the user-specified output_activation.
block.append(output_activation)
elif output_activation is None and output_norm_type is not None:
# the last layer use the user specified output_norm
# The last layer uses the user-specified output_norm.
block.append(build_normalization(output_norm_type, dim=1)(out_channels))

Check warning on line 383 in ding/torch_utils/network/nn_module.py

View check run for this annotation

Codecov / codecov/patch

ding/torch_utils/network/nn_module.py#L383

Added line #L383 was not covered by tests

if last_linear_layer_weight_bias_init_zero:
if last_linear_layer_init_zero:
# Locate the last linear layer and initialize its weights and biases to 0.
for _, layer in enumerate(reversed(block)):
if isinstance(layer, nn.Linear):
Expand Down
10 changes: 5 additions & 5 deletions ding/torch_utils/network/tests/test_nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ def test_mlp(self):
layer_num = 3
input_tensor = torch.rand(batch_size, in_channels).requires_grad_(True)

# Test case 1: Simple MLP without dropout, normalization, or output activation
# Test case 1: Simple MLP without dropout, normalization, or output activation.
model = MLP(in_channels, hidden_channels, out_channels, layer_num)
output_tensor = self.run_model(input_tensor, model)
assert output_tensor.shape == (batch_size, out_channels)

# Test case 2: MLP with dropout and normalization
# Test case 2: MLP with dropout and normalization.
for norm_type in ["LN", "BN", None]:
model = MLP(
in_channels,
Expand All @@ -70,22 +70,22 @@ def test_mlp(self):

for act in [torch.nn.LeakyReLU(), torch.nn.ReLU(), torch.nn.Sigmoid(), None]:
for norm_type in ["LN", "BN", None]:
# Test case 3: MLP without last linear layer initialized to 0
# Test case 3: MLP without last linear layer initialized to 0.
model = MLP(
in_channels, hidden_channels, out_channels, layer_num, norm_type=norm_type, output_activation=act
)
output_tensor = self.run_model(input_tensor, model)
assert output_tensor.shape == (batch_size, out_channels)

# Test case 4: MLP with last linear layer initialized to 0
# Test case 4: MLP with last linear layer initialized to 0.
model = MLP(
in_channels,
hidden_channels,
out_channels,
layer_num,
norm_type=norm_type,
output_activation=act,
last_linear_layer_weight_bias_init_zero=True
last_linear_layer_init_zero=True
)
output_tensor = self.run_model(input_tensor, model)
assert output_tensor.shape == (batch_size, out_channels)
Expand Down

0 comments on commit dd238a2

Please sign in to comment.