Skip to content

Commit

Permalink
fix a bug in sin().
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 318197390
  • Loading branch information
SiuMath authored and romanngg committed Jun 25, 2020
1 parent bd0e56a commit 689b0be
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 19 deletions.
20 changes: 11 additions & 9 deletions neural_tangents/stax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,7 @@ def Sin(a=1., b=1., c=0.) -> InternalLayer:
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
return _elementwise(_sin, 'Sin', a=1., b=1., c=0.)
return _elementwise(_sin, 'Sin', a=a, b=b, c=c)


@layer
Expand Down Expand Up @@ -2410,20 +2410,22 @@ def _transform_kernels_sin(
k.diagonal_batch,
k.diagonal_spatial,
op.add)

def _get_sin_kernel(prod, cov, ntk):
s1 = a**2 * np.exp(b * (-0.5 * prod + cov)) / 2.
s2 = a**2 * np.exp(b * (-0.5 * prod - cov)) / 2. * np.cos(2*c)
nngp = s1 - s2
half_a_square = a**2 / 2.
def _get_sin_kernel(sum_, cov, ntk):
s1 = np.exp(b**2 * (-0.5 * sum_ + cov))
s2 = np.exp(b**2 * (-0.5 * sum_ - cov)) * np.cos(2*c)
nngp = half_a_square * (s1 - s2)
if ntk is not None:
ntk *= (s1 + s2) * b**2
ntk *= half_a_square * b**2 * (s1 + s2)
return nngp, ntk
def _get_diag_sin_kernel(mat):
return half_a_square *(1. - np.exp(-b**2 * mat) *np.cos(2*c))
nngp, ntk = _get_sin_kernel(sum12, nngp, ntk)

if k.diagonal_batch and k.diagonal_spatial:
cov1 = -a**2 * np.expm1(-2. * b * sum11) / 2.
cov1 = _get_diag_sin_kernel(sum11)
if cov2 is not None:
cov2 = -a**2 * np.expm1(-2. * b * sum22) / 2.
cov2 = _get_diag_sin_kernel(sum22)
else:
cov1 = _get_sin_kernel(sum11, cov1, None)[0]
if cov2 is not None:
Expand Down
26 changes: 16 additions & 10 deletions neural_tangents/tests/stax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import string
from functools import partial
from itertools import product

import random as prandom
import itertools
import logging
Expand Down Expand Up @@ -818,6 +820,7 @@ def _get_empirical(n_samples, get):
self.assertEqual(shape2, x2_out_shape)



@jtu.parameterized.parameters([
{
'same_inputs': True
Expand All @@ -830,38 +833,41 @@ class SinTest(test_utils.NeuralTangentsTestCase):

def test_sin(self, same_inputs):
key = random.PRNGKey(1)
for a, b, c in [(1., 1., 0.),
(1., 1., np.pi/2),
(1.5, 2., np.pi/4),
(10., 25., 2.)]:
for a, b, c in product([5.],
[1.5],
[0., -np.pi/4.]):
for get in ['nngp', 'ntk']:
output_dim = 1024 if get == 'nngp' else 1
output_dim = 2048 if get == 'nngp' else 1
key, split = random.split(key)
for model in ['fc', 'conv']:
for model in ['fc', 'conv-pool', 'conv-flatten']:
with self.subTest(get=get, a=a, b=b, c=c, model=model):
if model == 'fc':
X0_1 = random.normal(key, (6, 7))
X0_2 = None if same_inputs else random.normal(split, (10, 7))
affine = stax.Dense(2048, 1., 0.)
readout = stax.Dense(output_dim)
else:
if xla_bridge.get_backend().platform == 'cpu':
raise unittest.SkipTest('Not running CNNs on CPU to save time.')
X0_1 = random.normal(key, (4, 8, 8, 3))
X0_2 = None if same_inputs else random.normal(split, (6, 8, 8, 3))
affine = stax.Conv(1024, (3, 2), W_std=1., b_std=0.1,
padding='SAME')
readout = stax.serial(stax.GlobalAvgPool(),
readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else
stax.Flatten(),
stax.Dense(output_dim))
init_fn, apply_sin, kernel_fn_sin = stax.serial(affine,
stax.Sin(a=a,
b=b,
c=c),
readout)
analytic_kernel = kernel_fn_sin(X0_1, X0_2, get)
mc_kernel_fn = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_sin,
key, 200)
key, split = random.split(key)
mc_kernel_fn = monte_carlo.monte_carlo_kernel_fn(
init_fn, apply_sin, key, 200)
empirical_kernel = np.squeeze(mc_kernel_fn(X0_1, X0_2, get))
test_utils.assert_close_matrices(self, analytic_kernel,
empirical_kernel, RTOL)
empirical_kernel, 0.05)


@jtu.parameterized.parameters([
Expand Down

0 comments on commit 689b0be

Please sign in to comment.