From dacd4f9c5531e93b4a0b70b9102414391c2f7b16 Mon Sep 17 00:00:00 2001 From: Lechao Xiao Date: Tue, 15 Dec 2020 22:26:40 -0800 Subject: [PATCH] Adding `sign` function. PiperOrigin-RevId: 347758916 --- neural_tangents/stax.py | 33 +++++++++++++++++++++++++++++++++ tests/stax_test.py | 11 ++++++----- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/neural_tangents/stax.py b/neural_tangents/stax.py index a31865ae..0c3ec667 100644 --- a/neural_tangents/stax.py +++ b/neural_tangents/stax.py @@ -3000,6 +3000,39 @@ def Abs( return ABRelu(-1, 1, do_backprop, do_stabilize) +@layer +@supports_masking(remask_kernel=True) +def Sign(do_backprop: bool = False) -> InternalLayer: + """Sign function. + + Args: + do_backprop: set to `True` if you want to backpropagate through the kernel. + + Returns: + `(init_fn, apply_fn, kernel_fn)`. + """ + def fn(x): + return np.sign(x) + + _requires(diagonal_spatial=_Diagonal()) # pytype:disable=wrong-keyword-args + def kernel_fn(k: Kernel) -> Kernel: + cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk + if ntk is not None: + ntk = np.zeros_like(ntk) + _, prod12, _ = _get_diagonal_outer_prods(cov1, + cov2, + k.diagonal_batch, + k.diagonal_spatial, + op.mul) + nngp = 1 - _arccos(nngp / _safe_sqrt(prod12), do_backprop) * 2 / np.pi + cov1 = np.ones_like(cov1) + cov2 = cov2 if cov2 is None else np.ones_like(cov2) + k = k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) + return k + + return _elementwise(fn, 'Sign', kernel_fn) + + @layer @supports_masking(remask_kernel=True) def ElementwiseNumerical( diff --git a/tests/stax_test.py b/tests/stax_test.py index a6f11f01..63a61746 100644 --- a/tests/stax_test.py +++ b/tests/stax_test.py @@ -915,7 +915,7 @@ def _test_activation(self, activation_fn, same_inputs, model, get, 'abc': abc, } for model in ['fc', 'conv-pool', 'conv-flatten'] - for phi_name in ['Sin', 'Erf', 'Gelu'] + for phi_name in ['Sin', 'Erf', 'Gelu', 'Sign'] for same_inputs in [False] for get in ['nngp', 'ntk'] for abc in itertools.product( @@ -933,10 +933,11 @@ def test_activation(self, same_inputs, model, phi_name, get, abc): activation = stax.Sin(a=a, b=b, c=c) elif phi_name == 'Erf': activation = stax.Erf(a=a, b=b, c=c) - elif phi_name == 'Gelu': - activation = stax.Gelu() - if a != 1. or b != 1. or c != 0.: - absltest.SkipTest('Skip `Gelu` test if (a, b, c) != (1., 1., 0.).') + elif phi_name in ['Gelu', 'Sign']: + if a != 0.3 or b != 0.3 or c != 0.: + raise absltest.SkipTest('Skip `Gelu/Sign` test if ' + ' (a, b, c) != (.3, .3, 0.).') + activation = stax.Gelu() if phi_name == 'Gelu' else stax.Sign() else: raise absltest.SkipTest(f'Activation {phi_name} is not implemented.') self._test_activation(activation, same_inputs, model, get)