-
Notifications
You must be signed in to change notification settings - Fork 0
/
normalization.py
122 lines (91 loc) · 3.45 KB
/
normalization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
r"""Normalization layers"""
__all__ = [
'BatchNorm',
'LayerNorm',
]
import jax
import jax.numpy as jnp
from jax import Array
from typing import *
from .module import Module, Buffer
from ..debug import same_trace
class BatchNorm(Module):
r"""Creates a batch-normalization layer.
.. math:: y = \frac{x - \mathbb{E}[x]}{\sqrt{\mathbb{V}[x] + \epsilon}}
The mean and variance are calculated over the batch and spatial axes. During
training, the layer keeps running estimates of the computed mean and variance, which
are then used for normalization during evaluation. The update rule for a running
average statistic :math:`\hat{s}` is
.. math:: \hat{s} \gets \alpha \hat{s} + (1 - \alpha) s
where :math:`s` is the statistic calculated for the current batch.
References:
| Accelerating Deep Network Training by Reducing Internal Covariate Shift (Ioffe et al., 2015)
| https://arxiv.org/abs/1502.03167
Arguments:
channels: The number of channels :math:`C`.
epsilon: A numerical stability term :math:`\epsilon`.
momentum: The momentum :math:`\alpha \in [0, 1]` for the running estimates.
"""
training: bool = True
def __init__(
self,
channels: int,
epsilon: float = 1e-05,
momentum: float = 0.9,
):
self.epsilon = epsilon
self.momentum = momentum
self.stats = Buffer(
mean=jnp.zeros((channels,)),
var=jnp.ones((channels,)),
)
def __call__(self, x: Array) -> Array:
r"""
Arguments:
x: The input tensor :math:`x`, with shape :math:`(N, *, C)`.
Returns:
The output tensor :math:`y`, with shape :math:`(N, *, C)`.
"""
if self.training:
y = x.reshape(-1, x.shape[-1])
mean = jnp.mean(y, axis=0)
var = jnp.var(y, axis=0)
self.stats = Buffer(
mean=self.ema(self.stats.mean, jax.lax.stop_gradient(mean)),
var=self.ema(self.stats.var, jax.lax.stop_gradient(var)),
)
else:
mean = self.stats.mean
var = self.stats.var
return (x - mean) / jnp.sqrt(var + self.epsilon)
def ema(self, x: Array, y: Array) -> Array:
assert same_trace(x, y), "an unsafe side effect was detected. Ensure that the running statistic and input tensors have the same trace."
return self.momentum * x + (1 - self.momentum) * y
class LayerNorm(Module):
r"""Creates a layer-normalization layer.
.. math:: y = \frac{x - \mathbb{E}[x]}{\sqrt{\mathbb{V}[x] + \epsilon}}
References:
| Layer Normalization (Ba et al., 2016)
| https://arxiv.org/abs/1607.06450
Arguments:
axis: The axis(es) over which the mean and variance are calculated.
epsilon: A numerical stability term :math:`\epsilon`.
"""
def __init__(
self,
axis: Union[int, Sequence[int]] = -1,
epsilon: float = 1e-05,
):
self.axis = axis
self.epsilon = epsilon
@jax.jit
def __call__(self, x: Array) -> Array:
r"""
Arguments:
x: The input tensor :math:`x`, with shape :math:`(*, C)`.
Returns:
The output tensor :math:`y`, with shape :math:`(*, C)`.
"""
mean = jnp.mean(x, axis=self.axis, keepdims=True)
var = jnp.var(x, axis=self.axis, keepdims=True)
return (x - mean) / jnp.sqrt(var + self.epsilon)