Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions dmff/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,20 @@
import jax.numpy as jnp
from collections import defaultdict
import xml.etree.ElementTree as ET

from dmff.utils import isinstance_jnp
from .admp.disp_pme import ADMPDispPmeForce
from .admp.multipole import convert_cart2harm, rot_local2global
from .admp.multipole import convert_cart2harm
from .admp.pairwise import TT_damping_qq_c6_kernel, generate_pairwise_interaction
from .admp.pairwise import slater_disp_damping_kernel, slater_sr_kernel, TT_damping_qq_kernel
from .admp.pme import ADMPPmeForce
from .admp.spatial import generate_construct_local_frames
from .admp.recip import Ck_1, generate_pme_recip
from .utils import jit_condition
from .classical.intra import (
HarmonicBondJaxForce,
HarmonicAngleJaxForce,
PeriodicTorsionJaxForce,
)
from jax_md import space, partition
from jax import grad, vmap
from jax import grad
import linecache
import itertools
from .classical.inter import (
Expand All @@ -30,7 +29,6 @@
CoulNoCutoffForce,
CoulReactionFieldForce,
)

import sys


Expand Down Expand Up @@ -211,7 +209,7 @@ def renderXML(self):


class ADMPDispPmeGenerator:
'''
r'''
This one computes the undamped C6/C8/C10 interactions
u = \sum_{ij} c6/r^6 + c8/r^8 + c10/r^10
'''
Expand Down Expand Up @@ -316,7 +314,7 @@ def renderXML(self):
app.forcefield.parsers["ADMPDispPmeForce"] = ADMPDispPmeGenerator.parseElement

class QqTtDampingGenerator:
'''
r'''
This one calculates the tang-tonnies damping of charge-charge interaction
E = \sum_ij exp(-B*r)*(1+B*r)*q_i*q_j/r
'''
Expand Down Expand Up @@ -392,7 +390,7 @@ def renderXML(self):


class SlaterDampingGenerator:
'''
r'''
This one computes the slater-type damping function for c6/c8/c10 dispersion
E = \sum_ij (f6-1)*c6/r6 + (f8-1)*c8/r8 + (f10-1)*c10/r10
fn = f_tt(x, n)
Expand Down Expand Up @@ -474,7 +472,7 @@ def renderXML(self):


