From 9cb01be5ff043e27da90712b7dafe3145a8c9060 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Mon, 9 May 2022 21:34:42 +0800 Subject: [PATCH] add nblist wrapper and its docs --- dmff/__init__.py | 3 +- dmff/common/__init__.py | 0 dmff/common/nblist.py | 95 ++++++++++++++++++ docs/dev_guide/arch.md | 18 +++- docs/user_guide/tutorial.md | 193 +++++++++++++++++++++++++++++++++++- tests/test_nblist.py | 50 ++++++++++ 6 files changed, 352 insertions(+), 7 deletions(-) create mode 100644 dmff/common/__init__.py create mode 100644 dmff/common/nblist.py create mode 100644 tests/test_nblist.py diff --git a/dmff/__init__.py b/dmff/__init__.py index 19eacb97d..e44570345 100644 --- a/dmff/__init__.py +++ b/dmff/__init__.py @@ -1 +1,2 @@ -import dmff.settings \ No newline at end of file +import dmff.settings +from dmff.common.nblist import NeighborList \ No newline at end of file diff --git a/dmff/common/__init__.py b/dmff/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dmff/common/nblist.py b/dmff/common/nblist.py new file mode 100644 index 000000000..74adcd967 --- /dev/null +++ b/dmff/common/nblist.py @@ -0,0 +1,95 @@ +from jax_md import space, partition +import jax.numpy as jnp +from dmff.utils import regularize_pairs +import jax.numpy as jnp + +class NeighborList: + + def __init__(self, box, rc) -> None: + """ wrapper of jax_md.space_periodic_general and jax_md.partition.NeighborList + + Args: + box (jnp.ndarray): A (spatial_dim, spatial_dim) affine transformation or [lx, ly, lz] vector + rc (float): cutoff radius + """ + self.box = box + self.rc = rc + self.displacement_fn, self.shift_fn = space.periodic_general(box, fractional_coordinates=False) + self.neighborlist_fn = partition.neighbor_list(self.displacement_fn, box, rc, 0, format=partition.OrderedSparse) + + def allocate(self, positions: jnp.ndarray): + """ A function to allocate a new neighbor list. This function cannot be compiled, since it uses the values of positions to infer the shapes. + + Args: + positions (jnp.ndarray): particle positions + + Returns: + jax_md.partition.NeighborList + """ + self.nblist = self.neighborlist_fn.allocate(positions) + return self.nblist + + def update(self, positions: jnp.ndarray): + """ A function to update a neighbor list given a new set of positions and a previously allocated neighbor list. + + Args: + positions (jnp.ndarray): particle positions + + Returns: + jax_md.partition.NeighborList + """ + self.nblist.update(positions) + + return self.nblist + + @property + def pairs(self): + """ get raw pair index + + Returns: + jnp.ndarray: (nPairs, 2) + """ + return self.nblist.idx.T + + @property + def pair_mask(self): + """ get regularized pair index and mask + + Returns: + (jnp.ndarray, jnp.ndarray): ((nParis, 2), (nPairs, )) + """ + + mask = jnp.sum(self.pairs == len(self.positions), axis=1) + mask = jnp.logical_not(mask) + pair = regularize_pairs(self.pairs) + + return pair, mask + + @property + def positions(self): + """ get current positions in current neighborlist + + Returns: + jnp.ndarray: (n, 3) + """ + return self.nblist.reference_position + + @property + def dr(self): + """ get pair distance vector in current neighborlist + + Returns: + jnp.ndarray: (nPairs, 3) + """ + pair, _ = self.pair_mask + return self.positions[pair[:, 0]] - self.positions[pair[:, 1]] + + @property + def distance(self): + """ get pair distance in current neighborlist + + Returns: + jnp.ndarray: (nPairs, ) + + """ + return jnp.linalg.norm(self.dr, axis=1) \ No newline at end of file diff --git a/docs/dev_guide/arch.md b/docs/dev_guide/arch.md index 0204000df..6d86e88dd 100644 --- a/docs/dev_guide/arch.md +++ b/docs/dev_guide/arch.md @@ -302,7 +302,6 @@ the Force class of harmonic bond potential is shown below as an example. def distance(p1v, p2v): return jnp.sqrt(jnp.sum(jnp.power(p1v - p2v, 2), axis=1)) - class HarmonicBondJaxForce: def __init__(self, p1idx, p2idx, prmidx): self.p1idx = p1idx @@ -312,12 +311,17 @@ class HarmonicBondJaxForce: def generate_get_energy(self): def get_energy(positions, box, pairs, k, length): + + # NOTE: pairs array from jax-md has invalid index + pairs = regularize_pairs(pairs) + buffer_scales = pair_buffer_scales(pairs) + p1 = positions[self.p1idx] p2 = positions[self.p2idx] kprm = k[self.prmidx] b0prm = length[self.prmidx] dist = distance(p1, p2) - return jnp.sum(0.5 * kprm * jnp.power(dist - b0prm, 2)) + return jnp.sum(0.5 * kprm * jnp.power(dist - b0prm, 2) * buffer_scales) # mask invalid pairs return get_energy @@ -336,7 +340,7 @@ class HarmonicBondJaxForce: self.get_forces = value_and_grad(self.get_energy) ``` -The design logic for the `Force` class is: it saves the *static* variables inside the class as +The design logic for the `Force` class is: it saves the *static* variables inside the class as the *environment* of the real calculators. Examples of the static environment variables include: the $\kappa$ and $K_{max}$ in PME calculators, the covalent_map in real-space calculators etc. For a typical `Force` class, one needs to define the following methods: @@ -350,3 +354,11 @@ For a typical `Force` class, one needs to define the following methods: In ADMP, all backend calculators only take atomic parameters as input, so they can be invoked independently in hybrid ML/force field models. The dispatch of force field parameters is done in the `potential_fn` function defined in the frontend. + +Please note that the `pairs` array accepted by `get_energy` potential compute kernel is **directly** construct from `jax-md`'s neighborList. +To keep the shape of array neat and tidy, prevent JIT the code every time `get_genergy` is called, the `pairs` array is padded. It has +some invalid index in the padding area, say, those `pair_index==len(positions)` is invalid padding pair index. That means there are many +`[len(positions), len(positions)]` pairs in the `pairs` array, resulting in the distance equl to 0. The solution is we first call `regularize_pairs` +helper function to replace `[len(positions), len(positions)]` with `[len(positions)-2, len(positions)-1]`, so the distance is always non-zeros. Due +to we compute additional invalid pairs, we need to compute a `buffer_scales` to mask out those invalid pairs. We need to use `pair_buffer_scales(pairs)` +to get the mask, and apply it in the pair energy array before we sum it up. diff --git a/docs/user_guide/tutorial.md b/docs/user_guide/tutorial.md index c48fab261..9e4508658 100644 --- a/docs/user_guide/tutorial.md +++ b/docs/user_guide/tutorial.md @@ -1,7 +1,194 @@ # Tutorial -## install DMFF +## Write XML -## compute energy and force +DMFF uses a simple XML file to describe force fields. Let us take an example of writing a DMFF XML file using the classical force field to calculate the water molecule system. -## auto differentiation \ No newline at end of file +Support we treat the water molecule as a three-body molecule. Within the molecule, we need harmonic interaction to describe the bonded interaction and harmonic angle potential. Between molecules, the interactions between atoms are expressed through the Lennard-jones potential. + +Let us create a new file called `forcefield.xml`. The root element of the XML file must be a `` tag: + +``` + +... + +``` + +The `` tag contains the following children: + +- An `` tag containing the atom type definitions + +- A `` tag containing the residue template definitions + +- Zero or more tags defining specific forces + +The order of these tags does not matter. They are described in detail below. + +`` defines atom type in the System. In this case, we have two types of atom: + + +``` + + + + +``` + +Each `` tag in this section represents a type of atom. It specifies the name of the type, the class it belongs to, the symbol for its element, and its mass. The names are arbitrary strings: they need not be numbers, as in this example. The only requirement is that all types have unique names. The classes are also arbitrary strings and in general will not be unique. If they list the same value for the class attribute, two types belong to the same class. + +The residue template definitions look like this: + +``` + + + + + + + + + +``` + +`` template contains the following tags: + +- An `` tag for each atom in the residue. This specifies the name of the atom and its atom type. + +- A `` tag for each pair of atoms that are bonded to each other. The atomName1 and atomName2 attributes are the names of the two bonded atoms. (Some older force fields use the alternate tags to and from to specify the atoms by index instead of name. This is still supported for backward compatibility, but specifying atoms by name is recommended since it makes the residue definition much easier to understand.) + +The `` tag may also contain `` tags, as in the following example: + + +``` + + + + + + + + + +``` + +Each `` tag indicates an atom in the residue that should be represented with a virtual site. The type attribute may equal "average2", "average3", "outOfPlane", or "localCoords", which correspond to the TwoParticleAverageSite, ThreeParticleAverageSite, OutOfPlaneSite, and LocalCoordinatesSite classes respectively. The siteName attribute gives the name of the atom to represent with a virtual site. The atoms it is calculated based on are specified by atomName1, atomName2, etc. (Some old force fields use the deprecated tags index, atom1, atom2, etc. to refer to them by index instead of name.) + +The remaining attributes are specific to the virtual site class and specify the parameters for calculating the site position. For a TwoParticleAverageSite, they are weight1 and weight2. For a ThreeParticleAverageSite, they are weight1, weight2, and weight3. For an OutOfPlaneSite, they are weight12, weight13, and weightCross. For a LocalCoordinatesSite, they are p1, p2, and p3 (giving the x, y, and z coordinates of the site position in the local coordinate System), and wo1, wx1, wy1, wo2, wx2, wy2, … (giving the weights for computing the origin, x-axis, and y-axis). + +Next, to add a HarmonicBondForce to the System, include a tag that looks like this: + +``` + + + +``` + +Every `` tag defines a rule for creating harmonic bond interactions between atoms. Each tag may identify the atoms either by type (using the attributes type1 and type2) or by class (using the attributes class1 and class2). For every pair of bonded atoms, the force field searches for a rule whose atom types or atom classes match the two atoms. If it finds one, it calls addBond() on the HarmonicBondForce with the specified parameters. Otherwise, it ignores that pair and continues. length is the equilibrium bond length in nm, and k is the spring constant in kJ/mol/nm2. + +To add a HarmonicAngleForce to the System, include a tag that looks like this: + +``` + + + +``` + +Every `` tag defines a rule for creating harmonic angle interactions between triplets of atoms. Each tag may identify the atoms either by type (using the attributes type1, type2, …) or by class (using the attributes class1, class2, …). The force field identifies every set of three atoms in the System where the first is bonded to the second, and the second to the third. For each one, it searches for a rule whose atom types or atom classes match the three atoms. If it finds one, it calls addAngle() on the HarmonicAngleForce with the specified parameters. Otherwise, it ignores that set and continues. angle is the equilibrium angle in radians, and k is the spring constant in kJ/mol/radian2. + +To add a NonbondedForce to the System, include a tag that looks like this: + +``` + + + + + +``` + +Each `` tag specifies the OBC parameters for one atom type (specified with the type attribute) or atom class (specified with the class attribute). It is fine to mix these two methods, having some tags specify a type and others specify a class. However you do it, you must make sure that a unique set of parameters is defined for every atom type. charge is measured in units of the proton charge, radius is the GBSA radius in nm, and scale is the OBC scaling factor. + +This is what we should do to describe a simple system with a classical force field. + +## Write a run script + +We already have a XML file to describe our System, now we need to write a python script to calculate energy and force. + +First, we need to parse PDB file + +``` +import openmm.app as app +pdb = app.PDBFile('/path/to/pdb') +positions = jnp.array(pdb.positions._value) +a, b, c = pdb.topology.getPeriodicBoxVectors() +box = jnp.array([a._value, b._value, c._value]) +``` + +Second, a `Hamiltonian` class should be initialized with XML file path + +``` +from dmff.api import Hamiltonian +H = Hamiltonian('forcefield.xml') +rc = 4.0 # cutoff +system = H.createPotential(pdb.topology, nonbondedCutoff=rc) +``` + +The `Hamiltonian` class will parse tags in XML file and invoke corresponding potential functions. We can access those potentials in this way: + +``` +bondE = H._potentials[0] +angleE = H._potentials[1] +nonBondE = H._potentials[2] +``` + +> Note: only when the `createPotential` method is called can potentials be obtained + +Next, we need to construct neighbor list. Here we use the code from `jax_md`: + +``` +from jax_md import space, partition +displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=False) +neighbor_list_fn = partition.neighbor_list(displacement_fn, box, rc, 0, format=partition.OrderedSparse) +nbr = neighbor_list_fn.allocate(positions) +pairs = nbr.idx.T +``` + +Also, we provide a wrapper to simplify neighborList construction: + +``` +from dmff import NeighborList +nblist = NeighborList(box, rc) +nblist.allocate(positions) +pairs = nblist.pairs # equivalent to nbr.idx.T +distance = nblist.distance # distance between pairs +dr = nblist.dr # distance vector + +``` + +`pairs` is a `(N, 2)` shape array, which indicates the index of atom i and atom j. ATTENTION: pairs array contains many **invalid** index. For example, in this case, we only have 6 atoms and pairs' shape maybe `(18, 2)`. And even there are three `[6, 6]` pairs which are obviously out of range. Because `jax-md` takes advantage of the feature of Jax.numpy, which will not throw an error when the index out of range, and return the [last element](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#out-of-bounds-indexing). + +Finally, we can calculate energy and force using the aforementioned potential: + +``` +print("Bond:", value_and_grad(bondE)(positions, box, pairs, H.getGenerators()[0].params)) +print("Angle:", value_and_grad(angleE)(positions, box, pairs, H.getGenerators()[1].params)) +print('NonBonded:', value_and_grad(nonBondE)(positions, box, pairs, H.getGenerators()[2].params)) +``` + +also, we can write a simple gradient descent to optimize parameters: + +``` +import optax +# start to do optmization +lr = 0.001 +optimizer = optax.adam(lr) +opt_state = optimizer.init(params) + +n_epochs = 1000 +for i_epoch in range(n_epochs): + loss, grads = value_and_grad(bondE, argnums=(0))(params, data[sid]) + updates, opt_state = optimizer.update(grads, opt_state) + params = optax.apply_updates(params, updates) + with open('params.pickle', 'wb') as ofile: + pickle.dump(params, ofile) +``` diff --git a/tests/test_nblist.py b/tests/test_nblist.py new file mode 100644 index 000000000..a8a04724b --- /dev/null +++ b/tests/test_nblist.py @@ -0,0 +1,50 @@ +import pytest +import jax.numpy as jnp +from dmff import NeighborList + +class TestNeighborList: + + @pytest.fixture(scope="class", name='nblist') + def test_nblist_init(self): + positions = jnp.array([[12.434, 3.404, 1.540], + [13.030, 2.664, 1.322], + [12.312, 3.814, 0.660], + [14.216, 1.424, 1.103], + [14.246, 1.144, 2.054], + [15.155, 1.542, 0.910]]) + + box = jnp.array([31.289, 31.289, 31.289]) + r_cutoff = 4.0 + nbobj = NeighborList(box, r_cutoff) + nbobj.allocate(positions) + yield nbobj + + def test_update(self, nblist): + + positions = jnp.array([[12.434, 3.404, 1.540], + [13.030, 2.664, 1.322], + [12.312, 3.814, 0.660], + [14.216, 1.424, 1.103], + [14.246, 1.144, 2.054], + [15.155, 1.542, 0.910]]) + + nblist.update(positions) + + def test_pairs(self, nblist): + + pairs = nblist.pairs + assert pairs.shape == (15, 2) + + def test_pair_mask(self, nblist): + + pair, mask = nblist.pair_mask + assert mask.shape == (15, ) + + def test_dr(self, nblist): + + dr = nblist.dr + assert dr.shape == (15, 3) + + def test_distance(self, nblist): + + assert nblist.distance.shape == (15, )