Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

random.keyArray has been removed from JAX #299

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions jax_md/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,26 +75,23 @@ def wrapped_force_fn(R, *args, **kwargs):


def canonicalize_force(energy_or_force_fn: Union[EnergyFn, ForceFn]) -> ForceFn:
_force_fn = None
def force_fn(R, **kwargs):
nonlocal _force_fn
if _force_fn is None:
out_shaped = eval_shape(energy_or_force_fn, R, **kwargs)
if isinstance(out_shaped, ShapeDtypeStruct) and out_shaped.shape == ():
_force_fn = force(energy_or_force_fn)
else:
# Check that the output has the right shape to be a force.
is_valid_force = tree_reduce(
lambda x, y: x and y,
tree_map(lambda x, y: x.shape == y.shape, out_shaped, R),
True
)
if not is_valid_force:
raise ValueError('Provided function should be compatible with '
'either an energy or a force. Found a function '
f'whose output has shape {out_shaped}.')

_force_fn = energy_or_force_fn
out_shaped = eval_shape(energy_or_force_fn, R, **kwargs)
if isinstance(out_shaped, ShapeDtypeStruct) and out_shaped.shape == ():
_force_fn = force(energy_or_force_fn)
else:
# Check that the output has the right shape to be a force.
is_valid_force = tree_reduce(
lambda x, y: x and y,
tree_map(lambda x, y: x.shape == y.shape, out_shaped, R),
True
)
if not is_valid_force:
raise ValueError('Provided function should be compatible with '
'either an energy or a force. Found a function '
f'whose output has shape {out_shaped}.')

_force_fn = energy_or_force_fn
return _force_fn(R, **kwargs)

return force_fn
Expand Down
5 changes: 2 additions & 3 deletions jax_md/rigid_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
PyTree = Any
f64 = util.f64
f32 = util.f32
KeyArray = random.KeyArray
NeighborListFns = partition.NeighborListFns
ShiftFn = space.ShiftFn

Expand Down Expand Up @@ -152,7 +151,7 @@ def _quaternion_rotate_bwd(res, g: Array) -> Tuple[Array, Array]:
_quaternion_rotate.defvjp(_quaternion_rotate_fwd, _quaternion_rotate_bwd)


def _random_quaternion(key: KeyArray, dtype: DType) -> Array:
def _random_quaternion(key: Array, dtype: DType) -> Array:
"""Generate a random quaternion of a given dtype."""
rnd = random.uniform(key, (3,), minval=0.0, maxval=1.0, dtype=dtype)

Expand Down Expand Up @@ -214,7 +213,7 @@ def quaternion_rotate(q: Quaternion, v: Array) -> Array:
return jnp.vectorize(_quaternion_rotate, signature='(q),(d)->(d)')(q.vec, v)


def random_quaternion(key: KeyArray, dtype: DType) -> Quaternion:
def random_quaternion(key: Array, dtype: DType) -> Quaternion:
"""Generate a random quaternion of a given dtype."""
rand_quat = partial(_random_quaternion, dtype=dtype)
rand_quat = jnp.vectorize(rand_quat, signature='(k)->(q)')
Expand Down
7 changes: 0 additions & 7 deletions jax_md/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import functools

from jax import grad
from jax import jit
from jax import random
import jax.numpy as jnp
from jax import lax
Expand Down Expand Up @@ -281,14 +280,12 @@ def nve(energy_or_force_fn, shift_fn, dt=1e-3, **sim_kwargs):
"""
force_fn = quantity.canonicalize_force(energy_or_force_fn)

@jit
def init_fn(key, R, kT, mass=f32(1.0), **kwargs):
force = force_fn(R, **kwargs)
state = NVEState(R, None, force, mass)
state = canonicalize_mass(state)
return initialize_momenta(state, key, kT)

@jit
def step_fn(state, **kwargs):
_dt = kwargs.pop('dt', dt)
return velocity_verlet(force_fn, shift_fn, _dt, state, **kwargs)
Expand Down Expand Up @@ -588,7 +585,6 @@ def nvt_nose_hoover(energy_or_force_fn: Callable[..., Array],

thermostat = nose_hoover_chain(dt, chain_length, chain_steps, sy_steps, tau)

@jit
def init_fn(key, R, mass=f32(1.0), **kwargs):
_kT = kT if 'kT' not in kwargs else kwargs['kT']

Expand All @@ -600,7 +596,6 @@ def init_fn(key, R, mass=f32(1.0), **kwargs):
KE = kinetic_energy(state)
return state.set(chain=thermostat.initialize(dof, KE, _kT))

@jit
def apply_fn(state, **kwargs):
_kT = kT if 'kT' not in kwargs else kwargs['kT']

Expand Down Expand Up @@ -1046,7 +1041,6 @@ def nvt_langevin(energy_or_force_fn: Callable[..., Array],
"""
force_fn = quantity.canonicalize_force(energy_or_force_fn)

@jit
def init_fn(key, R, mass=f32(1.0), **kwargs):
_kT = kwargs.pop('kT', kT)
key, split = random.split(key)
Expand All @@ -1055,7 +1049,6 @@ def init_fn(key, R, mass=f32(1.0), **kwargs):
state = canonicalize_mass(state)
return initialize_momenta(state, split, _kT)

@jit
def step_fn(state, **kwargs):
_dt = kwargs.pop('dt', dt)
_kT = kwargs.pop('kT', kT)
Expand Down