Skip to content
Merged
2 changes: 1 addition & 1 deletion dmff/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .settings import *
from .common.nblist import NeighborList
from .common.nblist import NeighborList, NeighborListFreud
from .api import Hamiltonian
61 changes: 54 additions & 7 deletions dmff/common/nblist.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Optional
import numpy as np
import jax.numpy as jnp
from jax_md import space, partition
from dmff.utils import jit_condition
from dmff.utils import regularize_pairs
from jax import jit
try:
import freud
except ImportError:
pass


class NeighborList:
Expand All @@ -21,7 +25,7 @@ def __init__(self, box, rc, covalent_map) -> None:
self.neighborlist_fn = partition.neighbor_list(self.displacement_fn, box, rc, 0, format=partition.OrderedSparse)
self.nblist = None

def allocate(self, positions: jnp.ndarray):
def allocate(self, positions: jnp.ndarray, box: Optional[jnp.ndarray] = None):
""" A function to allocate a new neighbor list. This function cannot be compiled, since it uses the values of positions to infer the shapes.

Args:
Expand All @@ -33,10 +37,10 @@ def allocate(self, positions: jnp.ndarray):
if self.nblist is None:
self.nblist = self.neighborlist_fn.allocate(positions)
else:
self.update(positions)
self.update(positions, box)
return self.nblist

def update(self, positions: jnp.ndarray):
def update(self, positions: jnp.ndarray, box: Optional[jnp.ndarray] = None):
""" A function to update a neighbor list given a new set of positions and a previously allocated neighbor list.

Args:
Expand All @@ -45,7 +49,10 @@ def update(self, positions: jnp.ndarray):
Returns:
jax_md.partition.NeighborList
"""
self.nblist = self.nblist.update(positions)
if box is None:
self.nblist = self.nblist.update(positions)
else:
self.nblist = self.nblist.update(positions, box)
return self.nblist

@property
Expand Down Expand Up @@ -113,4 +120,44 @@ def did_buffer_overflow(self)->bool:
-------
boolen
"""
return self.nblist.did_buffer_overflow
return self.nblist.did_buffer_overflow


class NeighborListFreud:
def __init__(self, box, rcut, cov_map, padding=True):
self.fbox = freud.box.Box.from_matrix(box)
self.rcut = rcut
self.nmax = None
self.padding = padding
self.cov_map = cov_map

def _do_cov_map(self, pairs):
nbond = self.cov_map[pairs[:, 0], pairs[:, 1]]
pairs = jnp.concatenate([pairs, nbond[:, None]], axis=1)
return pairs

def allocate(self, coords, box=None):
fbox = freud.box.Box.from_matrix(box) if box is not None else self.fbox
aq = freud.locality.AABBQuery(fbox, coords)
res = aq.query(coords, dict(r_max=self.rcut, exclude_ii=True))
nlist = res.toNeighborList()
nlist = np.vstack((nlist[:, 0], nlist[:, 1])).T
nlist = nlist.astype(np.int32)
msk = (nlist[:, 0] - nlist[:, 1]) < 0
nlist = nlist[msk]
if self.nmax is None:
self.nmax = int(nlist.shape[0] * 1.3)

if not self.padding:
return self._do_cov_map(nlist)

self.nmax = max(self.nmax, nlist.shape[0])
padding_width = self.nmax - nlist.shape[0]
if padding_width == 0:
return self._do_cov_map(nlist)
elif padding_width > 0:
padding = np.ones((self.nmax - nlist.shape[0], 2), dtype=np.int32) * coords.shape[0]
nlist = np.vstack((nlist, padding))
return self._do_cov_map(nlist)
else:
raise ValueError("padding width < 0")
43 changes: 36 additions & 7 deletions dmff/fftree.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
import xml.etree.ElementTree as ET
import xml.dom.minidom
from dmff.utils import convertStr2Float, DMFFException
from typing import Dict, List, Union, TypeVar
from itertools import permutations
from openmm.app.forcefield import _getDataDirectories


