/
normalization.py
176 lines (130 loc) · 4.83 KB
/
normalization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
r"""Normalization layers"""
__all__ = [
'BatchNorm',
'LayerNorm',
'GroupNorm',
]
import jax
import jax.numpy as jnp
from einops import rearrange
from jax import Array
from typing import Dict, NamedTuple, Sequence, Tuple, Union
# isort: local
from .module import Module
from .state import StateEntry, update_state
class Statistics(NamedTuple):
mean: Array
var: Array
class BatchNorm(Module):
r"""Creates a batch-normalization layer.
.. math:: y = \frac{x - \mathbb{E}[x]}{\sqrt{\mathbb{V}[x] + \epsilon}}
The mean and variance are calculated over the batch and spatial axes. During
training, the layer keeps running estimates of the mean and variance, which are then
used for normalization during evaluation. The update rule for a running average
statistic :math:`\hat{s}` is
.. math:: \hat{s} \gets \alpha \hat{s} + (1 - \alpha) s
where :math:`s` is the statistic calculated for the current batch.
References:
| Accelerating Deep Network Training by Reducing Internal Covariate Shift (Ioffe et al., 2015)
| https://arxiv.org/abs/1502.03167
Arguments:
channels: The number of channels :math:`C`.
epsilon: A numerical stability term :math:`\epsilon`.
momentum: The momentum :math:`\alpha \in [0, 1]` for the running estimates.
"""
training: bool = True
def __init__(
self,
channels: int,
epsilon: Union[float, Array] = 1e-05,
momentum: Union[float, Array] = 0.9,
):
self.epsilon = jnp.asarray(epsilon)
self.momentum = jnp.asarray(momentum)
self.stats = StateEntry(
Statistics(
mean=jnp.zeros(channels),
var=jnp.ones(channels),
)
)
def __call__(self, x: Array, state: Dict) -> Tuple[Array, Dict]:
r"""
Arguments:
x: The input tensor :math:`x`, with shape :math:`(N, *, C)`.
state: The state dictionary.
Returns:
The output tensor :math:`y`, with shape :math:`(N, *, C)`, and the
(updated) state dictionary.
"""
if self.training:
assert x.ndim > 1, "the input tensor is not batched."
y = x.reshape(-1, x.shape[-1])
mean = jnp.mean(y, axis=0)
var = jnp.var(y, axis=0)
stats = state[self.stats]
stats = Statistics(
mean=self.ema(stats.mean, jax.lax.stop_gradient(mean)),
var=self.ema(stats.var, jax.lax.stop_gradient(var)),
)
state = update_state(state, {self.stats: stats})
else:
mean, var = state[self.stats]
y = (x - mean) / jnp.sqrt(var + self.epsilon)
return y, state
def ema(self, x: Array, y: Array) -> Array:
return self.momentum * x + (1 - self.momentum) * y
class LayerNorm(Module):
r"""Creates a layer-normalization layer.
.. math:: y = \frac{x - \mathbb{E}[x]}{\sqrt{\mathbb{V}[x] + \epsilon}}
References:
| Layer Normalization (Ba et al., 2016)
| https://arxiv.org/abs/1607.06450
Arguments:
axis: The axis(es) over which the mean and variance are calculated.
epsilon: A numerical stability term :math:`\epsilon`.
"""
def __init__(
self,
axis: Union[int, Sequence[int]] = -1,
epsilon: Union[float, Array] = 1e-05,
):
self.axis = axis
self.epsilon = jnp.asarray(epsilon)
def __call__(self, x: Array) -> Array:
r"""
Arguments:
x: The input tensor :math:`x`, with shape :math:`(*, C)`.
Returns:
The output tensor :math:`y`, with shape :math:`(*, C)`.
"""
mean = jnp.mean(x, axis=self.axis, keepdims=True)
var = jnp.var(x, axis=self.axis, keepdims=True)
return (x - mean) / jnp.sqrt(var + self.epsilon)
class GroupNorm(LayerNorm):
r"""Creates a group-normalization layer.
References:
| Group Normalization (Wu et al., 2018)
| https://arxiv.org/abs/1803.08494
Arguments:
groups: The number of groups :math:`G` to separate channels into. If :math:`G = 1`,
the layer is equivalent to :class:`LayerNorm`.
epsilon: A numerical stability term :math:`\epsilon`.
"""
def __init__(
self,
groups: int,
epsilon: Union[float, Array] = 1e-05,
):
super().__init__(axis=-1, epsilon=epsilon)
self.groups = groups
def __call__(self, x: Array) -> Array:
r"""
Arguments:
x: The input tensor :math:`x`, with shape :math:`(*, C)`.
Returns:
The output tensor :math:`y`, with shape :math:`(*, C)`.
"""
x = rearrange(x, '... (G D) -> ... G D', G=self.groups)
x = super().__call__(x)
x = rearrange(x, '... G D -> ... (G D)')
return x