Skip to content

Commit

Permalink
Fix crashing Github tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 559848201
  • Loading branch information
romanngg committed Aug 24, 2023
1 parent b3b728e commit 91846d8
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tests/predict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ def testNTKMeanCovPrediction(
self.assertGreater(np.min(np.linalg.eigh(cov_train_inf)[0]), -1e-8)

_kernel_fn = nt.empirical_kernel_fn(f)
kernel_fn = jit(lambda x1, x2, params: _kernel_fn(x1, x2, 'ntk', params))
# TODO(romann): figure out the slow compile on Ubuntu 22.04 CPU Python 3.9
kernel_fn = lambda x1, x2, params: _kernel_fn(x1, x2, 'ntk', params)

def predict_empirical(key):
_, params = init_fn(key, train_shape)
Expand Down

0 comments on commit 91846d8

Please sign in to comment.