Skip to content

Commit

Permalink
Add ReluNTKFeatures test
Browse files Browse the repository at this point in the history
  • Loading branch information
insuhan committed Jul 13, 2022
1 parent af524d2 commit a16098c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 9 deletions.
27 changes: 18 additions & 9 deletions experimental/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,12 +521,14 @@ def _cholesky(mat):


@layer
def ReluNTKFeatures(
num_layers: int,
poly_degree: int = 16,
poly_sketch_dim: int = 1024,
W_std: float = 1.,
):
def ReluNTKFeatures(num_layers: int,
poly_degree: int = 16,
poly_sketch_dim: int = 1024,
batch_axis: int = 0,
channel_axis: int = -1):

if batch_axis != 0 or channel_axis != -1:
raise NotImplementedError(f'Not supported axes.')

def init_fn(rng, input_shape):
input_dim = input_shape[0][-1]
Expand All @@ -541,14 +543,18 @@ def init_fn(rng, input_shape):

return (), (polysketch, nngp_coeffs, ntk_coeffs)

def feature_fn(f, input=None, **kwargs):
@requires(batch_axis=batch_axis, channel_axis=channel_axis)
def feature_fn(f: Features, input=None, **kwargs):
input_shape = f.nngp_feat.shape[:-1]

polysketch: PolyTensorSketch = input[0]
nngp_coeffs: np.ndarray = input[1]
ntk_coeffs: np.ndarray = input[2]

polysketch_feats = polysketch.sketch(f.nngp_feat)
norms = np.linalg.norm(f.nngp_feat, axis=channel_axis, keepdims=True)
nngp_feat = f.nngp_feat / norms

polysketch_feats = polysketch.sketch(nngp_feat)
nngp_feat = polysketch.expand_feats(polysketch_feats, nngp_coeffs)
ntk_feat = polysketch.expand_feats(polysketch_feats, ntk_coeffs)

Expand All @@ -557,8 +563,11 @@ def feature_fn(f, input=None, **kwargs):
ntk_feat = polysketch.standardsrht(ntk_feat).reshape(input_shape + (-1,))

# Convert complex features to real ones.
ntk_feat = np.concatenate((ntk_feat.real, ntk_feat.imag), axis=-1)
nngp_feat = np.concatenate((nngp_feat.real, nngp_feat.imag), axis=-1)
ntk_feat = np.concatenate((ntk_feat.real, ntk_feat.imag), axis=-1)

nngp_feat *= norms / 2**(num_layers / 2.)
ntk_feat *= norms / 2**(num_layers / 2.)

return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat)

Expand Down
31 changes: 31 additions & 0 deletions experimental/tests/features_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,37 @@ def test_aggregate_features(self):
self.assertAllClose(k.nngp, f.nngp_feat @ f.nngp_feat.T)
self.assertAllClose(k.ntk, f.ntk_feat @ f.ntk_feat.T)

@parameterized.product(n_layers=[1, 2, 3, 4, 5], do_jit=[True, False])
def test_onepass_fc_relu_nngp_ntk(self, n_layers, do_jit):
rng = random.PRNGKey(1)
n, d = 4, 256
x = _get_init_data(rng, (n, d))

kernel_fn = stax.serial(*[stax.Dense(1), stax.Relu()] * n_layers +
[stax.Dense(1)])[2]

poly_degree = 8
poly_sketch_dim = 4096

init_fn, feature_fn = ft.ReluNTKFeatures(n_layers, poly_degree,
poly_sketch_dim)

rng2 = random.PRNGKey(2)
_, feat_fn_inputs = init_fn(rng2, x.shape)

if do_jit:
kernel_fn = jit(kernel_fn)
feature_fn = jit(feature_fn)

k = kernel_fn(x)
f = feature_fn(x, feat_fn_inputs)

k_nngp_approx = f.nngp_feat @ f.nngp_feat.T
k_ntk_approx = f.ntk_feat @ f.ntk_feat.T

test_utils.assert_close_matrices(self, k.nngp, k_nngp_approx, 0.2, 1.)
test_utils.assert_close_matrices(self, k.ntk, k_ntk_approx, 0.2, 1.)


if __name__ == "__main__":
absltest.main()

0 comments on commit a16098c

Please sign in to comment.