Skip to content

Commit

Permalink
Specify default initialization schemes for modules in docs (pytorch#9038
Browse files Browse the repository at this point in the history
)

Summary: This closes pytorch#6906 .

Reviewed By: ezyang

Differential Revision: D8698632

Pulled By: weiyangfb

fbshipit-source-id: 259c1dbdc264a8e9f83e196fa72d135babd97d48
  • Loading branch information
vishwakftw authored and jramseyer committed Jul 30, 2018
1 parent 0c2d105 commit 7ca6b08
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 45 deletions.
20 changes: 12 additions & 8 deletions torch/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .module import Module
from torch.nn.parameter import Parameter
from .. import functional as F
from .. import init


# TODO: check contiguous in THNN
Expand Down Expand Up @@ -42,8 +43,8 @@ def reset_running_stats(self):
def reset_parameters(self):
self.reset_running_stats()
if self.affine:
self.weight.data.uniform_()
self.bias.data.zero_()
init.uniform_(self.weight)
init.zeros_(self.bias)

def _check_input_dim(self, input):
raise NotImplementedError
Expand Down Expand Up @@ -96,9 +97,10 @@ class BatchNorm1d(_BatchNorm):
The mean and standard-deviation are calculated per-dimension over
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
of size `C` (where `C` is the input size).
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled
from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
By default, during training this layer keeps running estimates of its
Also by default, during training this layer keeps running estimates of its
computed mean and variance, which are then used for normalization during
evaluation. The running estimates are kept with a default :attr:`momentum`
of 0.1.
Expand Down Expand Up @@ -167,9 +169,10 @@ class BatchNorm2d(_BatchNorm):
The mean and standard-deviation are calculated per-dimension over
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
of size `C` (where `C` is the input size).
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled
from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
By default, during training this layer keeps running estimates of its
Also by default, during training this layer keeps running estimates of its
computed mean and variance, which are then used for normalization during
evaluation. The running estimates are kept with a default :attr:`momentum`
of 0.1.
Expand Down Expand Up @@ -238,9 +241,10 @@ class BatchNorm3d(_BatchNorm):
The mean and standard-deviation are calculated per-dimension over
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
of size `C` (where `C` is the input size).
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled
from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
By default, during training this layer keeps running estimates of its
Also by default, during training this layer keeps running estimates of its
computed mean and variance, which are then used for normalization during
evaluation. The running estimates are kept with a default :attr:`momentum`
of 0.1.
Expand Down
58 changes: 46 additions & 12 deletions torch/nn/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch.nn.parameter import Parameter
from .. import functional as F
from .. import init
from .module import Module
from .utils import _single, _pair, _triple

Expand Down Expand Up @@ -39,12 +40,11 @@ def __init__(self, in_channels, out_channels, kernel_size, stride,

def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
init.kaiming_uniform_(self.weight, a=math.sqrt(5))

This comment has been minimized.

Copy link
@pkaiserui

pkaiserui Mar 19, 2019

What is the reasoning for using the number 5 in a=math.sqrt(5)? Is there a research paper? Using five seems slightly arbitrary and would appreciate any additional clarity or documentation as to why.

if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)

