Skip to content

Commit

Permalink
Adding activation function: sin()
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 317781121
  • Loading branch information
SiuMath authored and romanngg committed Jun 24, 2020
1 parent b4a618c commit a936d87
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 20 deletions.
96 changes: 76 additions & 20 deletions neural_tangents/stax.py
Expand Up @@ -1165,6 +1165,13 @@ def Erf(do_backprop: bool = False) -> InternalLayer:
do_backprop=do_backprop)


@layer
@_supports_masking(remask_kernel=True)
def Sin(a=1., b=1., c=0.) -> InternalLayer:
"""Returns the function f(x) = a sin(bx + c)."""
return _elementwise(_sin, 'Sin', a=1., b=1., c=0.)


@layer
@_supports_masking(remask_kernel=True)
def Relu(
Expand Down Expand Up @@ -1512,11 +1519,12 @@ def prepare_mask(m):

return m

prod11, prod12, prod22 = _get_diagonal_prods(
prod11, prod12, prod22 = _get_diagonal_outer_prods(
eps + cov1,
cov2 if cov2 is None else eps + cov2,
kernels.diagonal_batch,
kernels.diagonal_spatial,
op.mul,
axis=kernel_axis,
mask1=prepare_mask(kernels.mask1),
mask2=prepare_mask(kernels.mask2),
Expand Down Expand Up @@ -2059,6 +2067,10 @@ def _erf(x, **kwargs):
return erf(x)


def _sin(x, a, b, c, **kwargs):
return a * np.sin(b * x + c)


def _arccos(x, do_backprop):
if do_backprop:
# https://github.com/google/jax/issues/654
Expand Down Expand Up @@ -2152,14 +2164,15 @@ def _get_diagonal_prod(cov1: np.ndarray,
return prod11, prod12, prod22


def _get_diagonal_prods(cov1: np.ndarray,
cov2: Optional[np.ndarray],
diagonal_batch: bool,
diagonal_spatial: bool,
axis: Tuple[int, ...] = (),
mask1: Optional[np.ndarray] = None,
mask2: Optional[np.ndarray] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
def _get_diagonal_outer_prods(cov1: np.ndarray,
cov2: Optional[np.ndarray],
diagonal_batch: bool,
diagonal_spatial: bool,
operation: Callable[[float, float], float],
axis: Tuple[int, ...] = (),
mask1: Optional[np.ndarray] = None,
mask2: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Gets outer products of diagonals `cov1, cov1`, `cov1, cov2`, `cov2, cov2`.
`prod11[x1, x2, h1, h2, ...]` =
Expand All @@ -2181,11 +2194,11 @@ def _get_diagonal_prods(cov1: np.ndarray,
cov2 = _mean_and_var(cov2, axis=axis, keepdims=True, mask=mask2)

end_axis = 1 if diagonal_spatial else cov1.ndim # pytype: disable=attribute-error
prod12 = utils.outer_prod(cov1, cov2, 0, end_axis, op.mul)
prod12 = utils.outer_prod(cov1, cov2, 0, end_axis, operation)

start_axis = 1 if diagonal_batch else 0
prod11 = utils.outer_prod(cov1, cov1, start_axis, end_axis, op.mul)
prod22 = (utils.outer_prod(cov2, cov2, start_axis, end_axis, op.mul)
prod11 = utils.outer_prod(cov1, cov1, start_axis, end_axis, operation)
prod22 = (utils.outer_prod(cov2, cov2, start_axis, end_axis, operation)
if cov2 is not None else prod11)

return prod11, prod12, prod22
Expand Down Expand Up @@ -2224,10 +2237,11 @@ def _transform_kernels_ab_relu(
if cov2 is not None:
cov2 /= factor

prod11, prod12, prod22 = _get_diagonal_prods(cov1,
cov2,
kernels.diagonal_batch,
kernels.diagonal_spatial)
prod11, prod12, prod22 = _get_diagonal_outer_prods(cov1,
cov2,
kernels.diagonal_batch,
kernels.diagonal_spatial,
op.mul)
nngp, ntk = _get_ab_relu_kernel(nngp, prod12, a, b, do_backprop, ntk=ntk)
if do_stabilize:
nngp *= factor
Expand Down Expand Up @@ -2270,10 +2284,11 @@ def _transform_kernels_erf(kernels: Kernel, do_backprop: bool) -> Kernel:
_cov1_denom = 1 + 2 * cov1
_cov2_denom = None if cov2 is None else 1 + 2 * cov2

prod11, prod12, prod22 = _get_diagonal_prods(_cov1_denom,
_cov2_denom,
kernels.diagonal_batch,
kernels.diagonal_spatial)
prod11, prod12, prod22 = _get_diagonal_outer_prods(_cov1_denom,
_cov2_denom,
kernels.diagonal_batch,
kernels.diagonal_spatial,
op.mul)
nngp, ntk = _get_erf_kernel(nngp, prod12, do_backprop, ntk=ntk)

if kernels.diagonal_batch and kernels.diagonal_spatial:
Expand All @@ -2292,6 +2307,45 @@ def _transform_kernels_erf(kernels: Kernel, do_backprop: bool) -> Kernel:
is_gaussian=False)


def _transform_kernels_sin(
kernels: Kernel,
a: float = 1.0,
b: float = 1.0,
c: float = 0.0) -> Kernel:
"""Compute new kernels after an `Sin` layer."""
cov1, nngp, cov2, ntk = kernels.cov1, kernels.nngp, kernels.cov2, kernels.ntk

sum11, sum12, sum22 = _get_diagonal_outer_prods(cov1,
cov2,
kernels.diagonal_batch,
kernels.diagonal_spatial,
op.add)

def _get_sin_kernel(prod, cov, ntk):
s1 = a**2 * np.exp(b * (-0.5 * prod + cov)) / 2.
s2 = a**2 * np.exp(b * (-0.5 * prod - cov)) / 2. * np.cos(2*c)
nngp = s1 - s2
if ntk is not None:
ntk *= (s1 + s2) * b**2
return nngp, ntk
nngp, ntk = _get_sin_kernel(sum12, nngp, ntk)

if kernels.diagonal_batch and kernels.diagonal_spatial:
cov1 = -a**2 * np.expm1(-2. * b * sum11) / 2.
if cov2 is not None:
cov2 = -a**2 * np.expm1(-2. * b * sum22) / 2.
else:
cov1 = _get_sin_kernel(sum11, cov1, None)[0]
if cov2 is not None:
cov2 = _get_sin_kernel(sum22, cov2, None)[0]

return kernels.replace(cov1=cov1,
nngp=nngp,
cov2=cov2,
ntk=ntk,
is_gaussian=False)


def _transform_kernels(
kernels: Kernel, fn: Callable[[float], float], **fn_kwargs) -> Kernel:
"""Apply transformation to kernels.
Expand All @@ -2309,6 +2363,8 @@ def _transform_kernels(
return _transform_kernels_ab_relu(kernels, **fn_kwargs)
if fn is _erf:
return _transform_kernels_erf(kernels, **fn_kwargs)
if fn is _sin:
return _transform_kernels_sin(kernels, **fn_kwargs)
# TODO: Monte Carlo approximation to the integral (suggested by schsam.)
raise NotImplementedError('Analaytic kernel for activiation {} is not '
'implmented'.format(fn))
Expand Down
46 changes: 46 additions & 0 deletions neural_tangents/tests/stax_test.py
Expand Up @@ -781,6 +781,52 @@ def _get_empirical(n_samples, get):
self.assertEqual(shape2, x2_out_shape)


@jtu.parameterized.parameters([
{
'same_inputs': True
},
{
'same_inputs': False
},
])
class SinTest(test_utils.NeuralTangentsTestCase):

def test_sin(self, same_inputs):
key = random.PRNGKey(1)
for a, b, c in [(1., 1., 0.),
(1., 1., np.pi/2),
(1.5, 2., np.pi/4),
(10., 25., 2.)]:
for get in ['nngp', 'ntk']:
output_dim = 1024 if get == 'nngp' else 1
key, split = random.split(key)
for model in ['fc', 'conv']:
with self.subTest(get=get, a=a, b=b, c=c, model=model):
if model == 'fc':
X0_1 = random.normal(key, (6, 7))
X0_2 = None if same_inputs else random.normal(split, (10, 7))
affine = stax.Dense(2048, 1., 0.)
readout = stax.Dense(output_dim)
else:
X0_1 = random.normal(key, (4, 8, 8, 3))
X0_2 = None if same_inputs else random.normal(split, (6, 8, 8, 3))
affine = stax.Conv(1024, (3, 2), W_std=1., b_std=0.1,
padding='SAME')
readout = stax.serial(stax.GlobalAvgPool(),
stax.Dense(output_dim))
init_fn, apply_sin, kernel_fn_sin = stax.serial(affine,
stax.Sin(a=a,
b=b,
c=c),
readout)
analytic_kernel = kernel_fn_sin(X0_1, X0_2, get)
mc_kernel_fn = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_sin,
key, 200)
empirical_kernel = np.squeeze(mc_kernel_fn(X0_1, X0_2, get))
test_utils.assert_close_matrices(self, analytic_kernel,
empirical_kernel, RTOL)


@jtu.parameterized.parameters([
{
'same_inputs': True
Expand Down

0 comments on commit a936d87

Please sign in to comment.