Skip to content

Latest commit

 

History

History
228 lines (166 loc) · 7.4 KB

nn.rst

File metadata and controls

228 lines (166 loc) · 7.4 KB

objax.nn package

objax.nn

objax.nn

BatchNorm BatchNorm0D BatchNorm1D BatchNorm2D Conv2D ConvTranspose2D Dropout Linear MovingAverage ExponentialMovingAverage Sequential SyncedBatchNorm SyncedBatchNorm0D SyncedBatchNorm1D SyncedBatchNorm2D

BatchNorm

$$y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta$$

The mean (E[x]) and variance (Var[x]) are calculated per specified dimensions and over the mini-batches. β and γ are trainable parameter tensors of shape dims. The elements of β are initialized with zeros and those of γ are initialized with ones.

BatchNorm0D

$$y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta$$

The mean (E[x]) and variance (Var[x]) are calculated over the mini-batches. β and γ are trainable parameter tensors of shape (1, nin). The elements of β are initialized with zeros and those of γ are initialized with ones.

BatchNorm1D

$$y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta$$

The mean (E[x]) and variance (Var[x]) are calculated per channel and over the mini-batches. β and γ are trainable parameter tensors of shape (1, nin, 1). The elements of β are initialized with zeros and those of γ are initialized with ones.

BatchNorm2D

$$y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta$$

The mean (E[x]) and variance (Var[x]) are calculated per channel and over the mini-batches. β and γ are trainable parameter tensors of shape (1, nin, 1, 1). The elements of β are initialized with zeros and those of γ are initialized with ones.

Conv2D

In the simplest case (strides = 1, padding = VALID), the output tensor (N, Cout, Hout, Wout) is computed from an input tensor (N, Cin, H, W) with kernel weight (k, k, Cin, Cout) and bias (Cout) as follows:

$$\mathrm{out}[n,c,h,w] = \mathrm{b}[c] + \sum_{t=0}^{C_{in}-1}\sum_{i=0}^{k-1}\sum_{j=0}^{k-1} \mathrm{in}[n,c,i+h,j+w] \times \mathrm{w}[i,j,t,c]$$

where Hout = H − k + 1, Wout = W − k + 1. Note that the implementation follows the definition of cross-correlation. When padding = SAME, the input tensor is zero-padded by $\lfloor\frac{k-1}{2}\rfloor$ for left and up sides and $\lfloor\frac{k}{2}\rfloor$ for right and down sides.

ConvTranspose2D

Dropout

During the evaluation, the module does not modify the input tensor. Dropout (Improving neural networks by preventing co-adaptation of feature detectors) is an effective regularization technique which reduces the overfitting and increases the overall utility.

Linear

The output tensor (N, Cout) is computed from an input tensor (N, Cin) with kernel weight (Cin, Cout) and bias (Cout) as follows:

$$\mathrm{out}[n,c] = \mathrm{b}[c] + \sum_{t=1}^{C_{in}} \mathrm{in}[n,t] \times \mathrm{w}[t,c]$$

MovingAverage

ExponentialMovingAverage


xEMA ← momentum × xEMA + (1 − momentum) × x

Sequential

Usage example:

import objax

ml = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu,
                          objax.nn.Linear(3, 4)])
x = objax.random.normal((10, 2))
y = ml(x)  # Runs all the operations (Linear -> ReLU -> Linear).
print(y.shape)  # (10, 4)

# objax.nn.Sequential is really a list.
ml.insert(2, objax.nn.BatchNorm0D(3))  # Add a batch norm layer after ReLU
ml.append(objax.nn.Dropout(keep=0.5))  # Add a dropout layer at the end
y = ml(x, training=False)  # Both batch norm and dropout expect a training argument.
# Sequential automatically pass arguments to the modules using them.

# You can run a subset of operations since it is a list.
y1 = ml[:2](x)  # Run first two layers (Linear -> ReLU)
y2 = ml[2:](y1, training=False)  # Run all layers starting from third (BatchNorm0D -> Dropout)
print(ml(x, training=False) - y2)  # [[0. 0. ...]] - results are the same.

print(ml.vars())
# (Sequential)[0](Linear).b                              3 (3,)
# (Sequential)[0](Linear).w                              6 (2, 3)
# (Sequential)[2](BatchNorm0D).running_mean              3 (1, 3)
# (Sequential)[2](BatchNorm0D).running_var               3 (1, 3)
# (Sequential)[2](BatchNorm0D).beta                      3 (1, 3)
# (Sequential)[2](BatchNorm0D).gamma                     3 (1, 3)
# (Sequential)[3](BatchNorm0D).running_mean              3 (1, 3)
# (Sequential)[3](BatchNorm0D).running_var               3 (1, 3)
# (Sequential)[3](BatchNorm0D).beta                      3 (1, 3)
# (Sequential)[3](BatchNorm0D).gamma                     3 (1, 3)
# (Sequential)[4](Linear).b                              4 (4,)
# (Sequential)[4](Linear).w                             12 (3, 4)
# (Sequential)[5](Dropout).keygen(Generator)._key        2 (2,)
# +Total(13)                                            51

SyncedBatchNorm

SyncedBatchNorm0D

SyncedBatchNorm1D

SyncedBatchNorm2D

objax.nn.init

objax.nn.init

gain_leaky_relu identity kaiming_normal_gain kaiming_normal kaiming_truncated_normal orthogonal truncated_normal xavier_normal_gain xavier_normal xavier_truncated_normal

gain_leaky_relu

The returned gain value is

$$\sqrt{\frac{2}{1 + \text{relu_slope}^2}}.$$

identity

kaiming_normal_gain

The returned gain value is

$$\sqrt{\frac{1}{\text{fan_in}}}.$$

kaiming_normal

kaiming_truncated_normal

orthogonal

truncated_normal

xavier_normal_gain

The returned gain value is

$$\sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}}.$$

xavier_normal

xavier_truncated_normal