Skip to content

Commit

Permalink
Update readthedocs requirements to include all neural tangents build …
Browse files Browse the repository at this point in the history
…dependencies. Fix pytype python version

PiperOrigin-RevId: 557183101
  • Loading branch information
romanngg committed Aug 24, 2023
1 parent db53e46 commit b3b728e
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
python-version: [3.9, '3.10', 3.11]
JAX_ENABLE_X64: [0, 1]

runs-on: ubuntu-20.04
runs-on: ubuntu-latest

steps:

Expand Down
2 changes: 1 addition & 1 deletion tests/predict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def testNTKMeanCovPrediction(
self.assertGreater(np.min(np.linalg.eigh(cov_train_inf)[0]), -1e-8)

_kernel_fn = nt.empirical_kernel_fn(f)
kernel_fn = lambda x1, x2, params: _kernel_fn(x1, x2, 'ntk', params)
kernel_fn = jit(lambda x1, x2, params: _kernel_fn(x1, x2, 'ntk', params))

def predict_empirical(key):
_, params = init_fn(key, train_shape)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import os
from types import ModuleType
from typing import Callable, Optional, Sequence
from typing import Callable, Optional, Sequence

from absl import flags
from absl.testing import parameterized
Expand All @@ -41,7 +41,7 @@

flags.DEFINE_integer(
'nt_num_generated_cases',
int(os.getenv('NT_NUM_GENERATED_CASES', '2')),
int(os.getenv('NT_NUM_GENERATED_CASES', '4')),
help='Number of generated cases to test'
)

Expand Down

0 comments on commit b3b728e

Please sign in to comment.