Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Out of Memory Issue During Neighbor List Generation #116

Closed
alecbetancourt opened this issue Sep 30, 2023 · 2 comments
Closed

[BUG] Out of Memory Issue During Neighbor List Generation #116

alecbetancourt opened this issue Sep 30, 2023 · 2 comments
Labels
bug Something isn't working

Comments

@alecbetancourt
Copy link

alecbetancourt commented Sep 30, 2023

Bug summary

I'm attempting to test MD runs with the AMBER protein force fields and a solvated system as a basis for some future free energy experiments. For very small systems I've generally had good luck getting things working but when I try to test a larger solvated system (# atoms = 37266), I'm having trouble with neighbor list generation causing an out of memory error.

DMFF Version

0.2.1

JAX Version

jaxlib[cuda11_cudnn805]==0.3.15

OpenMM Version

7.7.0

How did you download the software?

pip

Input Files, Running Commands, Error Log, etc.

Input file:

import sys
import time
import jax
import jax.numpy as jnp
import numpy as np
import openmm.app as app
import openmm.unit as unit
from dmff import Hamiltonian, NeighborList

from jax_md import space, smap, energy, minimize, quantity, simulate, quantity

from jax.config import config
config.update("jax_enable_x64", True)

prmtop = app.AmberPrmtopFile('../RAMP1_ion.prmtop')
inpcrd = app.AmberInpcrdFile('../RAMP1_ion.inpcrd')
ff = Hamiltonian("amber14/protein.ff14SB.xml", "amber14/tip3p.xml")

def hhbond(bond):
  if bond[0].residue.name == 'HOH':
    if bond[0].element._symbol == 'H' and bond[1].element._symbol == 'H':
      return True
  return False

#remove extra H-H bonds found in AMBER format
prmtop.topology._bonds = [bond for bond in prmtop.topology._bonds if not hhbond(bond)]

potentials = ff.createPotential(prmtop.topology, nonbondedMethod=app.PME, nonbondedCutoff=8*unit.angstrom, prm=prmtop)

params = ff.getParameters()
positions = jnp.array(inpcrd.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
positions = positions - jnp.min(positions, axis=0)

#positions range from 0 to ~9.6 in any given direction
box = jnp.array([
    [10.0, 0.0, 0.0], 
    [0.0, 10.0, 0.0],
    [0.0, 0.0, 10.0]
])

#8 angstrom cutoff
nbList = NeighborList(box, .8, potentials.meta["cov_map"])
nbList.allocate(positions)

Error log:

Traceback (most recent call last):
  File "/mnt/ufs18/home-094/betanc18/DMFF/examples/classical/forces_bench_ramp/jax/jaxrampdebug.py", line 85, in <module>
    nbList.allocate(positions)
  File "/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py", line 44, in allocate
    self.nblist = self.neighborlist_fn.allocate(positions)
  File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax_md/partition.py", line 816, in allocate_fn
    return neighbor_list_fn(position, extra_capacity=extra_capacity, **kwargs)
  File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax_md/partition.py", line 803, in neighbor_list_fn
    return neighbor_fn((position, False))
  File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax_md/partition.py", line 772, in neighbor_fn
    idx, occupancy = prune_neighbor_list_sparse(position, idx, **kwargs)
  File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax/_src/api.py", line 528, in cache_miss
    out_flat = xla.xla_call(
  File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax/core.py", line 1963, in bind
    return call_bind(self, fun, *args, **params)
  File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax/core.py", line 1979, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax/core.py", line 689, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl
    return compiled_fun(*args)
  File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax/_src/dispatch.py", line 837, in _execute_compiled
    out_flat = compiled.execute(in_flat)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 93.12GiB (99990344464B) on device ordinal 0
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   10.35GiB
              constant allocation:       144B
        maybe_live_out allocation:   10.35GiB
     preallocated temp allocation:   93.12GiB
  preallocated temp fragmentation:        64B (0.00%)
                 total allocation:  113.82GiB
              total fragmentation:   10.35GiB (9.09%)
Peak buffers:
	Buffer 1:
		Size: 31.04GiB
		Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/vmap(jit(_einsum))/dot_general[dimension_numbers=(((1,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
		XLA Label: custom-call
		Shape: f64[3,1388754756]
		==========================

	Buffer 2:
		Size: 31.04GiB
		Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1, 3) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
		XLA Label: fusion
		Shape: f64[1388754756,3]
		==========================

	Buffer 3:
		Size: 31.04GiB
		Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1, 3) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
		XLA Label: fusion
		Shape: f64[1388754756,3]
		==========================

	Buffer 4:
		Size: 10.35GiB
		Entry Parameter Subshape: s64[37266,37266]
		==========================

	Buffer 5:
		Size: 10.35GiB
		Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/concatenate[dimension=0]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
		XLA Label: fusion
		Shape: s32[2,1388754756]
		==========================

	Buffer 6:
		Size: 873.4KiB
		Entry Parameter Subshape: f64[37266,3]
		==========================

	Buffer 7:
		Size: 72B
		XLA Label: constant
		Shape: f64[3,3]
		==========================

	Buffer 8:
		Size: 72B
		XLA Label: constant
		Shape: f64[3,3]
		==========================

	Buffer 9:
		Size: 16B
		Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/jit(_cumulative_reduction)/pad[padding_config=((1, 1, 1),)]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
		XLA Label: fusion
		Shape: (s64[10595], s64[10595])
		==========================

	Buffer 10:
		Size: 16B
		Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/jit(_cumulative_reduction)/pad[padding_config=((1, 0, 1),)]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
		XLA Label: fusion
		Shape: (s64[694377378], s64[694377378])
		==========================

	Buffer 11:
		Size: 16B
		Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/jit(_cumulative_reduction)/pad[padding_config=((1, 0, 1),)]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
		XLA Label: fusion
		Shape: (s64[173594344], s64[173594344])
		==========================

	Buffer 12:
		Size: 16B
		Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/jit(_cumulative_reduction)/pad[padding_config=((1, 0, 1),)]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
		XLA Label: fusion
		Shape: (s64[43398586], s64[43398586])
		==========================

	Buffer 13:
		Size: 16B
		Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/jit(_cumulative_reduction)/pad[padding_config=((1, 0, 1),)]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
		XLA Label: fusion
		Shape: (s64[10849646], s64[10849646])
		==========================

	Buffer 14:
		Size: 16B
		Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/jit(_cumulative_reduction)/pad[padding_config=((1, 1, 1),)]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
		XLA Label: fusion
		Shape: (s64[2712411], s64[2712411])
		==========================

	Buffer 15:
		Size: 16B
		Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/jit(_cumulative_reduction)/pad[padding_config=((1, 0, 1),)]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
		XLA Label: fusion
		Shape: (s64[678102], s64[678102])
		==========================

Steps to Reproduce

Running the above snippet of Python code with the input files below causes this issue. From the above error, it looks like some of the buffers allocated are essentially n^2 in one dimension. I don't understand the neighbor/cell list generation code in JAX MD well enough to figure out why this is happening but my understanding is that a cell neighbor list should avoid these issues with an appropriately set cutoff.

Further Information, Files, and Links

The 2 files are the protein system building examples from the AMBER tutorials of the RAMP1 protein solvated in a water box.
RAMP1_ion.zip

@alecbetancourt alecbetancourt added the bug Something isn't working label Sep 30, 2023
@alecbetancourt alecbetancourt changed the title [BUG] _Replace With Suitable Title_ [BUG] Out of Memory Issue During Neighbor List Generation Sep 30, 2023
@KuangYu
Copy link
Collaborator

KuangYu commented Oct 2, 2023

Try to use NeighborListFreud, instead of NeighborList. The latter one relies on the neighbor list module of jax-md, which is slow and memory consuming. On the other hand, the Freud neighbor list runs on CPU, which does not have the memory issue for large systems, and, surprisingly, often runs faster.

Nevertheless, the performances of both modules are unsatisfactory. We are working on more efficient implementations of neighbor list. And hopefully can get it done in next few months.

@alecbetancourt
Copy link
Author

It seems like this does the trick. Comparing with some larger systems at reasonable cutoffs (8-10 Angstrom), the results scale somewhat better with NeighborListFreud although still run into bottlenecks on especially large systems from the benchmark suites of other packages. In regards to my other issue, it looks like there were also some issues with box vector input format that would cause the generation of some large n^2 arrays. This behavior seems to no longer be an issue in the development branch mentioned in that issue though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants