Skip to content

Commit

Permalink
Remove use of jax.numpy.onp, which was an accidental export of classi…
Browse files Browse the repository at this point in the history
…c NumPy from the jax.numpy namespace.

Instead, `import numpy as onp` directly where needed.

PiperOrigin-RevId: 310191530
  • Loading branch information
hawkinsp authored and sschoenholz committed May 8, 2020
1 parent 1a1ffc8 commit 272dc5e
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
9 changes: 5 additions & 4 deletions neural_tangents/tests/stax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from neural_tangents import stax
from neural_tangents.utils import monte_carlo
from neural_tangents.utils import test_utils
import numpy as onp
import unittest


Expand Down Expand Up @@ -168,7 +169,7 @@ def conv(out_chan): return stax.GeneralConv(
)
affine = conv(width) if is_conv else fc(width)

rate = np.onp.random.uniform(0.5, 0.9)
rate = onp.random.uniform(0.5, 0.9)
dropout = stax.Dropout(rate, mode='train')

if pool_type == 'AVG':
Expand Down Expand Up @@ -650,11 +651,11 @@ def test_sparse_inputs(self, act, kernel):
samples = N_SAMPLES

if xla_bridge.get_backend().platform == 'gpu':
jtu._default_tolerance[np.onp.dtype(np.onp.float64)] = 5e-4
jtu._default_tolerance[onp.dtype(onp.float64)] = 5e-4
samples = 100 * N_SAMPLES
else:
jtu._default_tolerance[np.onp.dtype(np.onp.float32)] = 5e-2
jtu._default_tolerance[np.onp.dtype(np.onp.float64)] = 5e-3
jtu._default_tolerance[onp.dtype(onp.float32)] = 5e-2
jtu._default_tolerance[onp.dtype(onp.float64)] = 5e-3

# a batch of dense inputs
x_dense = random.normal(key, (input_count, input_size))
Expand Down
3 changes: 2 additions & 1 deletion neural_tangents/utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from jax.tree_util import tree_multimap
from neural_tangents.utils.kernel import Kernel
from neural_tangents.utils import utils
import numpy as onp


def _scan(f, init, xs, store_on_device):
Expand Down Expand Up @@ -418,7 +419,7 @@ def batch(kernel_fn, batch_size=0, device_count=-1, store_on_device=True):

def _get_n_batches_and_batch_sizes(n1, n2, batch_size, device_count):
# TODO(romann): if dropout batching works for different batch sizes, relax.
max_serial_batch_size = np.onp.gcd(n1, n2) // device_count
max_serial_batch_size = onp.gcd(n1, n2) // device_count

n2_batch_size = min(batch_size, max_serial_batch_size)
if n2_batch_size != batch_size:
Expand Down
7 changes: 4 additions & 3 deletions neural_tangents/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from jax.lib import xla_bridge
import jax.numpy as np
import jax.test_util as jtu
import numpy as onp
from .kernel import Kernel
import dataclasses

Expand All @@ -31,13 +32,13 @@ def _jit_vmap(f):


def update_test_tolerance(f32_tol=5e-3, f64_tol=1e-5):
jtu._default_tolerance[np.onp.dtype(np.onp.float32)] = f32_tol
jtu._default_tolerance[np.onp.dtype(np.onp.float64)] = f64_tol
jtu._default_tolerance[onp.dtype(onp.float32)] = f32_tol
jtu._default_tolerance[onp.dtype(onp.float64)] = f64_tol
def default_tolerance():
if jtu.device_under_test() != 'tpu':
return jtu._default_tolerance
tol = jtu._default_tolerance.copy()
tol[np.onp.dtype(np.onp.float32)] = 5e-2
tol[onp.dtype(onp.float32)] = 5e-2
return tol
jtu.default_tolerance = default_tolerance

Expand Down

1 comment on commit 272dc5e

@DarrenZhang01
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When installing using pip3, the bug on line 421 of batch.py still exists (i.e., np.onp) and this make the command python3.6 examples/function_space.py terminates.

Please sign in to comment.