Skip to content

Commit

Permalink
Add Gelu.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 319323810
  • Loading branch information
SiuMath authored and romanngg committed Jul 2, 2020
1 parent c01e965 commit 70b983d
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 12 deletions.
94 changes: 90 additions & 4 deletions neural_tangents/stax.py
Expand Up @@ -483,7 +483,7 @@ def GeneralConv(
Args:
dimension_numbers: Specifies which axes should be convolved over. Should
match the specification in `jax.lax.conv_general_dilated`.
match the specification in `jax.lax.dot_general_dilated`.
out_chan: The number of output channels / features of the
convolution. This is ignored in by the `kernel_fn` in NTK
parameterization.
Expand Down Expand Up @@ -1231,12 +1231,29 @@ def Erf(a: float = 1.,
do_backprop=do_backprop)


@layer
@_supports_masking(remask_kernel=True)
def Gelu(do_backprop: bool = False) -> InternalLayer:
"""Gelu function.
Args:
do_backprop: set to `True` if you want to backpropagate through the kernel.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
return _elementwise(_gelu,
'Gelu',
do_backprop=do_backprop)


@layer
@_supports_masking(remask_kernel=True)
def Sin(a: float = 1.,
b: float = 1.,
c: float = 0.) -> InternalLayer:
"""Affine transform of `Sin` nonlinearity, i.e. `a sin(b*x + c)`
Args:
a: a float.
b: a float.
Expand Down Expand Up @@ -2209,6 +2226,10 @@ def _erf(x, a, b, c, **kwargs):
return a * erf(b * x) + c


def _gelu(x, **kwargs):
return 0.5 * x * (1. + erf(x / np.sqrt(2.)))


def _sin(x, a, b, c, **kwargs):
return a * np.sin(b * x + c)

Expand Down Expand Up @@ -2371,11 +2392,11 @@ def _get_erf_kernel(
prod: np.ndarray,
do_backprop: bool,
ntk: np.ndarray = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
dot_sigma = 4 / (np.pi * np.sqrt(prod - 4 * ker_mat**2))
ker_mat = _arcsin(2 * ker_mat / _safe_sqrt(prod), do_backprop) * 2 / np.pi

if ntk is not None:
dot_sigma = 4 / (np.pi * np.sqrt(prod - 4 * ker_mat**2))
ntk *= dot_sigma
ker_mat = _arcsin(2 * ker_mat / np.sqrt(prod), do_backprop) * 2 / np.pi


return ker_mat, ntk

Expand Down Expand Up @@ -2410,6 +2431,69 @@ def _transform_kernels_erf_non_scaled(k: Kernel, do_backprop: bool) -> Kernel:
is_gaussian=False)


def _get_gelu_kernel(nngp: np.ndarray,
prod: np.ndarray,
prod_plus_1: np.ndarray,
do_backprop: bool,
ntk: np.ndarray = None
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
delta_squared = prod_plus_1 - nngp**2
delta = _safe_sqrt(delta_squared)
ratio = nngp / _safe_sqrt(prod_plus_1)
new_nngp = (nngp**2 + prod * delta_squared) / (prod_plus_1 * delta)
new_nngp += nngp * _arcsin(ratio, do_backprop)
new_nngp /= 2 * np.pi
new_nngp += 0.25 * nngp

if ntk is not None:
second_term = 0.25 + _arcsin(ratio, do_backprop) / (2 * np.pi)
first_term = 1 / delta_squared + (1 - prod) / prod_plus_1 + 1
first_term *= nngp / delta / (2. * np.pi)
dot_sigma = first_term + second_term
ntk *= dot_sigma
return new_nngp, ntk


def _get_gelu_nngp_diag(nngp_diag: np.ndarray, do_backprop: bool) -> np.ndarray:
new_diag = nngp_diag / ((nngp_diag + 1.) * np.sqrt(1. + 2.* nngp_diag))
new_diag += _arcsin(nngp_diag/(nngp_diag + 1), do_backprop) / 2
new_diag /= np.pi
new_diag += 0.25
new_diag *= nngp_diag
return new_diag


def _transform_kernels_gelu(k: Kernel, do_backprop: bool) -> Kernel:
"""Compute new kernels after an `Gelu` layer; NNGP see `arXiv:2002.08517`."""
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk

cov1_plus_1 = cov1 + 1
cov2_plus_1 = None if cov2 is None else cov2 + 1

prod11_plus_1, prod12_plus_1, prod22_plus_1 = _get_diagonal_outer_prods(
cov1_plus_1, cov2_plus_1, k.diagonal_batch, k.diagonal_spatial, op.mul)
prod11, prod12, prod22 = _get_diagonal_outer_prods(
cov1, cov2, k.diagonal_batch, k.diagonal_spatial, op.mul)

nngp, ntk = _get_gelu_kernel(nngp, prod12, prod12_plus_1, do_backprop,
ntk=ntk)

if k.diagonal_batch and k.diagonal_spatial:
cov1 = _get_gelu_nngp_diag(cov1, do_backprop)
if cov2 is not None:
cov2 = _get_gelu_nngp_diag(cov2, do_backprop)
else:
cov1, _ = _get_gelu_kernel(cov1, prod11, prod11_plus_1, do_backprop)
if cov2 is not None:
cov2, _ = _get_gelu_kernel(cov2, prod22, prod22_plus_1, do_backprop)

return k.replace(cov1=cov1,
nngp=nngp,
cov2=cov2,
ntk=ntk,
is_gaussian=False)


def _transform_kernels_affine_erf(
k: Kernel,
do_backprop: bool,
Expand Down Expand Up @@ -2494,6 +2578,8 @@ def _transform_kernels(
return _transform_kernels_affine_erf(k, **fn_kwargs)
if fn is _sin:
return _transform_kernels_sin(k, **fn_kwargs)
if fn is _gelu:
return _transform_kernels_gelu(k, **fn_kwargs)
# TODO(xlc): Monte Carlo approximation to the integral (suggested by schsam@.)
raise NotImplementedError(f'Analaytic kernel for activiation {fn} is not '
f'implmented.')
Expand Down
20 changes: 12 additions & 8 deletions neural_tangents/tests/stax_test.py
Expand Up @@ -840,12 +840,13 @@ class ActivationTest(test_utils.NeuralTangentsTestCase):
'abc': abc,
}
for model in ['fc', 'conv-pool', 'conv-flatten']
for phi_name in ['Sin', 'Erf']
for same_inputs in [False, True]
for phi_name in ['Sin', 'Erf', 'Gelu']
for same_inputs in [True, False]
for get in ['nngp', 'ntk']
for abc in product([1., 2., 0.3],
[1., 1.5, 0.3],
[0., -np.pi/4., np.pi/2.])))
[0., -np.pi/4., np.pi/2.])
))
def test_activation(self, same_inputs, model, phi_name, get, abc):
platform = xla_bridge.get_backend().platform
if platform == 'cpu' and 'conv' in model:
Expand All @@ -854,19 +855,19 @@ def test_activation(self, same_inputs, model, phi_name, get, abc):
key = random.PRNGKey(1)
key, split = random.split(key)
output_dim = 2048 if get == 'nngp' else 1

W_std = 0.9 if phi_name == 'Sin' else 2.
if model == 'fc':
rtol = 0.02
X0_1 = random.normal(key, (6, 7))
X0_2 = None if same_inputs else random.normal(split, (10, 7))
affine = stax.Dense(1024, 1., 0.)
affine = stax.Dense(1024, W_std, 0.5)
readout = stax.Dense(output_dim)
depth = 1
else:
rtol = 0.05
X0_1 = random.normal(key, (4, 8, 8, 3))
X0_2 = None if same_inputs else random.normal(split, (6, 8, 8, 3))
affine = stax.Conv(1024, (3, 2), W_std=1., b_std=0.1, padding='SAME')
affine = stax.Conv(1024, (3, 2), W_std=W_std, b_std=0.5, padding='SAME')
readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else
stax.Flatten(),
stax.Dense(output_dim))
Expand All @@ -875,18 +876,21 @@ def test_activation(self, same_inputs, model, phi_name, get, abc):
num_samplings = 200
rtol *= 2
else:
num_samplings = 500
num_samplings = 500 if phi_name == 'Sin' else 300
a, b, c = abc
if phi_name == 'Sin':
activation = stax.Sin(a=a, b=b, c=c)
elif phi_name == 'Erf':
activation = stax.Erf(a=a, b=b, c=c)
elif phi_name == 'Gelu':
activation = stax.Gelu()
if a != 1. or b != 1. or c != 0.:
unittest.SkipTest('Skip `Gelu` test if (a, b, c) != (1., 1., 0.).')
else:
raise unittest.SkipTest(f'Activation {phi_name} is not implemented.')
init_fn, apply_fn, kernel_fn = stax.serial(
*[affine, activation]*depth, readout)
analytic_kernel = kernel_fn(X0_1, X0_2, get)
key, split = random.split(key)
mc_kernel_fn = monte_carlo.monte_carlo_kernel_fn(
init_fn, apply_fn, split, num_samplings)
empirical_kernel = mc_kernel_fn(X0_1, X0_2, get)
Expand Down

0 comments on commit 70b983d

Please sign in to comment.