Skip to content

Commit

Permalink
Removed as a required keyword argument from remaining simulations.
Browse files Browse the repository at this point in the history
  • Loading branch information
sschoenholz committed Feb 1, 2021
1 parent 9d51c8d commit c55ec95
Showing 1 changed file with 30 additions and 18 deletions.
48 changes: 30 additions & 18 deletions jax_md/simulate.py
Expand Up @@ -123,10 +123,10 @@ def init_fun(key: Array,
V = np.sqrt(velocity_scale) * random.normal(key, R.shape, dtype=R.dtype)
mass = quantity.canonicalize_mass(mass)
return NVEState(R, V, force(R, **kwargs) / mass, mass) # pytype: disable=wrong-arg-count
def apply_fun(state: NVEState, t: float=f32(0), **kwargs) -> NVEState:
def apply_fun(state: NVEState, **kwargs) -> NVEState:
R, V, A, mass = dataclasses.astuple(state)
R = shift_fn(R, V * dt + A * dt_2, t=t, **kwargs)
A_prime = force(R, t=t, **kwargs) / mass
R = shift_fn(R, V * dt + A * dt_2, **kwargs)
A_prime = force(R, **kwargs) / mass
V = V + f32(0.5) * (A + A_prime) * dt
return NVEState(R, V, A_prime, mass) # pytype: disable=wrong-arg-count
return init_fun, apply_fun
Expand Down Expand Up @@ -216,6 +216,9 @@ def nvt_nose_hoover(energy_or_force: Callable[..., Array],
and dR should be ndarrays of shape [n, spatial_dimension].
dt: Floating point number specifying the timescale (step size) of the
simulation.
kT: Floating point number specifying the temperature inunits of Boltzmann
constant. To update the temperature dynamically during a simulation one
should pass `kT` as a keyword argument to the step function.
chain_length: An integer specifying the number of particles in
the Nose-Hoover chain.
chain_steps: An integer specifying the number, $n_c$, of outer substeps.
Expand Down Expand Up @@ -283,15 +286,15 @@ def substep_chain_fn(delta, KE, V, xi, v_xi, Q, DOF, T):
G = (Q[M - 1] * v_xi[M - 1] ** f32(2) - T) / Q[M]
v_xi = ops.index_add(v_xi, M, delta_4 * G)

def backward_loop_fn(v_xi_new, m):
def backward_loop_fn(v_xi_new, m):
G = (Q[m - 1] * v_xi[m - 1] ** 2 - T) / Q[m]
scale = np.exp(-delta_8 * v_xi_new)
v_xi_new = scale * (scale * v_xi[m] + delta_4 * G)
return v_xi_new, v_xi_new
idx = np.arange(M - 1, 0, -1)
_, v_xi_update = lax.scan(backward_loop_fn, v_xi[M], idx, unroll=2)
v_xi = ops.index_update(v_xi, idx, v_xi_update)

G = (f32(2.0) * KE - DOF * T) / Q[0]
scale = np.exp(-delta_8 * v_xi[1])
v_xi = ops.index_update(v_xi, 0, scale * (scale * v_xi[0] + delta_4 * G))
Expand Down Expand Up @@ -430,10 +433,9 @@ def nvt_langevin(energy_or_force: Callable[..., Array],
and dR should be ndarrays of shape [n, spatial_dimension].
dt: Floating point number specifying the timescale (step size) of the
simulation.
T_schedule: Either a floating point number specifying a constant temperature
or a function specifying temperature as a function of time.
quant: Either a quantity.Energy or a quantity.Force specifying whether
energy_or_force is an energy or force respectively.
kT: Floating point number specifying the temperature inunits of Boltzmann
constant. To update the temperature dynamically during a simulation one
should pass `kT` as a keyword argument to the step function.
gamma: A float specifying the friction coefficient between the particles
and the solvent.
Returns:
Expand Down Expand Up @@ -463,7 +465,7 @@ def init_fn(key, R, mass=f32(1), **kwargs):
V = np.sqrt(_kT / mass) * random.normal(split, R.shape, dtype=R.dtype)
V = V - np.mean(V, axis=0, keepdims=True)

return NVTLangevinState(R, V, force_fn(R, t=f32(0), **kwargs), mass, key) # pytype: disable=wrong-arg-count
return NVTLangevinState(R, V, force_fn(R, **kwargs), mass, key) # pytype: disable=wrong-arg-count

def apply_fn(state, **kwargs):
R, V, F, mass, key = dataclasses.astuple(state)
Expand Down Expand Up @@ -510,7 +512,7 @@ class BrownianState:
def brownian(energy_or_force: Callable[..., Array],
shift: ShiftFn,
dt: float,
T_schedule: Schedule,
kT: float,
gamma: float=0.1) -> Simulator:
"""Simulation of Brownian dynamics.
Expand All @@ -521,7 +523,18 @@ def brownian(energy_or_force: Callable[..., Array],
case of Langevin dynamics our implementation follows [1].
Args:
See nvt_langevin.
energy_or_force: A function that produces either an energy or a force from
a set of particle positions specified as an ndarray of shape
[n, spatial_dimension].
shift_fn: A function that displaces positions, R, by an amount dR. Both R
and dR should be ndarrays of shape [n, spatial_dimension].
dt: Floating point number specifying the timescale (step size) of the
simulation.
kT: Floating point number specifying the temperature inunits of Boltzmann
constant. To update the temperature dynamically during a simulation one
should pass `kT` as a keyword argument to the step function.
gamma: A float specifying the friction coefficient between the particles
and the solvent.
Returns:
See above.
Expand All @@ -535,26 +548,25 @@ def brownian(energy_or_force: Callable[..., Array],

dt, gamma = static_cast(dt, gamma)

T_schedule = interpolate.canonicalize(T_schedule)

def init_fn(key, R, mass=f32(1)):
mass = quantity.canonicalize_mass(mass)

return BrownianState(R, mass, key) # pytype: disable=wrong-arg-count

def apply_fn(state, t=f32(0), **kwargs):
def apply_fn(state, **kwargs):
_kT = kT if 'kT' not in kwargs else kwargs['kT']

R, mass, key = dataclasses.astuple(state)

key, split = random.split(key)

F = force_fn(R, t=t, **kwargs)
F = force_fn(R, **kwargs)
xi = random.normal(split, R.shape, R.dtype)

nu = f32(1) / (mass * gamma)

dR = F * dt * nu + np.sqrt(f32(2) * T_schedule(t) * dt * nu) * xi
R = shift(R, dR, t=t, **kwargs)
dR = F * dt * nu + np.sqrt(f32(2) * _kT * dt * nu) * xi
R = shift(R, dR, **kwargs)

return BrownianState(R, mass, key) # pytype: disable=wrong-arg-count

Expand Down

0 comments on commit c55ec95

Please sign in to comment.