Skip to content

Commit

Permalink
Adding sign function.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 347758916
  • Loading branch information
SiuMath authored and romanngg committed Dec 16, 2020
1 parent 2a8e5f5 commit dacd4f9
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 5 deletions.
33 changes: 33 additions & 0 deletions neural_tangents/stax.py
Expand Up @@ -3000,6 +3000,39 @@ def Abs(
return ABRelu(-1, 1, do_backprop, do_stabilize)


@layer
@supports_masking(remask_kernel=True)
def Sign(do_backprop: bool = False) -> InternalLayer:
"""Sign function.
Args:
do_backprop: set to `True` if you want to backpropagate through the kernel.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
def fn(x):
return np.sign(x)

_requires(diagonal_spatial=_Diagonal()) # pytype:disable=wrong-keyword-args
def kernel_fn(k: Kernel) -> Kernel:
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
if ntk is not None:
ntk = np.zeros_like(ntk)
_, prod12, _ = _get_diagonal_outer_prods(cov1,
cov2,
k.diagonal_batch,
k.diagonal_spatial,
op.mul)
nngp = 1 - _arccos(nngp / _safe_sqrt(prod12), do_backprop) * 2 / np.pi
cov1 = np.ones_like(cov1)
cov2 = cov2 if cov2 is None else np.ones_like(cov2)
k = k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)
return k

return _elementwise(fn, 'Sign', kernel_fn)


@layer
@supports_masking(remask_kernel=True)
def ElementwiseNumerical(
Expand Down
11 changes: 6 additions & 5 deletions tests/stax_test.py
Expand Up @@ -915,7 +915,7 @@ def _test_activation(self, activation_fn, same_inputs, model, get,
'abc': abc,
}
for model in ['fc', 'conv-pool', 'conv-flatten']
for phi_name in ['Sin', 'Erf', 'Gelu']
for phi_name in ['Sin', 'Erf', 'Gelu', 'Sign']
for same_inputs in [False]
for get in ['nngp', 'ntk']
for abc in itertools.product(
Expand All @@ -933,10 +933,11 @@ def test_activation(self, same_inputs, model, phi_name, get, abc):
activation = stax.Sin(a=a, b=b, c=c)
elif phi_name == 'Erf':
activation = stax.Erf(a=a, b=b, c=c)
elif phi_name == 'Gelu':
activation = stax.Gelu()
if a != 1. or b != 1. or c != 0.:
absltest.SkipTest('Skip `Gelu` test if (a, b, c) != (1., 1., 0.).')
elif phi_name in ['Gelu', 'Sign']:
if a != 0.3 or b != 0.3 or c != 0.:
raise absltest.SkipTest('Skip `Gelu/Sign` test if '
' (a, b, c) != (.3, .3, 0.).')
activation = stax.Gelu() if phi_name == 'Gelu' else stax.Sign()
else:
raise absltest.SkipTest(f'Activation {phi_name} is not implemented.')
self._test_activation(activation, same_inputs, model, get)
Expand Down

0 comments on commit dacd4f9

Please sign in to comment.