Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Selectively fix atom positions #115

Closed
Noahyt opened this issue Jan 22, 2021 · 2 comments
Closed

Selectively fix atom positions #115

Noahyt opened this issue Jan 22, 2021 · 2 comments

Comments

@Noahyt
Copy link

Noahyt commented Jan 22, 2021

I believe this is a feature request -- though I perhaps I am missing an obvious solution already available.

I would like to be able to selectively fix atom positions. I.e. be able to run an energy minimization with some atom locations fixed, then "release" them and run minimization again (with a different subset of atoms fixed).

I can not figure out how to do this since all positions are handled through the R array whose elements are all updated by the minimization algorithms at every state.

@cpgoodri
Copy link
Contributor

Have you tried just using a customer shift function?

@sschoenholz
Copy link
Collaborator

Hey @Noahyt,

Sorry about the delay in response, we were busy getting ready for ICML. As @cpgoodri, it should be possible to do this without really modifying jax md. The basic premise is that we should use a new shift function that only moves a subset of the particles. By passing an explicit mask to the particles, it should be possible to adjust the frozen particles dynamically. One caveat is that this won't work with simulation environments that need to use the number of degrees of freedom as a parameter (for example, NVT simulations); if this use-case is important for you please let me know and I'll see what I can do!

In any case, here is schematically how I would write such a shift function,

displacement, shift = space.periodic(box_size)
def masked_shift(R, dR, is_mobile=None):
  if is_mobile is None:
    return shift(R, dR)
  return np.where(is_mobile[:, None], shift(R, dR), R)

Then this masked_shift function can be passed to minimizers. The is_mobile array can be passed as a keyword argument to the minimizer and it should just work.

fire_init, fire_apply = minimize.fire_descent(energy_fn, masked_shift)
fire_apply = jit(fire_apply)
fire_state = fire_init(R)

# Freeze all but the first 10 particles.
is_mobile = np.arange(N) > 10 
for i in range(steps):
  fire_state = fire_apply(fire_state, is_mobile=is_mobile)

Here is an example colab notebook that puts all of this together. Please let me know if this isn't what you had in mind or if you get stuck!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants