Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change jax.api.* -> jax.*. #150

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 8 additions & 8 deletions jax_md/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from jax import lax
from jax import ops
from jax.api import jit, vmap, eval_shape
from jax import jit, vmap, eval_shape
from jax.abstract_arrays import ShapedArray
from jax.interpreters import partial_eval as pe
import jax.numpy as np
Expand Down Expand Up @@ -137,7 +137,7 @@ def count(cell_hash, filling):


def _is_variable_compatible_with_positions(R: Array) -> bool:
if (isinstance(R, np.ndarray) and
if (util.is_array(R) and
len(R.shape) == 2 and
np.issubdtype(R.dtype, np.floating)):
return True
Expand Down Expand Up @@ -179,11 +179,11 @@ def _unflatten_cell_buffer(arr: Array,
dim: int) -> Array:
if (isinstance(cells_per_side, int) or
isinstance(cells_per_side, float) or
(isinstance(cells_per_side, np.ndarray) and not cells_per_side.shape)):
(util.is_array(cells_per_side) and not cells_per_side.shape)):
cells_per_side = (int(cells_per_side),) * dim
elif isinstance(cells_per_side, np.ndarray) and len(cells_per_side.shape) == 1:
elif util.is_array(cells_per_side) and len(cells_per_side.shape) == 1:
cells_per_side = tuple([int(x) for x in cells_per_side[::-1]])
elif isinstance(cells_per_side, np.ndarray) and len(cells_per_side.shape) == 2:
elif util.is_array(cells_per_side) and len(cells_per_side.shape) == 2:
cells_per_side = tuple([int(x) for x in cells_per_side[0][::-1]])
else:
raise ValueError() # TODO
Expand Down Expand Up @@ -264,12 +264,12 @@ def cell_list(box_size: Box,
containing the partition.
"""

if isinstance(box_size, np.ndarray):
if util.is_array(box_size):
box_size = onp.array(box_size)
if len(box_size.shape) == 1:
box_size = np.reshape(box_size, (1, -1))

if isinstance(minimum_cell_size, np.ndarray):
if util.is_array(minimum_cell_size):
minimum_cell_size = onp.array(minimum_cell_size)

cell_capacity = cell_capacity_or_example_R
Expand Down Expand Up @@ -318,7 +318,7 @@ def build_cells(R, **kwargs):
empty_kwarg_value = 10 ** 5
cell_kwargs = {}
for k, v in kwargs.items():
if not isinstance(v, np.ndarray):
if not util.is_array(v):
raise ValueError((
'Data must be specified as an ndarry. Found "{}" with '
'type {}'.format(k, type(v))))
Expand Down
2 changes: 1 addition & 1 deletion jax_md/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from typing import TypeVar, Callable, Union, Tuple

from jax.api import grad, vmap, eval_shape
from jax import grad, vmap, eval_shape
import jax.numpy as jnp

from jax_md import space, dataclasses, partition, util
Expand Down
2 changes: 1 addition & 1 deletion jax_md/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from typing import Callable, TypeVar, Union, Tuple, Dict, Optional

from jax.api import grad
from jax import grad
from jax import ops
from jax import random
import jax.numpy as jnp
Expand Down
14 changes: 7 additions & 7 deletions jax_md/smap.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _get_bond_type_parameters(params: Array, bond_type: Array) -> Array:
assert isinstance(bond_type, jnp.ndarray)
assert len(bond_type.shape) == 1

if isinstance(params, jnp.ndarray):
if util.is_array(params):
if len(params.shape) == 1:
return params[bond_type]
elif len(params.shape) == 0:
Expand Down Expand Up @@ -167,7 +167,7 @@ def mapped_fn(R: Array,
def _get_species_parameters(params: Array, species: Array) -> Array:
"""Get parameters for interactions between species pairs."""
# TODO(schsam): We should do better error checking here.
if isinstance(params, jnp.ndarray):
if util.is_array(params):
if len(params.shape) == 2:
return params[species]
elif len(params.shape) == 0:
Expand All @@ -180,7 +180,7 @@ def _get_species_parameters(params: Array, species: Array) -> Array:

def _get_matrix_parameters(params: Array, combinator: Callable) -> Array:
"""Get an NxN parameter matrix from per-particle parameters."""
if isinstance(params, jnp.ndarray):
if util.is_array(params):
if len(params.shape) == 1:
return combinator(params[:, jnp.newaxis], params[jnp.newaxis, :])
elif len(params.shape) == 0 or len(params.shape) == 2:
Expand Down Expand Up @@ -339,7 +339,7 @@ def fn_mapped(R: Array, **dynamic_kwargs) -> Array:
# we are mapping. Should this be an option?
return high_precision_sum(_diagonal_mask(fn(dr, **_kwargs)),
axis=reduce_axis, keepdims=keepdims) * f32(0.5)
elif isinstance(species, jnp.ndarray):
elif util.is_array(species):
species = onp.array(species)
_check_species_dtype(species)
species_count = int(onp.max(species))
Expand Down Expand Up @@ -395,7 +395,7 @@ def fn_mapped(R, species, **dynamic_kwargs):
def _get_neighborhood_matrix_params(idx: Array,
params: Array,
combinator: Callable) -> Array:
if isinstance(params, jnp.ndarray):
if util.is_array(params):
if len(params.shape) == 1:
return combinator(jnp.reshape(params, params.shape + (1,)), params[idx])
elif len(params.shape) == 2:
Expand Down Expand Up @@ -423,7 +423,7 @@ def lookup(species_a, species_b, params):
lookup = vmap(vmap(lookup, (None, 0, None)), (0, 0, None))

neighbor_species = jnp.reshape(species[idx], idx.shape)
if isinstance(params, jnp.ndarray):
if util.is_array(params):
if len(params.shape) == 2:
return lookup(species, neighbor_species, params)
elif len(params.shape) == 0:
Expand Down Expand Up @@ -611,7 +611,7 @@ def fn_mapped(R, **dynamic_kwargs) -> Array:
return high_precision_sum(output,
axis=reduce_axis,
keepdims=keepdims) / 2.
elif isinstance(species, jnp.ndarray):
elif util.is_array(species):
def fn_mapped(R, **dynamic_kwargs):
d = partial(displacement_or_metric, **dynamic_kwargs)
idx = onp.tile(onp.arange(R.shape[0]), [R.shape[0], 1])
Expand Down
8 changes: 6 additions & 2 deletions jax_md/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

"""Defines utility functions."""

from typing import Iterable, Union, Optional
from typing import Iterable, Union, Optional, Any

from jax.tree_util import register_pytree_node
from jax.lib import xla_bridge
import jax.numpy as jnp
from jax.api import jit
from jax import jit

from functools import partial

Expand Down Expand Up @@ -81,3 +81,7 @@ def maybe_downcast(x):
if isinstance(x, jnp.ndarray) and x.dtype is jnp.dtype('float64'):
return x
return jnp.array(x, f32)


def is_array(x: Any) -> bool:
return isinstance(x, (onp.ndarray, jnp.ndarray))
6 changes: 3 additions & 3 deletions notebooks/jax_md_cookbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@
"\n",
"import jax.numpy as np\n",
"\n",
"from jax.api import jit\n",
"from jax.api import grad\n",
"from jax.api import vmap\n",
"from jax import jit\n",
"from jax import grad\n",
"from jax import vmap\n",
"from jax import value_and_grad\n",
"\n",
"from jax import random\n",
Expand Down
10 changes: 5 additions & 5 deletions notebooks/lanl_summer_school_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@
"id": "QrOivTPpCJV7"
},
"source": [
"from jax.api import grad\n",
"from jax import grad\n",
"\n",
"du_dr = grad(soft_sphere)\n",
"\n",
Expand All @@ -222,7 +222,7 @@
"id": "C3sx2gTb23s7"
},
"source": [
"from jax.api import vmap\n",
"from jax import vmap\n",
"\n",
"du_dr_v = vmap(du_dr)\n",
"\n",
Expand Down Expand Up @@ -466,7 +466,7 @@
"id": "ZNMxdujG81-6"
},
"source": [
"from jax.api import jit\n",
"from jax import jit\n",
"\n",
"# Just-In-Time compile to GPU\n",
"minimize = jit(minimize)"
Expand Down Expand Up @@ -526,7 +526,7 @@
"id": "9Jesy9PRZc62"
},
"source": [
"from jax.api import jit\n",
"from jax import jit\n",
"\n",
"minimize = jit(minimize)"
],
Expand Down Expand Up @@ -596,7 +596,7 @@
"id": "1oAkt_2K8Lwv"
},
"source": [
"from jax.api import hessian\n",
"from jax import hessian\n",
"\n",
"K = hessian(strain_energy)(np.eye(2), R_is)\n",
"print(K.shape)"
Expand Down
8 changes: 4 additions & 4 deletions notebooks/meta_optimization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@
"\n",
"import jax.numpy as np\n",
"\n",
"from jax.api import jit\n",
"from jax.api import grad\n",
"from jax.api import vmap\n",
"from jax.api import value_and_grad\n",
"from jax import jit\n",
"from jax import grad\n",
"from jax import vmap\n",
"from jax import value_and_grad\n",
"\n",
"from jax import random\n",
"from jax import lax\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/neural_networks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"import jax\n",
"from jax import lax\n",
"\n",
"from jax.api import jit, vmap, grad\n",
"from jax import jit, vmap, grad\n",
"\n",
"# TODO: Re-enable x64 mode after XLA bug fix.\n",
"# from jax.config import config ; config.update('jax_enable_x64', True)\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/neurips_spotlight_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"!wget https://github.com/google/jax-md/blob/master/examples/models/si_gnn.pickle?raw=true\n",
"\n",
"import numpy as onp\n",
"from jax.api import device_put\n",
"from jax import device_put\n",
"\n",
"box_size = 10.862\n",
"\n",
Expand Down
10 changes: 5 additions & 5 deletions notebooks/talk_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@
"id": "QrOivTPpCJV7"
},
"source": [
"from jax.api import grad\n",
"from jax import grad\n",
"\n",
"du_dr = grad(soft_sphere)\n",
"\n",
Expand All @@ -231,7 +231,7 @@
"id": "C3sx2gTb23s7"
},
"source": [
"from jax.api import vmap\n",
"from jax import vmap\n",
"\n",
"du_dr_v = vmap(du_dr)\n",
"\n",
Expand Down Expand Up @@ -475,7 +475,7 @@
"id": "ZNMxdujG81-6"
},
"source": [
"from jax.api import jit\n",
"from jax import jit\n",
"\n",
"# Just-In-Time compile to GPU\n",
"minimize = jit(minimize)"
Expand Down Expand Up @@ -535,7 +535,7 @@
"id": "9Jesy9PRZc62"
},
"source": [
"from jax.api import jit\n",
"from jax import jit\n",
"\n",
"minimize = jit(minimize)"
],
Expand Down Expand Up @@ -605,7 +605,7 @@
"id": "1oAkt_2K8Lwv"
},
"source": [
"from jax.api import hessian\n",
"from jax import hessian\n",
"\n",
"K = hessian(strain_energy)(np.eye(2), R_is)\n",
"print(K.shape)"
Expand Down
2 changes: 1 addition & 1 deletion tests/energy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import numpy as onp

from jax.api import grad
from jax import grad
from jax_md import space
from jax_md.util import *
from jax_md import test_util
Expand Down
2 changes: 1 addition & 1 deletion tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import numpy as onp

from jax.api import jit, grad
from jax import jit, grad
from jax_md import space, quantity, nn, dataclasses, partition
from jax_md.util import f32, f64
from jax_md.test_util import update_test_tolerance
Expand Down
2 changes: 1 addition & 1 deletion tests/partition_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import jax.numpy as np
from jax import ops

from jax.api import grad
from jax import grad

from jax import test_util as jtu
from jax import jit, vmap
Expand Down
2 changes: 1 addition & 1 deletion tests/quantity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from jax import random
import jax.numpy as np

from jax.api import jit, grad, vmap
from jax import jit, grad, vmap
from jax_md import space, quantity, test_util, energy
from jax_md.util import *

Expand Down
2 changes: 1 addition & 1 deletion tests/smap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax import random
import jax.numpy as np

from jax.api import grad
from jax import grad

from jax import test_util as jtu
from jax import jit, vmap
Expand Down
2 changes: 1 addition & 1 deletion tests/space_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from jax import random
import jax.numpy as jnp

from jax.api import grad, jit, jacfwd
from jax import grad, jit, jacfwd

from jax import test_util as jtu

Expand Down