-
Notifications
You must be signed in to change notification settings - Fork 0
/
dropout.py
85 lines (59 loc) · 2.01 KB
/
dropout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
r"""Dropout layers"""
__all__ = [
'Dropout',
'TrainingDropout',
]
import jax
from jax import Array
from typing import Union
# isort: local
from .module import Module
from ..random import get_rng
class Dropout(Module):
r"""Creates a dropout layer.
.. math:: y = \frac{m \odot x}{1 - p}
where the binary mask :math:`m` is drawn from a Bernoulli distribution such that
:math:`P(m_i = 0) = p`. This has proven to be an effective technique for
regularization and preventing overfitting.
References:
| A Simple Way to Prevent Neural Networks from Overfitting (Srivastava et al., 2014)
| https://jmlr.org/papers/v15/srivastava14a.html
Arguments:
p: The dropout rate :math:`p \in [0, 1]`.
"""
def __init__(self, p: Union[float, Array] = 0.5):
self.p = jax.numpy.asarray(p)
def __call__(self, x: Array, key: Array) -> Array:
r"""
Arguments:
x: The input tensor :math:`x`, with shape :math:`(*)`.
key: A PRNG key.
Returns:
The output tensor :math:`y`, with shape :math:`(*)`.
"""
mask = jax.random.bernoulli(key, 1 - self.p, shape=x.shape)
return jax.numpy.where(mask, x / (1 - self.p), 0)
class TrainingDropout(Dropout):
r"""Creates a training-bound dropout layer.
When :py:`self.training = False`,
.. math:: y = x
See also:
:class:`Dropout`
Arguments:
p: The dropout rate :math:`p \in [0, 1]`.
"""
training: bool = True
def __call__(self, x: Array, key: Array = None) -> Array:
r"""
Arguments:
x: The input tensor :math:`x`, with shape :math:`(*)`.
key: A PRNG key. If :py:`None`, :func:`inox.random.get_rng` is used instead.
Returns:
The output tensor :math:`y`, with shape :math:`(*)`.
"""
if self.training:
if key is None:
key = get_rng().split()
return super().__call__(x, key)
else:
return x