def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
Expand Down Expand Up @@ -144,9 +144,13 @@ class Conv1d(_ConvNd):
Attributes:
weight (Tensor): the learnable weights of the module of shape
(out_channels, in_channels, kernel_size)
(out_channels, in_channels, kernel_size). The values of these weights are sampled from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in_channels} * \text{kernel_size}}`
bias (Tensor): the learnable bias of the module of shape
(out_channels)
(out_channels). If :attr:`bias` is ``True``, then the values of these weights are
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in_channels} * \text{kernel_size}}`
Examples::
Expand Down Expand Up @@ -265,8 +269,14 @@ class Conv2d(_ConvNd):
Attributes:
weight (Tensor): the learnable weights of the module of shape
(out_channels, in_channels, kernel_size[0], kernel_size[1])
bias (Tensor): the learnable bias of the module of shape (out_channels)
(out_channels, in_channels, kernel_size[0], kernel_size[1]).
The values of these weights are sampled from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in_channels} * \prod_{i=0}^{1}\text{kernel_size[i]}}`
bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,
then the values of these weights are
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in_channels} * \prod_{i=0}^{1}\text{kernel_size[i]}}`
Examples::
Expand Down Expand Up @@ -388,7 +398,13 @@ class Conv3d(_ConvNd):
Attributes:
weight (Tensor): the learnable weights of the module of shape
(out_channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2])
bias (Tensor): the learnable bias of the module of shape (out_channels)
The values of these weights are sampled from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in_channels} * \prod_{i=0}^{2}\text{kernel_size[i]}}`
bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,
then the values of these weights are
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in_channels} * \prod_{i=0}^{2}\text{kernel_size[i]}}`
Examples::
Expand Down Expand Up @@ -537,8 +553,14 @@ class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
Attributes:
weight (Tensor): the learnable weights of the module of shape
(in_channels, out_channels, kernel_size[0], kernel_size[1])
bias (Tensor): the learnable bias of the module of shape (out_channels)
(in_channels, out_channels, kernel_size[0], kernel_size[1]). The values
of these weights are sampled from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in_channels} * \text{kernel_size}}`
bias (Tensor): the learnable bias of the module of shape (out_channels).
If :attr:`bias` is ``True``, then the values of these weights are
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in_channels} * \text{kernel_size}}`
"""

def __init__(self, in_channels, out_channels, kernel_size, stride=1,
Expand Down Expand Up @@ -645,7 +667,13 @@ class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
Attributes:
weight (Tensor): the learnable weights of the module of shape
(in_channels, out_channels, kernel_size[0], kernel_size[1])
The values of these weights are sampled from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in_channels} * \prod_{i=0}^{1}\text{kernel_size[i]}}`
bias (Tensor): the learnable bias of the module of shape (out_channels)
If :attr:`bias` is ``True``, then the values of these weights are
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in_channels} * \prod_{i=0}^{1}\text{kernel_size[i]}}`
Examples::
Expand Down Expand Up @@ -782,7 +810,13 @@ class ConvTranspose3d(_ConvTransposeMixin, _ConvNd):
Attributes:
weight (Tensor): the learnable weights of the module of shape
(in_channels, out_channels, kernel_size[0], kernel_size[1], kernel_size[2])
The values of these weights are sampled from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in_channels} * \prod_{i=0}^{2}\text{kernel_size[i]}}`
bias (Tensor): the learnable bias of the module of shape (out_channels)
If :attr:`bias` is ``True``, then the values of these weights are
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in_channels} * \prod_{i=0}^{2}\text{kernel_size[i]}}`
Examples::
Expand Down
9 changes: 6 additions & 3 deletions torch/nn/modules/instancenorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ class InstanceNorm1d(_InstanceNorm):
eps: a value added to the denominator for numerical stability. Default: 1e-5
momentum: the value used for the running_mean and running_var computation. Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``False``
learnable affine parameters, initialized the same way as done for batch normalization.
Default: ``False``.
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
Expand Down Expand Up @@ -148,7 +149,8 @@ class InstanceNorm2d(_InstanceNorm):
eps: a value added to the denominator for numerical stability. Default: 1e-5
momentum: the value used for the running_mean and running_var computation. Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``False``
learnable affine parameters, initialized the same way as done for batch normalization.
Default: ``False``.
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
Expand Down Expand Up @@ -212,7 +214,8 @@ class InstanceNorm3d(_InstanceNorm):
eps: a value added to the denominator for numerical stability. Default: 1e-5
momentum: the value used for the running_mean and running_var computation. Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``False``
learnable affine parameters, initialized the same way as done for batch normalization.
Default: ``False``.
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
Expand Down
30 changes: 21 additions & 9 deletions torch/nn/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch.nn.parameter import Parameter
from .. import functional as F
from .. import init
from .module import Module


Expand All @@ -23,8 +24,13 @@ class Linear(Module):
Attributes:
weight: the learnable weights of the module of shape
`(out_features x in_features)`
bias: the learnable bias of the module of shape `(out_features)`
`(out_features x in_features)`. The values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in_features}}`
bias: the learnable bias of the module of shape `(out_features)`.
If :attr:`bias` is ``True``, the values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in_features}}`
Examples::
Expand All @@ -46,10 +52,11 @@ def __init__(self, in_features, out_features, bias=True):
self.reset_parameters()

def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)

def forward(self, input):
return F.linear(input, self.weight, self.bias)
Expand Down Expand Up @@ -80,8 +87,13 @@ class Bilinear(Module):
Attributes:
weight: the learnable weights of the module of shape
`(out_features x in1_features x in2_features)`
`(out_features x in1_features x in2_features)`. The values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in1_features}}`
bias: the learnable bias of the module of shape `(out_features)`
If :attr:`bias` is ``True``, the values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in1_features}}`
Examples::
Expand All @@ -106,10 +118,10 @@ def __init__(self, in1_features, in2_features, out_features, bias=True):
self.reset_parameters()

def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
bound = 1 / math.sqrt(self.weight.size(1))
init.uniform_(self.weight, -bound, bound)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
init.uniform_(self.bias, -bound, bound)

def forward(self, input1, input2):
return F.bilinear(input1, input2, self.weight, self.bias)
Expand Down
15 changes: 9 additions & 6 deletions torch/nn/modules/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .module import Module
from .batchnorm import _BatchNorm
from .. import functional as F
from .. import init


class LocalResponseNorm(Module):
Expand Down Expand Up @@ -101,7 +102,8 @@ class LayerNorm(Module):
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters. Default: ``True``
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
Expand Down Expand Up @@ -140,8 +142,8 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):

def reset_parameters(self):
if self.elementwise_affine:
self.weight.data.fill_(1)
self.bias.data.zero_()
init.ones_(self.weight)
init.zeros_(self.bias)

def forward(self, input):
return F.layer_norm(
Expand Down Expand Up @@ -173,7 +175,8 @@ class GroupNorm(Module):
num_channels (int): number of channels expected in input
eps: a value added to the denominator for numerical stability. Default: 1e-5
affine: a boolean value that when set to ``True``, this module
has learnable per-channel affine parameters. Default: ``True``
has learnable per-channel affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, num\_channels, *)`
Expand Down Expand Up @@ -209,8 +212,8 @@ def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):

def reset_parameters(self):
if self.affine:
self.weight.data.fill_(1)
self.bias.data.zero_()
init.ones_(self.weight)
init.zeros_(self.bias)

def forward(self, input):
return F.group_norm(
Expand Down

0 comments on commit 7ca6b08

Please sign in to comment.