class SlaterExGenerator:
'''
r'''
This one computes the Slater-ISA type exchange interaction
u = \sum_ij A * (1/3*(Br)^2 + Br + 1)
'''
Expand Down Expand Up @@ -633,7 +631,7 @@ def registerAtomType(self, atom: dict):
@staticmethod
def parseElement(element, hamiltonian):

""" parse admp related parameters in XML file
r""" parse admp related parameters in XML file

example:

Expand Down Expand Up @@ -1068,7 +1066,7 @@ def registerBondType(self, bond):
@staticmethod
def parseElement(element, hamiltonian):

"""parse <HarmonicBondForce> section in XML file
r"""parse <HarmonicBondForce> section in XML file

example:

Expand Down Expand Up @@ -1160,7 +1158,7 @@ def registerAngleType(self, angle):

@staticmethod
def parseElement(element, hamiltonian):
""" parse <HarmonicAngleForce> section in XML file
r""" parse <HarmonicAngleForce> section in XML file

example:
<HarmonicAngleForce>
Expand Down Expand Up @@ -1988,7 +1986,12 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
coulenergy = coulforce.generate_get_energy()

def potential_fn(positions, box, pairs, params):


# check whether args passed into potential_fn are jnp.array and differentiable
# note this check will be optimized away by jit
# it is jit-compatiable
isinstance_jnp(positions, box, params)

ljE = ljenergy(
positions,
box,
Expand Down
32 changes: 22 additions & 10 deletions dmff/classical/inter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dmff.admp.pairwise import distribute_scalar
from dmff.utils import pair_buffer_scales, regularize_pairs
import jax.numpy as jnp
from dmff.admp.pme import energy_pme, setup_ewald_parameters
from dmff.admp.recip import generate_pme_recip
Expand Down Expand Up @@ -28,7 +28,7 @@ def __init__(
self.r_switch = r_switch
self.r_cut = r_cut

self.map_prm = map_prm
self.map_prm = jnp.array(map_prm)
self.map_nbfix = map_nbfix
self.ifPBC = isPBC
self.ifNoCut = isNoCut
Expand Down Expand Up @@ -59,6 +59,10 @@ def get_LJ_energy(dr_vec, sig, eps, box):
return E

def get_energy(positions, box, pairs, epsilon, sigma, epsfix, sigfix, mscales):

pairs = regularize_pairs(pairs)
mask = pair_buffer_scales(pairs)
map_prm = self.map_prm

eps_m1 = jnp.repeat(epsilon.reshape((-1, 1)), epsilon.shape[0], axis=1)
eps_m2 = eps_m1.T
Expand All @@ -77,16 +81,16 @@ def get_energy(positions, box, pairs, epsilon, sigma, epsfix, sigfix, mscales):


dr_vec = positions[pairs[:, 0]] - positions[pairs[:, 1]]
prm_pair0 = self.map_prm[pairs[:, 0]]
prm_pair1 = self.map_prm[pairs[:, 1]]
prm_pair0 = map_prm[pairs[:, 0]]
prm_pair1 = map_prm[pairs[:, 1]]
eps = eps_mat[prm_pair0, prm_pair1]
sig = sig_mat[prm_pair0, prm_pair1]

eps_scale = eps * mscale_pair

E_inter = get_LJ_energy(dr_vec, sig, eps_scale, box)

return jnp.sum(E_inter)
return jnp.sum(E_inter * mask)

return get_energy

Expand All @@ -110,12 +114,16 @@ def get_coul_energy(dr_vec, chrgprod, box):
return E

def get_energy(positions, box, pairs, charges, mscales):

pairs = regularize_pairs(pairs)
mask = pair_buffer_scales(pairs)
map_prm = jnp.array(self.map_prm)

colv_pair = self.colvmap[pairs[:,0],pairs[:,1]]
mscale_pair = mscales[colv_pair-1]

chrg_map0 = self.map_prm[pairs[:, 0]]
chrg_map1 = self.map_prm[pairs[:, 1]]
chrg_map0 = map_prm[pairs[:, 0]]
chrg_map1 = map_prm[pairs[:, 1]]
charge0 = charges[chrg_map0]
charge1 = charges[chrg_map1]
chrgprod = charge0 * charge1
Expand All @@ -124,7 +132,7 @@ def get_energy(positions, box, pairs, charges, mscales):

E_inter = get_coul_energy(dr_vec, chrgprod_scale, box)

return jnp.sum(E_inter)
return jnp.sum(E_inter * mask)

return get_energy

Expand Down Expand Up @@ -169,6 +177,9 @@ def get_rf_energy(dr_vec, chrgprod, box):
return E

def get_energy(positions, box, pairs, charges, mscales):

pairs = regularize_pairs(pairs)
mask = pair_buffer_scales(pairs)

colv_pair = self.colvmap[pairs[:,0],pairs[:,1]]
mscale_pair = mscales[colv_pair-1]
Expand All @@ -183,7 +194,7 @@ def get_energy(positions, box, pairs, charges, mscales):

E_inter = get_rf_energy(dr_vec, chrgprod_scale, box)

return jnp.sum(E_inter)
return jnp.sum(E_inter * mask)

return get_energy

Expand All @@ -198,7 +209,8 @@ def __init__(self, box, rc, ethresh, covalent_map):

def generate_get_energy(self):
def get_energy(positions, box, pairs, Q, mScales):

# Not required regularize_pairs
# already done in the pme code
return energy_pme(
positions,
box,
Expand Down
14 changes: 13 additions & 1 deletion dmff/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dmff.settings import DO_JIT
from jax import jit, vmap
from jax import jit, vmap, tree_util
import jax.numpy as jnp

def jit_condition(*args, **kwargs):
Expand Down Expand Up @@ -28,3 +28,15 @@ def pair_buffer_scales(p):
p[0] - p[1],
(p[0] - p[1] < 0, p[0] - p[1] >= 0),
(lambda x: jnp.array(1), lambda x: jnp.array(0)))


def isinstance_jnp(*args):

def _check(arg):
if not isinstance(arg, jnp.ndarray):
raise TypeError('all arguments must be jnp.array, \
otherwise they won\'t be able to take derivatives \
on these variables from outside of potential_fn anyway')

for arg in args:
tree_util.tree_map(lambda arg: _check(arg), args[0])
3 changes: 2 additions & 1 deletion docs/dev_guide/arch.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ returned to users:
```python

def potential_fn(positions, box, pairs, params):
isinstance_jnp(positions, box, params)
return bforce.get_energy(
positions, box, pairs, params["k"], params["length"]
)
Expand All @@ -145,7 +146,7 @@ self._jaxPotential = potential_fn
```

The `potential_fn` function only takes `(positions, box, pairs, params)` as explicit input arguments. All these arguments except
`pairs` (neighbor list) should be differentiable. Non differentiable parameters are passed into it by closure (see code convention section).
`pairs` (neighbor list) should be differentiable. A helper function `isinstance_jnp` in `utils.py` can check take-in arguments whether they are `jnp.array`. Non differentiable parameters are passed into it by closure (see code convention section).
Meanwhile, if the generator needs to initialize multiple calculators (e.g. `NonBondedJaxGenerator` will call both `LJ` and `PME` calculators),
`potential_fn` should return the summation of the results of all calculators.

Expand Down
Loading