value = TypeVar('value') # generic type: interpreted as either a number or str


class SelectError(BaseException):
pass

Expand Down Expand Up @@ -88,7 +92,7 @@ def get_nodes(self, parser:str)->List[Node]:
val = val[0]
return val

def get_attribs(self, parser:str, attrname:Union[str, List[str]])->List[Union[value, List[value]]]:
def get_attribs(self, parser:str, attrname:Union[str, List[str]], convert_to_float: bool = True)->List[Union[value, List[value]]]:
"""
get all values of attributes of nodes which nodes matching certain path

Expand All @@ -105,6 +109,8 @@ def get_attribs(self, parser:str, attrname:Union[str, List[str]])->List[Union[va
a path to locate nodes
attrname : _type_
attribute name or a list of attribute names of a node
conver_to_float : bool
whether to covert the value of query attrnames to float type

Returns
-------
Expand All @@ -115,11 +121,23 @@ def get_attribs(self, parser:str, attrname:Union[str, List[str]])->List[Union[va
if isinstance(attrname, list):
ret = []
for item in sel:
vals = [convertStr2Float(item.attrs[an]) if an in item.attrs else None for an in attrname]
vals = []
for an in attrname:
if an in item.attrs:
val = convertStr2Float(item.attrs[an]) if convert_to_float else item.attrs[an]
else:
val = None
vals.append(val)
ret.append(vals)
return ret
else:
attrs = [convertStr2Float(n.attrs[attrname]) if attrname in n.attrs else None for n in sel]
attrs = []
for n in sel:
if attrname in n.attrs:
val = convertStr2Float(n.attrs[attrname]) if convert_to_float else n.attrs[attrname]
else:
val = None
attrs.append(val)
return attrs

def set_node(self, parser:str, values:List[Dict[str, value]])->None:
Expand Down Expand Up @@ -182,9 +200,19 @@ def parse_node(self, root):
if children:
node.add_children(children)
return node

def _render_interal_ff_path(self, xml):
rendered_xml = xml
for dataDir in _getDataDirectories():
rendered_xml = os.path.join(dataDir, xml)
if os.path.isfile(rendered_xml):
break
return rendered_xml

def parse(self, *xmls):
for xml in xmls:
if not os.path.isfile(xml):
xml = self._render_interal_ff_path(xml)
root = ET.parse(xml).getroot()
for leaf in root:
n = self.parse_node(leaf)
Expand Down Expand Up @@ -233,8 +261,9 @@ def __init__(self, fftree: ForcefieldTree, parser):
"""
Freeze type matching list.
"""
atypes = fftree.get_attribs("AtomTypes/Type", "name")
aclasses = fftree.get_attribs("AtomTypes/Type", "class")
# not convert to float for atom types
atypes = fftree.get_attribs("AtomTypes/Type", "name", convert_to_float=False)
aclasses = fftree.get_attribs("AtomTypes/Type", "class", convert_to_float=False)
self.class2type = {}
for nline in range(len(atypes)):
if aclasses[nline] not in self.class2type:
Expand All @@ -256,9 +285,9 @@ def __init__(self, fftree: ForcefieldTree, parser):
tmp.append((1, [node.attrs[key]]))
elif len(key) > 5 and "class" == key[:5]:
nit = int(key[5:])
tmp.append((nit, self.class2type[node.attrs[key]]))
tmp.append((nit, self.class2type.get(node.attrs[key], [None])))
elif key == "class":
tmp.append((1, self.class2type[node.attrs[key]]))
tmp.append((1, self.class2type.get(node.attrs[key], [None])))
tmp = sorted(tmp, key=lambda x: x[0])
self.functions.append([i[1] for i in tmp])

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ mkdocs-gen-files>=0.3.4
mkdocs-literate-nav>=0.4.1
mkdocstrings>=0.19.0
mkdocstrings-python>=0.7.0
pygments>=2.12
pygments>=2.12