Skip to content

Commit

Permalink
Allow trace_axes to precede diagonal_axes in empirical kernels - …
Browse files Browse the repository at this point in the history
…otherwise in such situations current code would fail unexpectedly / silently.

PiperOrigin-RevId: 319869589
  • Loading branch information
romanngg committed Jul 6, 2020
1 parent 6eea987 commit 3582ba0
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 37 deletions.
53 changes: 30 additions & 23 deletions neural_tangents/tests/empirical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,20 +245,25 @@ def testNTKAgainstDirect(
(0,),
(0, 1),
(0, 1, 2),
(0, 1, 2, 3),]
(0, 1, 2, 3),
(-1,),
(-2,),
(0, -1),
(1, -2),
(2, 3)]
for trace_axes in [(),
(0,),
(0, 1),
(-1,),
(1,),
(0, -1),
(-1, -2),
(0, 1, 2, 3),
(3, 1, 2, 0),
(1, 2, 3),
(-3, -2),
(-3, -1),
(-2, -4)]))
(0,),
(0, 1),
(-1,),
(1,),
(0, -1),
(-1, -2),
(0, 1, 2, 3),
(3, 1, 2, 0),
(1, 2, 3),
(-3, -2),
(-3, -1),
(-2, -4)]))
def testAxes(self, diagonal_axes, trace_axes):
key = random.PRNGKey(0)
key, self_split, other_split = random.split(key, 3)
Expand All @@ -282,23 +287,25 @@ def testAxes(self, diagonal_axes, trace_axes):
n_marg = len(_diagonal_axes)
n_chan = len(_trace_axes)

g = implicit(data_self, None)
g_direct = direct(data_self, None)
g_nngp = nngp(data_self, None)

self.assertAllClose(g, g_direct)
self.assertEqual(g_nngp.shape, g.shape)
self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim)

g_direct = direct(data_self, None)
self.assertEqual(g_nngp.shape, g_direct.shape)

g = implicit(data_self, None)
self.assertAllClose(g_direct, g)

if 0 not in _trace_axes and 0 not in _diagonal_axes:
g = implicit(data_other, data_self)
g_direct = direct(data_other, data_self)
g_nngp = nngp(data_other, data_self)

self.assertAllClose(g, g_direct)
self.assertEqual(g_nngp.shape, g.shape)
self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim)

g_direct = direct(data_other, data_self)
self.assertEqual(g_nngp.shape, g_direct.shape)

g = implicit(data_other, data_self)
self.assertAllClose(g_direct, g)


if __name__ == '__main__':
absltest.main()
45 changes: 34 additions & 11 deletions neural_tangents/utils/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def delta_vjp(delta):
fx_dummy = np.ones(fx2_struct.shape, fx2_struct.dtype)

ntk = jacobian(delta_vjp_jvp)(fx_dummy)
return _index_and_contract(ntk, trace_axes, diagonal_axes)
return _trace_and_diagonal(ntk, trace_axes, diagonal_axes)

return ntk_fn

Expand Down Expand Up @@ -571,28 +571,51 @@ def _read_keys(keys: Union[None, PRNGKey, Tuple[PRNGKey, PRNGKey]]
return key1, key2


def _index_and_contract(ntk: np.ndarray,
def _trace_and_diagonal(ntk: np.ndarray,
trace_axes: Axes,
diagonal_axes: Axes) -> np.ndarray:
"""Extract traces and diagonals along respective pairs of axes from the `ntk`.
Args:
ntk:
input empirical NTK of shape `(N1, X, Y, Z, ..., N2, X, Y, Z, ...)`.
trace_axes:
axes (among `X, Y, Z, ...`) to trace over, i.e. compute the trace along
and remove the respective pairs of axes from the `ntk`.
diagonal_axes:
axes (among `X, Y, Z, ...`) to take the diagonal along, i.e. extract the
diagonal along the respective pairs of axes from the `ntk` (and hence
reduce the resulting `ntk` axes count by 2).
Returns:
An array of shape, for example, `(N1, N2, Y, Z, Z, ...)` if
`trace_axes=(1,)` (`X` axes removed), and `diagonal_axes=(2,)` (`Y` axes
replaced with a single `Y` axis).
"""
if ntk.ndim % 2 == 1:
raise ValueError('Expected an even-dimensional kernel. Please file a bug at'
'https://github.com/google/neural-tangents/issues/new')

output_ndim = ntk.ndim // 2

trace_axes = utils.canonicalize_axis(trace_axes, output_ndim)
diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim)
n_marg = len(diagonal_axes)

n_diag, n_trace = len(diagonal_axes), len(trace_axes)
contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes)

shrink = 0
for c in reversed(trace_axes):
ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - shrink)
shrink += 1
for i, c in enumerate(reversed(trace_axes)):
ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - i)

for i, d in enumerate(diagonal_axes):
ntk = np.diagonal(ntk, axis1=d - i, axis2=output_ndim + d - shrink - 2 * i)

ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_marg)
axis1 = d - i
axis2 = output_ndim + d - 2 * i - n_trace
for c in trace_axes:
if c < d:
axis1 -= 1
axis2 -= 1
ntk = np.diagonal(ntk, axis1=axis1, axis2=axis2)

ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_diag)
res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes)
ntk = np.moveaxis(ntk, range(-n_marg, 0), res_diagonal_axes)
ntk = np.moveaxis(ntk, range(-n_diag, 0), res_diagonal_axes)
return ntk / contract_size
9 changes: 6 additions & 3 deletions neural_tangents/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,15 +419,18 @@ def dot_general(lhs: np.ndarray,
n_batch_dims = len(batch_dims)
leading_batch_dims = range(n_batch_dims)

dimension_numbers = ((contracting_dims, contracting_dims),
(leading_batch_dims, leading_batch_dims))

lhs = np.moveaxis(lhs, batch_dims, leading_batch_dims)
if rhs is None:
rhs = lhs
else:
rhs = np.moveaxis(rhs, batch_dims, leading_batch_dims)

shifted_contracting_dims = [i + sum(1 if i < b else 0 for b in batch_dims)
for i in contracting_dims]

dimension_numbers = ((shifted_contracting_dims, shifted_contracting_dims),
(leading_batch_dims, leading_batch_dims))

prod = lax.dot_general(lhs, rhs, dimension_numbers, precision)
prod = zip_axes(prod, n_batch_dims)

Expand Down

0 comments on commit 3582ba0

Please sign in to comment.