diff --git a/dmff/admp/disp_pme.py b/dmff/admp/disp_pme.py
index 3df7d11d6..11c30eab0 100755
--- a/dmff/admp/disp_pme.py
+++ b/dmff/admp/disp_pme.py
@@ -1,6 +1,6 @@
import jax.numpy as jnp
from jax import vmap, value_and_grad
-from dmff.utils import jit_condition
+from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales
from dmff.admp.spatial import pbc_shift
from dmff.admp.pme import setup_ewald_parameters
from dmff.admp.recip import generate_pme_recip, Ck_6, Ck_8, Ck_10
@@ -14,17 +14,24 @@ class ADMPDispPmeForce:
The so called "environment paramters" means parameters that do not need to be differentiable
'''
- def __init__(self, box, covalent_map, rc, ethresh, pmax):
+ def __init__(self, box, covalent_map, rc, ethresh, pmax, lpme=True):
self.covalent_map = covalent_map
self.rc = rc
self.ethresh = ethresh
self.pmax = pmax
# Need a different function for dispersion ??? Need tests
- kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
- self.kappa = kappa
- self.K1 = K1
- self.K2 = K2
- self.K3 = K3
+ self.lpme = lpme
+ if lpme:
+ kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
+ self.kappa = kappa
+ self.K1 = K1
+ self.K2 = K2
+ self.K3 = K3
+ else:
+ self.kappa = 0.0
+ self.K1 = 0
+ self.K2 = 0
+ self.K3 = 0
self.pme_order = 6
# setup calculators
self.refresh_calculators()
@@ -36,7 +43,7 @@ def get_energy(positions, box, pairs, c_list, mScales):
return energy_disp_pme(positions, box, pairs,
c_list, mScales, self.covalent_map,
self.kappa, self.K1, self.K2, self.K3, self.pmax,
- self.d6_recip, self.d8_recip, self.d10_recip)
+ self.d6_recip, self.d8_recip, self.d10_recip, lpme=self.lpme)
return get_energy
@@ -70,7 +77,7 @@ def refresh_calculators(self):
def energy_disp_pme(positions, box, pairs,
c_list, mScales, covalent_map,
kappa, K1, K2, K3, pmax,
- recip_fn6, recip_fn8, recip_fn10):
+ recip_fn6, recip_fn8, recip_fn10, lpme=True):
'''
Top level wrapper for dispersion pme
@@ -95,22 +102,29 @@ def energy_disp_pme(positions, box, pairs,
int: max K for reciprocal calculations
pmax:
int array: maximal exponents (p) to compute, e.g., (6, 8, 10)
+ lpme:
+ bool: whether do pme or not, useful when doing cluster calculations
Output:
energy: total dispersion pme energy
'''
- ene_real = disp_pme_real(positions, box, pairs, c_list, mScales, covalent_map, kappa, pmax)
+ if lpme is False:
+ kappa = 0
- ene_recip = recip_fn6(positions, box, c_list[:, 0, jnp.newaxis])
- if pmax >= 8:
- ene_recip += recip_fn8(positions, box, c_list[:, 1, jnp.newaxis])
- if pmax >= 10:
- ene_recip += recip_fn10(positions, box, c_list[:, 2, jnp.newaxis])
+ ene_real = disp_pme_real(positions, box, pairs, c_list, mScales, covalent_map, kappa, pmax)
- ene_self = disp_pme_self(c_list, kappa, pmax)
+ if lpme:
+ ene_recip = recip_fn6(positions, box, c_list[:, 0, jnp.newaxis])
+ if pmax >= 8:
+ ene_recip += recip_fn8(positions, box, c_list[:, 1, jnp.newaxis])
+ if pmax >= 10:
+ ene_recip += recip_fn10(positions, box, c_list[:, 2, jnp.newaxis])
+ ene_self = disp_pme_self(c_list, kappa, pmax)
+ return ene_real + ene_recip + ene_self
- return ene_real + ene_recip + ene_self
+ else:
+ return ene_real
def disp_pme_real(positions, box, pairs,
@@ -144,24 +158,26 @@ def disp_pme_real(positions, box, pairs,
'''
# expand pairwise parameters
- pairs = pairs[pairs[:, 0] < pairs[:, 1]]
+ # pairs = pairs[pairs[:, 0] < pairs[:, 1]]
+ pairs = regularize_pairs(pairs)
box_inv = jnp.linalg.inv(box)
ri = distribute_v3(positions, pairs[:, 0])
rj = distribute_v3(positions, pairs[:, 1])
- # ri = positions[pairs[:, 0]]
- # rj = positions[pairs[:, 1]]
nbonds = covalent_map[pairs[:, 0], pairs[:, 1]]
mscales = distribute_scalar(mScales, nbonds-1)
- # mscales = mScales[nbonds-1]
+
+ buffer_scales = pair_buffer_scales(pairs)
+ mscales = mscales * buffer_scales
ci = distribute_dispcoeff(c_list, pairs[:, 0])
cj = distribute_dispcoeff(c_list, pairs[:, 1])
- # ci = c_list[pairs[:, 0], :]
- # cj = c_list[pairs[:, 1], :]
- ene_real = jnp.sum(disp_pme_real_kernel(ri, rj, ci, cj, box, box_inv, mscales, kappa, pmax))
+ ene_real = jnp.sum(
+ disp_pme_real_kernel(ri, rj, ci, cj, box, box_inv, mscales, kappa, pmax)
+ * buffer_scales
+ )
return jnp.sum(ene_real)
@@ -193,6 +209,7 @@ def disp_pme_real_kernel(ri, rj, ci, cj, box, box_inv, mscales, kappa, pmax):
dr = ri - rj
dr = pbc_shift(dr, box, box_inv)
dr2 = jnp.dot(dr, dr)
+
x2 = kappa * kappa * dr2
g = g_p(x2, pmax)
dr6 = dr2 * dr2 * dr2
@@ -269,85 +286,3 @@ def disp_pme_self(c_list, kappa, pmax):
return E
-# def validation(pdb):
-# xml = 'mpidwater.xml'
-# pdbinfo = read_pdb(pdb)
-# serials = pdbinfo['serials']
-# names = pdbinfo['names']
-# resNames = pdbinfo['resNames']
-# resSeqs = pdbinfo['resSeqs']
-# positions = pdbinfo['positions']
-# box = pdbinfo['box'] # a, b, c, α, β, γ
-# charges = pdbinfo['charges']
-# positions = jnp.asarray(positions)
-# lx, ly, lz, _, _, _ = box
-# box = jnp.eye(3)*jnp.array([lx, ly, lz])
-
-# mScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
-# pScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
-# dScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
-
-# rc = 4 # in Angstrom
-# ethresh = 1e-4
-
-# n_atoms = len(serials)
-
-# atomTemplate, residueTemplate = read_xml(xml)
-# atomDicts, residueDicts = init_residues(serials, names, resNames, resSeqs, positions, charges, atomTemplate, residueTemplate)
-
-# covalent_map = assemble_covalent(residueDicts, n_atoms)
-# displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=False)
-# neighbor_list_fn = partition.neighbor_list(displacement_fn, box, rc, 0, format=partition.OrderedSparse)
-# nbr = neighbor_list_fn.allocate(positions)
-# pairs = nbr.idx.T
-
-# pmax = 10
-# kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
-# kappa = 0.657065221219616
-
-# # construct the C list
-# c_list = np.zeros((3,n_atoms))
-# nmol=int(n_atoms/3)
-# for i in range(nmol):
-# a = i*3
-# b = i*3+1
-# c = i*3+2
-# c_list[0][a]=37.19677405
-# c_list[0][b]=7.6111103
-# c_list[0][c]=7.6111103
-# c_list[1][a]=85.26810658
-# c_list[1][b]=11.90220148
-# c_list[1][c]=11.90220148
-# c_list[2][a]=134.44874488
-# c_list[2][b]=15.05074749
-# c_list[2][c]=15.05074749
-# c_list = jnp.array(c_list.T)
-
-
-# # Finish data preparation
-# # -------------------------------------------------------------------------------------
-# # pme_order = 6
-# # d6_recip = generate_pme_recip(Ck_6, kappa, True, pme_order, K1, K2, K3, 0)
-# # d8_recip = generate_pme_recip(Ck_8, kappa, True, pme_order, K1, K2, K3, 0)
-# # d10_recip = generate_pme_recip(Ck_10, kappa, True, pme_order, K1, K2, K3, 0)
-# # disp_pme_recip_fns = [d6_recip, d8_recip, d10_recip]
-# # energy_force_disp_pme = value_and_grad(energy_disp_pme)
-# # e, f = energy_force_disp_pme(positions, box, pairs, c_list, mScales, covalent_map, kappa, K1, K2, K3, pmax, *disp_pme_recip_fns)
-# # print('ok')
-# # e, f = energy_force_disp_pme(positions, box, pairs, c_list, mScales, covalent_map, kappa, K1, K2, K3, pmax, *disp_pme_recip_fns)
-# # print(e)
-
-# disp_pme_force = ADMPDispPmeForce(box, covalent_map, rc, ethresh, pmax)
-# disp_pme_force.update_env('kappa', 0.657065221219616)
-
-# print(c_list[:4])
-# E, F = disp_pme_force.get_forces(positions, box, pairs, c_list, mScales)
-# print('ok')
-# E, F = disp_pme_force.get_forces(positions, box, pairs, c_list, mScales)
-# print(E)
-# return
-
-
-# # below is the validation code
-# if __name__ == '__main__':
-# validation(sys.argv[1])
diff --git a/dmff/admp/mbpol_intra.py b/dmff/admp/mbpol_intra.py
new file mode 100755
index 000000000..7b1ffff56
--- /dev/null
+++ b/dmff/admp/mbpol_intra.py
@@ -0,0 +1,529 @@
+
+import sys
+import numpy as np
+import jax.numpy as jnp
+from jax import grad, value_and_grad
+from dmff.settings import DO_JIT
+from dmff.utils import jit_condition
+from dmff.admp.spatial import v_pbc_shift
+from dmff.admp.pme import ADMPPmeForce
+from dmff.admp.parser import *
+from jax import vmap
+import time
+#from admp.multipole import convert_cart2harm
+#from jax_md import partition, space
+
+#const
+f5z = 0.999677885
+fbasis = 0.15860145369897
+fcore = -1.6351695982132
+frest = 1.0
+reoh = 0.958649;
+thetae = 104.3475;
+b1 = 2.0;
+roh = 0.9519607159623009;
+alphaoh = 2.587949757553683;
+deohA = 42290.92019288289;
+phh1A = 16.94879431193463;
+phh2 = 12.66426998162947;
+
+c5zA = jnp.array([4.2278462684916e+04, 4.5859382909906e-02, 9.4804986183058e+03,
+ 7.5485566680955e+02, 1.9865052511496e+03, 4.3768071560862e+02,
+ 1.4466054104131e+03, 1.3591924557890e+02,-1.4299027252645e+03,
+ 6.6966329416373e+02, 3.8065088734195e+03,-5.0582552618154e+02,
+ -3.2067534385604e+03, 6.9673382568135e+02, 1.6789085874578e+03,
+ -3.5387509130093e+03,-1.2902326455736e+04,-6.4271125232353e+03,
+ -6.9346876863641e+03,-4.9765266152649e+02,-3.4380943579627e+03,
+ 3.9925274973255e+03,-1.2703668547457e+04,-1.5831591056092e+04,
+ 2.9431777405339e+04, 2.5071411925779e+04,-4.8518811956397e+04,
+ -1.4430705306580e+04, 2.5844109323395e+04,-2.3371683301770e+03,
+ 1.2333872678202e+04, 6.6525207018832e+03,-2.0884209672231e+03,
+ -6.3008463062877e+03, 4.2548148298119e+04, 2.1561445953347e+04,
+ -1.5517277060400e+05, 2.9277086555691e+04, 2.6154026873478e+05,
+ -1.3093666159230e+05,-1.6260425387088e+05, 1.2311652217133e+05,
+ -5.1764697159603e+04, 2.5287599662992e+03, 3.0114701659513e+04,
+ -2.0580084492150e+03, 3.3617940269402e+04, 1.3503379582016e+04,
+ -1.0401149481887e+05,-6.3248258344140e+04, 2.4576697811922e+05,
+ 8.9685253338525e+04,-2.3910076031416e+05,-6.5265145723160e+04,
+ 8.9184290973880e+04,-8.0850272976101e+03,-3.1054961140464e+04,
+ -1.3684354599285e+04, 9.3754012976495e+03,-7.4676475789329e+04,
+ -1.8122270942076e+05, 2.6987309391410e+05, 4.0582251904706e+05,
+ -4.7103517814752e+05,-3.6115503974010e+05, 3.2284775325099e+05,
+ 1.3264691929787e+04, 1.8025253924335e+05,-1.2235925565102e+04,
+ -9.1363898120735e+03,-4.1294242946858e+04,-3.4995730900098e+04,
+ 3.1769893347165e+05, 2.8395605362570e+05,-1.0784536354219e+06,
+ -5.9451106980882e+05, 1.5215430060937e+06, 4.5943167339298e+05,
+ -7.9957883936866e+05,-9.2432840622294e+04, 5.5825423140341e+03,
+ 3.0673594098716e+03, 8.7439532014842e+04, 1.9113438435651e+05,
+ -3.4306742659939e+05,-3.0711488132651e+05, 6.2118702580693e+05,
+ -1.5805976377422e+04,-4.2038045404190e+05, 3.4847108834282e+05,
+ -1.3486811106770e+04, 3.1256632170871e+04, 5.3344700235019e+03,
+ 2.6384242145376e+04, 1.2917121516510e+05,-1.3160848301195e+05,
+ -4.5853998051192e+05, 3.5760105069089e+05, 6.4570143281747e+05,
+ -3.6980075904167e+05,-3.2941029518332e+05,-3.5042507366553e+05,
+ 2.1513919629391e+03, 6.3403845616538e+04, 6.2152822008047e+04,
+ -4.8805335375295e+05,-6.3261951398766e+05, 1.8433340786742e+06,
+ 1.4650263449690e+06,-2.9204939728308e+06,-1.1011338105757e+06,
+ 1.7270664922758e+06, 3.4925947462024e+05,-1.9526251371308e+04,
+ -3.2271030511683e+04,-3.7601575719875e+05, 1.8295007005531e+05,
+ 1.5005699079799e+06,-1.2350076538617e+06,-1.8221938812193e+06,
+ 1.5438780841786e+06,-3.2729150692367e+03, 1.0546285883943e+04,
+ -4.7118461673723e+04,-1.1458551385925e+05, 2.7704588008958e+05,
+ 7.4145816862032e+05,-6.6864945408289e+05,-1.6992324545166e+06,
+ 6.7487333473248e+05, 1.4361670430046e+06,-2.0837555267331e+05,
+ 4.7678355561019e+05,-1.5194821786066e+04,-1.1987249931134e+05,
+ 1.3007675671713e+05, 9.6641544907323e+05,-5.3379849922258e+05,
+ -2.4303858824867e+06, 1.5261649025605e+06, 2.0186755858342e+06,
+ -1.6429544469130e+06,-1.7921520714752e+04, 1.4125624734639e+04,
+ -2.5345006031695e+04, 1.7853375909076e+05,-5.4318156343922e+04,
+ -3.6889685715963e+05, 4.2449670705837e+05, 3.5020329799394e+05,
+ 9.3825886484788e+03,-8.0012127425648e+05, 9.8554789856472e+04,
+ 4.9210554266522e+05,-6.4038493953446e+05,-2.8398085766046e+06,
+ 2.1390360019254e+06, 6.3452935017176e+06,-2.3677386290925e+06,
+ -3.9697874352050e+06,-1.9490691547041e+04, 4.4213579019433e+04,
+ 1.6113884156437e+05,-7.1247665213713e+05,-1.1808376404616e+06,
+ 3.0815171952564e+06, 1.3519809705593e+06,-3.4457898745450e+06,
+ 2.0705775494050e+05,-4.3778169926622e+05, 8.7041260169714e+03,
+ 1.8982512628535e+05,-2.9708215504578e+05,-8.8213012222074e+05,
+ 8.6031109049755e+05, 1.0968800857081e+06,-1.0114716732602e+06,
+ 1.9367263614108e+05, 2.8678295007137e+05,-9.4347729862989e+04,
+ 4.4154039394108e+04, 5.3686756196439e+05, 1.7254041770855e+05,
+ -2.5310674462399e+06,-2.0381171865455e+06, 3.3780796258176e+06,
+ 7.8836220768478e+05,-1.5307728782887e+05,-3.7573362053757e+05,
+ 1.0124501604626e+06, 2.0929686545723e+06,-5.7305706586465e+06,
+ -2.6200352535413e+06, 7.1543745536691e+06,-1.9733601879064e+04,
+ 8.5273008477607e+04, 6.1062454495045e+04,-2.2642508675984e+05,
+ 2.4581653864150e+05,-9.0376851105383e+05,-4.4367930945690e+05,
+ 1.5740351463593e+06, 2.4563041445249e+05,-3.4697646046367e+03,
+ -2.1391370322552e+05, 4.2358948404842e+05, 5.6270081955003e+05,
+ -8.5007851251980e+05,-6.1182429537130e+05, 5.6690751824341e+05,
+ -3.5617502919487e+05,-8.1875263381402e+02,-2.4506258140060e+05,
+ 2.5830513731509e+05, 6.0646114465433e+05,-6.9676584616955e+05,
+ 5.1937406389690e+05, 1.7261913546007e+05,-1.7405787307472e+04,
+ -3.8301842660567e+05, 5.4227693205154e+05, 2.5442083515211e+06,
+ -1.1837755702370e+06,-1.9381959088092e+06,-4.0642141553575e+05,
+ 1.1840693827934e+04,-1.5334500255967e+05, 4.9098619510989e+05,
+ 6.1688992640977e+05, 2.2351144690009e+05,-1.8550462739570e+06,
+ 9.6815110649918e+03,-8.1526584681055e+04,-8.0810433155289e+04,
+ 3.4520506615177e+05, 2.5509863381419e+05,-1.3331224992157e+05,
+ -4.3119301071653e+05,-5.9818343115856e+04, 1.7863692414573e+03,
+ 8.9440694919836e+04,-2.5558967650731e+05,-2.2130423988459e+04,
+ 4.4973674518316e+05,-2.2094939343618e+05])
+
+cbasis = jnp.array([6.9770019624764e-04,-2.4209870001642e+01, 1.8113927151562e+01,
+ 3.5107416275981e+01,-5.4600021126735e+00,-4.8731149608386e+01,
+ 3.6007189184766e+01, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ -7.7178474355102e+01,-3.8460795013977e+01,-4.6622480912340e+01,
+ 5.5684951167513e+01, 1.2274939911242e+02,-1.4325154752086e+02,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00,-6.0800589055949e+00,
+ 8.6171499453475e+01,-8.4066835441327e+01,-5.8228085624620e+01,
+ 2.0237393793875e+02, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 3.3525582670313e+02, 7.0056962392208e+01,-4.5312502936708e+01,
+ -3.0441141194247e+02, 2.8111438108965e+02, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00,-1.2983583774779e+02, 3.9781671212935e+01,
+ -6.6793945229609e+01,-1.9259805675433e+02, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00,-8.2855757669957e+02,-5.7003072730941e+01,
+ -3.5604806670066e+01, 9.6277766002709e+01, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 8.8645622149112e+02,-7.6908409772041e+01,
+ 6.8111763314154e+01, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 2.5090493428062e+02,-2.3622141780572e+02, 5.8155647658455e+02,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 2.8919570295095e+03,
+ -1.7871014635921e+02,-1.3515667622500e+02, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00,-3.6965613754734e+03, 2.1148158286617e+02,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00,-1.4795670139431e+03,
+ 3.6210798138768e+02, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ -5.3552886800881e+03, 3.1006384016202e+02, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 1.6241824368764e+03, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 4.3764909606382e+03, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 1.0940849243716e+03, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 3.0743267832931e+03, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00])
+
+ccore = jnp.array([2.4332191647159e-02,-2.9749090113656e+01, 1.8638980892831e+01,
+ -6.1272361746520e+00, 2.1567487597605e+00,-1.5552044084945e+01,
+ 8.9752150543954e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ -3.5693557878741e+02,-3.0398393196894e+00,-6.5936553294576e+00,
+ 1.6056619388911e+01, 7.8061422868204e+01,-8.6270891686359e+01,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00,-3.1688002530217e+01,
+ 3.7586725583944e+01,-3.2725765966657e+01,-5.6458213299259e+00,
+ 2.1502613314595e+01, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 5.2789943583277e+02,-4.2461079404962e+00,-2.4937638543122e+01,
+ -1.1963809321312e+02, 2.0240663228078e+02, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00,-6.2574211352272e+02,-6.9617539465382e+00,
+ -5.9440243471241e+01, 1.4944220180218e+01, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00,-1.2851139918332e+03,-6.5043516710835e+00,
+ 4.0410829440249e+01,-6.7162452402027e+01, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 1.0031942127832e+03, 7.6137226541944e+01,
+ -2.7279242226902e+01, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ -3.3059000871075e+01, 2.4384498749480e+01,-1.4597931874215e+02,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 1.6559579606045e+03,
+ 1.5038996611400e+02,-7.3865347730818e+01, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00,-1.9738401290808e+03,-1.4149993809415e+02,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00,-1.2756627454888e+02,
+ 4.1487702227579e+01, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ -1.7406770966429e+03,-9.3812204399266e+01, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00,-1.1890301282216e+03, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 2.3723447727360e+03, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00,-1.0279968223292e+03, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 5.7153838472603e+02, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00])
+
+crest = jnp.array([ 0.0000000000000e+00,-4.7430930170000e+00,-1.4422132560000e+01,
+ -1.8061146510000e+01, 7.5186735000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ -2.7962099800000e+02, 1.7616414260000e+01,-9.9741392630000e+01,
+ 7.1402447000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00,-7.8571336480000e+01,
+ 5.2434353250000e+01, 7.7696745000000e+01, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 1.7799123760000e+02, 1.4564532380000e+02, 2.2347226000000e+02,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00,-4.3823284100000e+02,-7.2846553000000e+02,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00,-2.6752313750000e+02, 3.6170310000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00,
+ 0.0000000000000e+00, 0.0000000000000e+00])
+
+idx1 = jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
+ 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3,
+ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
+ 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4,
+ 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6,
+ 6, 6, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5,
+ 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 5,
+ 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7,
+ 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6,
+ 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 9, 9,
+ 9, 9, 9, 9, 9])
+
+idx2 = jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3,
+ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3,
+ 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3,
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 4,
+ 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2,
+ 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4,
+ 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 1, 1,
+ 1, 1, 1, 1, 1])
+
+idx3 = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15, 1, 2, 3, 4, 5,
+ 6, 7, 8, 9,10,11,12,13,14, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,
+ 12,13, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13, 1, 2, 3, 4, 5,
+ 6, 7, 8, 9,10,11,12, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 1,
+ 2, 3, 4, 5, 6, 7, 8, 9,10,11, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,
+ 11, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11, 1, 2, 3, 4, 5, 6, 7, 8,
+ 9,10, 1, 2, 3, 4, 5, 6, 7, 8, 9,10, 1, 2, 3, 4, 5, 6, 7, 8,
+ 9,10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9,
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2,
+ 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6,
+ 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3,
+ 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2,
+ 3, 4, 5, 6, 7])
+
+matrix1 = np.zeros((245,16))
+matrix2 = np.zeros((245,16))
+matrix3 = np.zeros((245,16))
+for i in range(245):
+ a = int(idx1[i])
+ b = int(idx2[i])
+ c = int(idx3[i])
+ list1 = np.zeros(16)
+ list2 = np.zeros(16)
+ list3 = np.zeros(16)
+ list1[a] = 1
+ list2[b] = 1
+ list3[c] = 1
+ matrix1[i] = list1
+ matrix2[i] = list2
+ matrix3[i] = list3
+
+c5z = jnp.zeros(245)
+for i in range(245):
+ c5z = c5z.at[i].set(f5z*c5zA[i] + fbasis*cbasis[i]+ fcore*ccore[i] + frest*crest[i])
+deoh = f5z*deohA
+phh1 = f5z*phh1A*jnp.exp(phh2)
+costhe = -0.24780227221366464506
+
+Eh_J = 4.35974434e-18
+Na = 6.02214129e+23
+kcal_J = 4184.0
+c0 = 299792458.0
+h_Js = 6.62606957e-34
+cal2joule = 4.184
+Eh_kcalmol = Eh_J*Na/kcal_J
+Eh_cm1 = 1.0e-2*Eh_J/(c0*h_Js)
+cm1_kcalmol = Eh_kcalmol/Eh_cm1
+
+
+## compute intra
+def onebodyenergy(positions, box):
+ box_inv = jnp.linalg.inv(box)
+ O = positions[::3]
+ H1 = positions[1::3]
+ H2 = positions[2::3]
+ ROH1 = H1 - O
+ ROH2 = H2 - O
+ RHH = H1 - H2
+ ROH1 = v_pbc_shift(ROH1, box, box_inv)
+ ROH2 = v_pbc_shift(ROH2, box, box_inv)
+ RHH = v_pbc_shift(RHH, box, box_inv)
+ dROH1 = jnp.linalg.norm(ROH1, axis=1)
+ dROH2 = jnp.linalg.norm(ROH2, axis=1)
+ dRHH = jnp.linalg.norm(RHH, axis=1)
+ costh = jnp.sum(ROH1 * ROH2, axis=1) / (dROH1 * dROH2)
+ exp1 = jnp.exp(-alphaoh*(dROH1 - roh))
+ exp2 = jnp.exp(-alphaoh*(dROH2 - roh))
+ Va = deoh*(exp1*(exp1 - 2.0) + exp2*(exp2 - 2.0))
+ Vb = phh1*jnp.exp(-phh2*dRHH)
+ x1 = (dROH1 - reoh)/reoh
+ x2 = (dROH2 - reoh)/reoh
+ x3 = costh - costhe
+ efac = jnp.exp(-b1*(dROH1 - reoh)**2 + (dROH2 - reoh)**2)
+ energy = jnp.sum(onebody_kernel(x1, x2, x3, Va, Vb, efac))
+ return energy
+
+
+
+@vmap
+@jit_condition(static_argnums={})
+def onebody_kernel(x1, x2, x3, Va, Vb, efac):
+ const = jnp.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
+ CONST = jnp.array([const,const,const])
+ list1 = jnp.array([x1**i for i in range(-1, 15)])
+ list2 = jnp.array([x2**i for i in range(-1, 15)])
+ list3 = jnp.array([x3**i for i in range(-1, 15)])
+ fmat = jnp.array([list1, list2, list3])
+ fmat *= CONST
+ F1 = jnp.sum(fmat[0].T * matrix1, axis=1) # fmat[0][inI] 1*245
+ F2 = jnp.sum(fmat[1].T * matrix2, axis=1) #fmat[1][inJ] 1*245
+ F3 = jnp.sum(fmat[0].T * matrix2, axis=1) #fmat[0][inJ] 1*245
+ F4 = jnp.sum(fmat[1].T * matrix1, axis=1) #fmat[1][inI] 1*245
+ F5 = jnp.sum(fmat[2].T * matrix3, axis=1) #fmat[2][inK] 1*245
+ total = c5z * (F1*F2 + F3*F4)* F5
+ sum0 = jnp.sum(total[1:245])
+ Vc = 2*c5z[0] + efac*sum0
+ e1 = Va + Vb + Vc
+ e1 += 0.44739574026257
+ e1 *= cm1_kcalmol
+ e1 *= cal2joule # conver cal 2 j
+ return e1
+
+
+def validation(pdb):
+ xml = 'mpidwater.xml'
+ pdbinfo = read_pdb(pdb)
+ serials = pdbinfo['serials']
+ names = pdbinfo['names']
+ resNames = pdbinfo['resNames']
+ resSeqs = pdbinfo['resSeqs']
+ positions = pdbinfo['positions']
+ box = pdbinfo['box'] # a, b, c, α, β, γ
+ charges = pdbinfo['charges']
+ positions = jnp.asarray(positions)
+ lx, ly, lz, _, _, _ = box
+ box = jnp.eye(3)*jnp.array([lx, ly, lz])
+
+ mScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
+ pScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
+ dScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
+
+ rc = 4 # in Angstrom
+ ethresh = 1e-4
+
+ n_atoms = len(serials)
+
+ # compute intra
+
+
+ grad_E1 = value_and_grad(onebodyenergy,argnums=(0))
+ ene, force = grad_E1(positions, box)
+ print(ene,force)
+ return
+
+
+# below is the validation code
+if __name__ == '__main__':
+ validation(sys.argv[1])
+
+
diff --git a/dmff/admp/pairwise.py b/dmff/admp/pairwise.py
index d7b6547ce..d8b7b80d2 100755
--- a/dmff/admp/pairwise.py
+++ b/dmff/admp/pairwise.py
@@ -1,10 +1,12 @@
import sys
from jax import vmap
import jax.numpy as jnp
-from dmff.utils import jit_condition
+from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales
from dmff.admp.spatial import v_pbc_shift
from functools import partial
+DIELECTRIC = 1389.35455846
+
# for debug
# from jax_md import partition, space
# from admp.parser import *
@@ -62,13 +64,17 @@ def generate_pairwise_interaction(pair_int_kernel, covalent_map, static_args):
'''
def pair_int(positions, box, pairs, mScales, *atomic_params):
- pairs = pairs[pairs[:, 0] < pairs[:, 1]]
+ pairs = regularize_pairs(pairs)
+
ri = distribute_v3(positions, pairs[:, 0])
rj = distribute_v3(positions, pairs[:, 1])
# ri = positions[pairs[:, 0]]
# rj = positions[pairs[:, 1]]
nbonds = covalent_map[pairs[:, 0], pairs[:, 1]]
mscales = distribute_scalar(mScales, nbonds-1)
+
+ buffer_scales = pair_buffer_scales(pairs)
+ mscales = mscales * buffer_scales
# mscales = mScales[nbonds-1]
box_inv = jnp.linalg.inv(box)
dr = ri - rj
@@ -82,7 +88,7 @@ def pair_int(positions, box, pairs, mScales, *atomic_params):
# pair_params.append(param[pairs[:, 0]])
# pair_params.append(param[pairs[:, 1]])
- energy = jnp.sum(pair_int_kernel(dr, mscales, *pair_params))
+ energy = jnp.sum(pair_int_kernel(dr, mscales, *pair_params) * buffer_scales)
return energy
return pair_int
@@ -110,6 +116,58 @@ def TT_damping_qq_c6_kernel(dr, m, ai, aj, bi, bj, qi, qj, ci, cj):
return f * m
+@vmap
+@jit_condition(static_argnums={})
+def TT_damping_qq_kernel(dr, m, bi, bj, qi, qj):
+ b = jnp.sqrt(bi * bj)
+ q = qi * qj
+ br = b * dr
+ exp_br = jnp.exp(-br)
+ f = - DIELECTRIC * exp_br * (1+br) * q / dr
+ return f * m
+
+
+@vmap
+@jit_condition(static_argnums=())
+def slater_disp_damping_kernel(dr, m, bi, bj, c6i, c6j, c8i, c8j, c10i, c10j):
+ '''
+ Slater-ISA type damping for dispersion:
+ f(x) = -e^{-x} * \sum_{k} x^k/k!
+ x = Br - \frac{2*(Br)^2 + 3Br}{(Br)^2 + 3*Br + 3}
+ see jctc 12 3851
+ '''
+ b = jnp.sqrt(bi * bj)
+ c6 = c6i * c6j
+ c8 = c8i * c8j
+ c10 = c10i * c10j
+ br = b * dr
+ br2 = br * br
+ x = br - (2*br2 + 3*br) / (br2 + 3*br + 3)
+ s6 = 1 + x + x**2/2 + x**3/6 + x**4/24 + x**5/120 + x**6/720
+ s8 = s6 + x**7/5040 + x**8/40320
+ s10 = s8 + x**9/362880 + x**10/3628800
+ exp_x = jnp.exp(-x)
+ f6 = exp_x * s6
+ f8 = exp_x * s8
+ f10 = exp_x * s10
+ return (f6*c6/dr**6 + f8*c8/dr**8 + f10*c10/dr**10) * m
+
+
+@vmap
+@jit_condition(static_argnums=())
+def slater_sr_kernel(dr, m, ai, aj, bi, bj):
+ '''
+ Slater-ISA type short range terms
+ see jctc 12 3851
+ '''
+ b = jnp.sqrt(bi * bj)
+ a = ai * aj
+ br = b * dr
+ br2 = br * br
+ P = 1/3 * br2 + br + 1
+ return a * P * jnp.exp(-br) * m
+
+
def validation(pdb):
xml = 'mpidwater.xml'
pdbinfo = read_pdb(pdb)
diff --git a/dmff/admp/pme.py b/dmff/admp/pme.py
index 6112950d7..2cde47a1c 100755
--- a/dmff/admp/pme.py
+++ b/dmff/admp/pme.py
@@ -6,7 +6,7 @@
from jax.scipy.special import erf
from dmff.settings import DO_JIT
from dmff.admp.settings import POL_CONV, MAX_N_POL
-from dmff.utils import jit_condition
+from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales
from dmff.admp.multipole import C1_c2h, convert_cart2harm
from dmff.admp.multipole import rot_ind_global2local, rot_global2local, rot_local2global
from dmff.admp.spatial import v_pbc_shift, generate_construct_local_frames, build_quasi_internal
@@ -34,20 +34,58 @@ class ADMPPmeForce:
The so called "environment paramters" means parameters that do not need to be differentiable
'''
- def __init__(self, box, axis_type, axis_indices, covalent_map, rc, ethresh, lmax, lpol=False):
+ def __init__(self, box, axis_type, axis_indices, covalent_map, rc, ethresh, lmax, lpol=False, lpme=True, steps_pol=None):
+ '''
+ Initialize the ADMPPmeForce calculator.
+
+ Input:
+ box:
+ (3, 3) float, box size in row
+ axis_type:
+ (na,) int, types of local axis (bisector, z-then-x etc.)
+ covalent_map:
+ (na, na) int, covalent map matrix, labels the topological distances between atoms
+ rc:
+ float: cutoff distance
+ ethresh:
+ float: pme energy threshold
+ lmax:
+ int: max L for multipoles
+ lpol:
+ bool: polarize or not?
+ lpme:
+ bool: do pme or simple cutoff?
+ if False, the kappa will be set to zero and the reciprocal part will not be computed
+ steps:
+ None or int: Whether do fixed number of dipole iteration steps?
+ if None: converge dipoles until convergence threshold is met
+ if int: optimize for this many steps and stop, this is useful if you want to jit the entire function
+
+ Output:
+
+ '''
self.axis_type = axis_type
self.axis_indices = axis_indices
self.rc = rc
self.ethresh = ethresh
self.lmax = int(lmax) # jichen: type checking
- kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
- self.kappa = kappa
- self.K1 = K1
- self.K2 = K2
- self.K3 = K3
+ # turn off pme if lpme is False, this is useful when doing cluster calculations
+ self.lpme = lpme
+ if self.lpme is False:
+ self.kappa = 0
+ self.K1 = 0
+ self.K2 = 0
+ self.K3 = 0
+ else:
+ kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
+ self.kappa = kappa
+ self.K1 = K1
+ self.K2 = K2
+ self.K3 = K3
self.pme_order = 6
self.covalent_map = covalent_map
self.lpol = lpol
+ self.steps_pol = steps_pol
self.n_atoms = int(covalent_map.shape[0]) # len(axis_type)
# setup calculators
@@ -63,7 +101,7 @@ def get_energy(positions, box, pairs, Q_local, mScales):
Q_local, None, None, None,
mScales, None, None, self.covalent_map,
self.construct_local_frames, self.pme_recip,
- self.kappa, self.K1, self.K2, self.K3, self.lmax, False)
+ self.kappa, self.K1, self.K2, self.K3, self.lmax, False, lpme=self.lpme)
return get_energy
else:
# this is the bare energy calculator, with Uind as explicit input
@@ -72,14 +110,20 @@ def energy_fn(positions, box, pairs, Q_local, Uind_global, pol, tholes, mScales,
Q_local, Uind_global, pol, tholes,
mScales, pScales, dScales, self.covalent_map,
self.construct_local_frames, self.pme_recip,
- self.kappa, self.K1, self.K2, self.K3, self.lmax, True)
+ self.kappa, self.K1, self.K2, self.K3, self.lmax, True, lpme=self.lpme)
self.energy_fn = energy_fn
self.grad_U_fn = grad(self.energy_fn, argnums=(4))
self.grad_pos_fn = grad(self.energy_fn, argnums=(0))
self.U_ind = jnp.zeros((self.n_atoms, 3))
# this is the wrapper that include a Uind optimizer
- def get_energy(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, U_init=self.U_ind):
- self.U_ind, self.lconverg, self.n_cycle = self.optimize_Uind(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, U_init=U_init)
+ def get_energy(
+ positions, box, pairs,
+ Q_local, pol, tholes, mScales, pScales, dScales,
+ U_init=self.U_ind):
+ self.U_ind, self.lconverg, self.n_cycle = self.optimize_Uind(
+ positions, box, pairs, Q_local, pol, tholes,
+ mScales, pScales, dScales,
+ U_init=U_init, steps_pol=self.steps_pol)
# here we rely on Feynman-Hellman theorem, drop the term dV/dU*dU/dr !
# self.U_ind = jax.lax.stop_gradient(U_ind)
return self.energy_fn(positions, box, pairs, Q_local, self.U_ind, pol, tholes, mScales, pScales, dScales)
@@ -112,7 +156,11 @@ def refresh_calculators(self):
self.get_forces = value_and_grad(self.get_energy)
return
- def optimize_Uind(self, positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, U_init=None, maxiter=MAX_N_POL, thresh=POL_CONV):
+ def optimize_Uind(self,
+ positions, box, pairs,
+ Q_local, pol, tholes, mScales, pScales, dScales,
+ U_init=None, steps_pol=None,
+ maxiter=MAX_N_POL, thresh=POL_CONV):
'''
This function converges the induced dipole
Note that we cut all the gradient chain passing through this function as we assume Feynman-Hellman theorem
@@ -131,19 +179,28 @@ def optimize_Uind(self, positions, box, pairs, Q_local, pol, tholes, mScales, pS
U = jnp.zeros((self.n_atoms, 3))
else:
U = U_init
- site_filter = (pol>0.001) # focus on the actual polarizable sites
-
- for i in range(maxiter):
- field = self.grad_U_fn(positions, box, pairs, Q_local, U, pol, tholes, mScales, pScales, dScales)
- E = self.energy_fn(positions, box, pairs, Q_local, U, pol, tholes, mScales, pScales, dScales)
- if jnp.max(jnp.abs(field[site_filter])) < thresh:
- break
- U = U - field * pol[:, jnp.newaxis] / DIELECTRIC
- if i == maxiter-1:
- flag = False
- else: # converged
+ if steps_pol is None:
+ site_filter = (pol>0.001) # focus on the actual polarizable sites
+
+ if steps_pol is None:
+ for i in range(maxiter):
+ field = self.grad_U_fn(positions, box, pairs, Q_local, U, pol, tholes, mScales, pScales, dScales)
+ # E = self.energy_fn(positions, box, pairs, Q_local, U, pol, tholes, mScales, pScales, dScales)
+ if jnp.max(jnp.abs(field[site_filter])) < thresh:
+ break
+ U = U - field * pol[:, jnp.newaxis] / DIELECTRIC
+ if i == maxiter-1:
+ flag = False
+ else: # converged
+ flag = True
+ else:
+ def update_U(i, U):
+ field = self.grad_U_fn(positions, box, pairs, Q_local, U, pol, tholes, mScales, pScales, dScales)
+ U = U - field * pol[:, jnp.newaxis] / DIELECTRIC
+ return U
+ U = jax.lax.fori_loop(0, steps_pol, update_U, U)
flag = True
- return U, flag, i
+ return U, flag, steps_pol
def setup_ewald_parameters(rc, ethresh, box):
@@ -179,7 +236,7 @@ def setup_ewald_parameters(rc, ethresh, box):
def energy_pme(positions, box, pairs,
Q_local, Uind_global, pol, tholes,
mScales, pScales, dScales, covalent_map,
- construct_local_frame_fn, pme_recip_fn, kappa, K1, K2, K3, lmax, lpol):
+ construct_local_frame_fn, pme_recip_fn, kappa, K1, K2, K3, lmax, lpol, lpme=True):
'''
This is the top-level wrapper for multipole PME
@@ -213,8 +270,10 @@ def energy_pme(positions, box, pairs,
int: max K for reciprocal calculations
lmax:
int: maximum L
- bool:
- int: if polarizable or not? if yes, 1, otherwise 0
+ lpol:
+ bool: if polarizable or not? if yes, 1, otherwise 0
+ lpme:
+ bool: doing pme? If false, then turn off reciprocal space and set kappa = 0
Output:
energy: total pme energy
@@ -240,6 +299,9 @@ def energy_pme(positions, box, pairs,
else:
Q_global_tot = Q_global
+ if lpme is False:
+ kappa = 0
+
if lpol:
ene_real = pme_real(positions, box, pairs, Q_global, U_ind, pol, tholes,
mScales, pScales, dScales, covalent_map, kappa, lmax, True)
@@ -247,14 +309,23 @@ def energy_pme(positions, box, pairs,
ene_real = pme_real(positions, box, pairs, Q_global, None, None, None,
mScales, None, None, covalent_map, kappa, lmax, False)
- ene_recip = pme_recip_fn(positions, box, Q_global_tot)
+ if lpme:
+ ene_recip = pme_recip_fn(positions, box, Q_global_tot)
- ene_self = pme_self(Q_global_tot, kappa, lmax)
+ ene_self = pme_self(Q_global_tot, kappa, lmax)
- if lpol:
- ene_self += pol_penalty(U_ind, pol)
+ if lpol:
+ ene_self += pol_penalty(U_ind, pol)
+
+ return ene_real + ene_recip + ene_self
- return ene_real + ene_recip + ene_self
+ else:
+ if lpol:
+ ene_self = pol_penalty(U_ind, pol)
+ else:
+ ene_self = 0.0
+
+ return ene_real + ene_self
# @partial(vmap, in_axes=(0, 0, None, None), out_axes=0)
@@ -669,7 +740,10 @@ def pme_real(positions, box, pairs,
'''
# expand pairwise parameters, from atomic parameters
- pairs = pairs[pairs[:, 0] < pairs[:, 1]]
+ # debug
+ # pairs = pairs[pairs[:, 0] < pairs[:, 1]]
+ pairs = regularize_pairs(pairs)
+ buffer_scales = pair_buffer_scales(pairs)
box_inv = jnp.linalg.inv(box)
r1 = distribute_v3(positions, pairs[:, 0])
r2 = distribute_v3(positions, pairs[:, 1])
@@ -682,6 +756,7 @@ def pme_real(positions, box, pairs,
nbonds = covalent_map[pairs[:, 0], pairs[:, 1]]
indices = nbonds-1
mscales = distribute_scalar(mScales, indices)
+ mscales = mscales * buffer_scales
# mscales = mScales[nbonds-1]
if lpol:
pol1 = distribute_scalar(pol, pairs[:, 0])
@@ -691,7 +766,9 @@ def pme_real(positions, box, pairs,
Uind_extendi = distribute_v3(Uind_global, pairs[:, 0])
Uind_extendj = distribute_v3(Uind_global, pairs[:, 1])
pscales = distribute_scalar(pScales, indices)
+ pscales = pscales * buffer_scales
dscales = distribute_scalar(dScales, indices)
+ dscales = dscales * buffer_scales
# pol1 = pol[pairs[:,0]]
# pol2 = pol[pairs[:,1]]
# thole1 = tholes[pairs[:,0]]
@@ -725,7 +802,10 @@ def pme_real(positions, box, pairs,
qiUindJ = None
# everything should be pair-specific now
- ene = jnp.sum(pme_real_kernel(norm_dr, qiQI, qiQJ, qiUindI, qiUindJ, thole1, thole2, dmp, mscales, pscales, dscales, kappa, lmax, lpol))
+ ene = jnp.sum(
+ pme_real_kernel(norm_dr, qiQI, qiQJ, qiUindI, qiUindJ, thole1, thole2, dmp, mscales, pscales, dscales, kappa, lmax, lpol)
+ * buffer_scales
+ )
return ene
diff --git a/dmff/admp/settings.py b/dmff/admp/settings.py
index f08e88162..aa05b5704 100644
--- a/dmff/admp/settings.py
+++ b/dmff/admp/settings.py
@@ -1,3 +1,3 @@
# DEFAULT THRESHOLDS
-POL_CONV = 10.0 # gradient convergence thresh for induced dipoles
-MAX_N_POL = 30 # maximum number of cyles for optimizing induced dipole
\ No newline at end of file
+POL_CONV = 1.0 # gradient convergence thresh for induced dipoles
+MAX_N_POL = 30 # maximum number of cyles for optimizing induced dipole
diff --git a/dmff/api.py b/dmff/api.py
index ba9823ff3..fc8788231 100644
--- a/dmff/api.py
+++ b/dmff/api.py
@@ -10,16 +10,18 @@
from .admp.disp_pme import ADMPDispPmeForce
from .admp.multipole import convert_cart2harm, rot_local2global
from .admp.pairwise import TT_damping_qq_c6_kernel, generate_pairwise_interaction
+from .admp.pairwise import slater_disp_damping_kernel, slater_sr_kernel, TT_damping_qq_kernel
from .admp.pme import ADMPPmeForce
from .admp.spatial import generate_construct_local_frames
from .admp.recip import Ck_1, generate_pme_recip
+from .utils import jit_condition
from .classical.intra import (
HarmonicBondJaxForce,
HarmonicAngleJaxForce,
PeriodicTorsionJaxForce,
)
from jax_md import space, partition
-from jax import grad
+from jax import grad, vmap
import linecache
import itertools
from .classical.inter import (
@@ -87,7 +89,7 @@ def build_covalent_map(data, max_neighbor):
if k != i and k not in j_list:
covalent_map[i, k] = n_curr + 1
covalent_map[k, i] = n_curr + 1
- return covalent_map
+ return jnp.array(covalent_map)
def findAtomTypeTexts(attribs, num):
@@ -136,12 +138,25 @@ def parseElement(element, hamiltonian):
def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
+ methodMap = {
+ app.CutoffPeriodic: "CutoffPeriodic",
+ app.NoCutoff: "NoCutoff",
+ app.PME: "PME",
+ }
+ if nonbondedMethod not in methodMap:
+ raise ValueError("Illegal nonbonded method for ADMPDispForce")
+ if nonbondedMethod is app.CutoffPeriodic:
+ self.lpme = False
+ else:
+ self.lpme = True
+
n_atoms = len(data.atoms)
# build index map
map_atomtype = np.zeros(n_atoms, dtype=int)
for i in range(n_atoms):
atype = data.atomType[data.atoms[i]]
map_atomtype[i] = np.where(self.types == atype)[0][0]
+ self.map_atomtype = map_atomtype
# build covalent map
covalent_map = build_covalent_map(data, 6)
@@ -155,13 +170,9 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
if "ethresh" in args:
self.ethresh = args["ethresh"]
- Force_DispPME = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh, self.pmax)
+ Force_DispPME = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh,
+ self.pmax, lpme=self.lpme)
self.disp_pme_force = Force_DispPME
- # debugging
- # Force_DispPME.update_env('kappa', 0.657065221219616)
- # Force_DispPME.update_env('K1', 96)
- # Force_DispPME.update_env('K2', 96)
- # Force_DispPME.update_env('K3', 96)
pot_fn_lr = Force_DispPME.get_energy
pot_fn_sr = generate_pairwise_interaction(
TT_damping_qq_c6_kernel, covalent_map, static_args={}
@@ -195,11 +206,369 @@ def renderXML(self):
# generate xml force field file
pass
-
# register all parsers
app.forcefield.parsers["ADMPDispForce"] = ADMPDispGenerator.parseElement
+class ADMPDispPmeGenerator:
+ '''
+ This one computes the undamped C6/C8/C10 interactions
+ u = \sum_{ij} c6/r^6 + c8/r^8 + c10/r^10
+ '''
+
+ def __init__(self, hamiltonian):
+ self.ff = hamiltonian
+ self.params = {
+ "C6": [],
+ "C8": [],
+ "C10": []
+ }
+ self._jaxPotential = None
+ self.types = []
+ self.ethresh = 5e-4
+ self.pmax = 10
+
+ def registerAtomType(self, atom):
+ self.types.append(atom["type"])
+ self.params["C6"].append(float(atom["C6"]))
+ self.params["C8"].append(float(atom["C8"]))
+ self.params["C10"].append(float(atom["C10"]))
+
+ @staticmethod
+ def parseElement(element, hamiltonian):
+ generator = ADMPDispPmeGenerator(hamiltonian)
+ hamiltonian.registerGenerator(generator)
+ # covalent scales
+ mScales = []
+ for i in range(2, 7):
+ mScales.append(float(element.attrib["mScale1%d" % i]))
+ mScales.append(1.0)
+ generator.params["mScales"] = mScales
+ for atomtype in element.findall("Atom"):
+ generator.registerAtomType(atomtype.attrib)
+ # jax it!
+ for k in generator.params.keys():
+ generator.params[k] = jnp.array(generator.params[k])
+ generator.types = np.array(generator.types)
+
+ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
+ args):
+ methodMap = {
+ app.CutoffPeriodic: "CutoffPeriodic",
+ app.NoCutoff: "NoCutoff",
+ app.PME: "PME",
+ }
+ if nonbondedMethod not in methodMap:
+ raise ValueError("Illegal nonbonded method for ADMPDispPmeForce")
+ if nonbondedMethod is app.CutoffPeriodic:
+ self.lpme = False
+ else:
+ self.lpme = True
+
+ n_atoms = len(data.atoms)
+ # build index map
+ map_atomtype = np.zeros(n_atoms, dtype=int)
+ for i in range(n_atoms):
+ atype = data.atomType[data.atoms[i]]
+ map_atomtype[i] = np.where(self.types == atype)[0][0]
+ self.map_atomtype = map_atomtype
+ # build covalent map
+ covalent_map = build_covalent_map(data, 6)
+
+ # here box is only used to setup ewald parameters, no need to be differentiable
+ a, b, c = system.getDefaultPeriodicBoxVectors()
+ box = jnp.array([a._value, b._value, c._value]) * 10
+ # get the admp calculator
+ rc = nonbondedCutoff.value_in_unit(unit.angstrom)
+
+ # get calculator
+ if 'ethresh' in args:
+ self.ethresh = args['ethresh']
+
+ disp_force = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh,
+ self.pmax, self.lpme)
+ self.disp_force = disp_force
+ pot_fn_lr = disp_force.get_energy
+
+ def potential_fn(positions, box, pairs, params):
+ mScales = params["mScales"]
+ C6_list = params["C6"][map_atomtype] * 1e6 # to kj/mol * A**6
+ C8_list = params["C8"][map_atomtype] * 1e8
+ C10_list = params["C10"][map_atomtype] * 1e10
+ c6_list = jnp.sqrt(C6_list)
+ c8_list = jnp.sqrt(C8_list)
+ c10_list = jnp.sqrt(C10_list)
+ c_list = jnp.vstack((c6_list, c8_list, c10_list))
+ E_lr = pot_fn_lr(positions, box, pairs, c_list.T, mScales)
+ return - E_lr
+
+ self._jaxPotential = potential_fn
+ # self._top_data = data
+
+ def getJaxPotential(self):
+ return self._jaxPotential
+
+ def renderXML(self):
+ # generate xml force field file
+ pass
+
+# register all parsers
+app.forcefield.parsers["ADMPDispPmeForce"] = ADMPDispPmeGenerator.parseElement
+
+class QqTtDampingGenerator:
+ '''
+ This one calculates the tang-tonnies damping of charge-charge interaction
+ E = \sum_ij exp(-B*r)*(1+B*r)*q_i*q_j/r
+ '''
+ def __init__(self, hamiltonian):
+ self.ff = hamiltonian
+ self.params = {
+ "B": [],
+ "Q": [],
+ }
+ self._jaxPotential = None
+ self.types = []
+
+ def registerAtomType(self, atom):
+ self.types.append(atom["type"])
+ self.params["B"].append(float(atom["B"]))
+ self.params["Q"].append(float(atom["Q"]))
+
+ @staticmethod
+ def parseElement(element, hamiltonian):
+ generator = QqTtDampingGenerator(hamiltonian)
+ hamiltonian.registerGenerator(generator)
+ # covalent scales
+ mScales = []
+ for i in range(2, 7):
+ mScales.append(float(element.attrib["mScale1%d" % i]))
+ mScales.append(1.0)
+ generator.params["mScales"] = mScales
+ for atomtype in element.findall("Atom"):
+ generator.registerAtomType(atomtype.attrib)
+ # jax it!
+ for k in generator.params.keys():
+ generator.params[k] = jnp.array(generator.params[k])
+ generator.types = np.array(generator.types)
+
+ # on working
+ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
+ args):
+
+ n_atoms = len(data.atoms)
+ # build index map
+ map_atomtype = np.zeros(n_atoms, dtype=int)
+ for i in range(n_atoms):
+ atype = data.atomType[data.atoms[i]]
+ map_atomtype[i] = np.where(self.types == atype)[0][0]
+ self.map_atomtype = map_atomtype
+ # build covalent map
+ covalent_map = build_covalent_map(data, 6)
+
+ pot_fn_sr = generate_pairwise_interaction(TT_damping_qq_kernel,
+ covalent_map,
+ static_args={})
+
+ def potential_fn(positions, box, pairs, params):
+ mScales = params["mScales"]
+ b_list = params["B"][map_atomtype] / 10 # convert to A^-1
+ q_list = params["Q"][map_atomtype]
+
+ E_sr = pot_fn_sr(positions, box, pairs, mScales, b_list, q_list)
+ return E_sr
+
+ self._jaxPotential = potential_fn
+ # self._top_data = data
+
+ def getJaxPotential(self):
+ return self._jaxPotential
+
+ def renderXML(self):
+ # generate xml force field file
+ pass
+
+# register all parsers
+app.forcefield.parsers["QqTtDampingForce"] = QqTtDampingGenerator.parseElement
+
+
+class SlaterDampingGenerator:
+ '''
+ This one computes the slater-type damping function for c6/c8/c10 dispersion
+ E = \sum_ij (f6-1)*c6/r6 + (f8-1)*c8/r8 + (f10-1)*c10/r10
+ fn = f_tt(x, n)
+ x = br - (2*br2 + 3*br) / (br2 + 3*br + 3)
+ '''
+ def __init__(self, hamiltonian):
+ self.ff = hamiltonian
+ self.params = {
+ "B": [],
+ "C6": [],
+ "C8": [],
+ "C10": [],
+ }
+ self._jaxPotential = None
+ self.types = []
+
+ def registerAtomType(self, atom):
+ self.types.append(atom["type"])
+ self.params["B"].append(float(atom["B"]))
+ self.params["C6"].append(float(atom["C6"]))
+ self.params["C8"].append(float(atom["C8"]))
+ self.params["C10"].append(float(atom["C10"]))
+
+ @staticmethod
+ def parseElement(element, hamiltonian):
+ generator = SlaterDampingGenerator(hamiltonian)
+ hamiltonian.registerGenerator(generator)
+ # covalent scales
+ mScales = []
+ for i in range(2, 7):
+ mScales.append(float(element.attrib["mScale1%d" % i]))
+ mScales.append(1.0)
+ generator.params["mScales"] = mScales
+ for atomtype in element.findall("Atom"):
+ generator.registerAtomType(atomtype.attrib)
+ # jax it!
+ for k in generator.params.keys():
+ generator.params[k] = jnp.array(generator.params[k])
+ generator.types = np.array(generator.types)
+
+ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
+ args):
+
+ n_atoms = len(data.atoms)
+ # build index map
+ map_atomtype = np.zeros(n_atoms, dtype=int)
+ for i in range(n_atoms):
+ atype = data.atomType[data.atoms[i]]
+ map_atomtype[i] = np.where(self.types == atype)[0][0]
+ self.map_atomtype = map_atomtype
+ # build covalent map
+ covalent_map = build_covalent_map(data, 6)
+
+ # WORKING
+ pot_fn_sr = generate_pairwise_interaction(slater_disp_damping_kernel,
+ covalent_map,
+ static_args={})
+
+ def potential_fn(positions, box, pairs, params):
+ mScales = params["mScales"]
+ b_list = params["B"][map_atomtype] / 10 # convert to A^-1
+ c6_list = jnp.sqrt(params["C6"][map_atomtype] * 1e6) # to kj/mol * A**6
+ c8_list = jnp.sqrt(params["C8"][map_atomtype] * 1e8)
+ c10_list = jnp.sqrt(params["C10"][map_atomtype] * 1e10)
+ E_sr = pot_fn_sr(positions, box, pairs, mScales, b_list, c6_list, c8_list, c10_list)
+ return E_sr
+
+ self._jaxPotential = potential_fn
+ # self._top_data = data
+
+ def getJaxPotential(self):
+ return self._jaxPotential
+
+ def renderXML(self):
+ # generate xml force field file
+ pass
+
+app.forcefield.parsers["SlaterDampingForce"] = SlaterDampingGenerator.parseElement
+
+
+class SlaterExGenerator:
+ '''
+ This one computes the Slater-ISA type exchange interaction
+ u = \sum_ij A * (1/3*(Br)^2 + Br + 1)
+ '''
+
+ def __init__(self, hamiltonian):
+ self.ff = hamiltonian
+ self.params = {
+ "A": [],
+ "B": [],
+ }
+ self._jaxPotential = None
+ self.types = []
+
+ def registerAtomType(self, atom):
+ self.types.append(atom["type"])
+ self.params["A"].append(float(atom["A"]))
+ self.params["B"].append(float(atom["B"]))
+
+ @staticmethod
+ def parseElement(element, hamiltonian):
+ generator = SlaterExGenerator(hamiltonian)
+ hamiltonian.registerGenerator(generator)
+ # covalent scales
+ mScales = []
+ for i in range(2, 7):
+ mScales.append(float(element.attrib["mScale1%d" % i]))
+ mScales.append(1.0)
+ generator.params["mScales"] = mScales
+ for atomtype in element.findall("Atom"):
+ generator.registerAtomType(atomtype.attrib)
+ # jax it!
+ for k in generator.params.keys():
+ generator.params[k] = jnp.array(generator.params[k])
+ generator.types = np.array(generator.types)
+
+ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
+ args):
+
+ n_atoms = len(data.atoms)
+ # build index map
+ map_atomtype = np.zeros(n_atoms, dtype=int)
+ for i in range(n_atoms):
+ atype = data.atomType[data.atoms[i]]
+ map_atomtype[i] = np.where(self.types == atype)[0][0]
+ self.map_atomtype = map_atomtype
+ # build covalent map
+ covalent_map = build_covalent_map(data, 6)
+
+ pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel,
+ covalent_map,
+ static_args={})
+
+ def potential_fn(positions, box, pairs, params):
+ mScales = params["mScales"]
+ a_list = params["A"][map_atomtype]
+ b_list = params["B"][map_atomtype] / 10 # nm^-1 to A^-1
+
+ return pot_fn_sr(positions, box, pairs, mScales, a_list, b_list)
+
+ self._jaxPotential = potential_fn
+ # self._top_data = data
+
+ def getJaxPotential(self):
+ return self._jaxPotential
+
+ def renderXML(self):
+ # generate xml force field file
+ pass
+
+app.forcefield.parsers["SlaterExForce"] = SlaterExGenerator.parseElement
+
+
+# Here are all the short range "charge penetration" terms
+# They all have the exchange form
+class SlaterSrEsGenerator(SlaterExGenerator):
+ def __init__(self):
+ super().__init__(self)
+class SlaterSrPolGenerator(SlaterExGenerator):
+ def __init__(self):
+ super().__init__(self)
+class SlaterSrDispGenerator(SlaterExGenerator):
+ def __init__(self):
+ super().__init__(self)
+class SlaterDhfGenerator(SlaterExGenerator):
+ def __init__(self):
+ super().__init__(self)
+
+# register all parsers
+app.forcefield.parsers["SlaterSrEsForce"] = SlaterSrEsGenerator.parseElement
+app.forcefield.parsers["SlaterSrPolForce"] = SlaterSrPolGenerator.parseElement
+app.forcefield.parsers["SlaterSrDispForce"] = SlaterSrDispGenerator.parseElement
+app.forcefield.parsers["SlaterDhfForce"] = SlaterDhfGenerator.parseElement
+
+
class ADMPPmeGenerator:
def __init__(self, hamiltonian):
self.ff = hamiltonian
@@ -382,12 +751,25 @@ def parseElement(element, hamiltonian):
def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
+ methodMap = {
+ app.CutoffPeriodic: "CutoffPeriodic",
+ app.NoCutoff: "NoCutoff",
+ app.PME: "PME",
+ }
+ if nonbondedMethod not in methodMap:
+ raise ValueError("Illegal nonbonded method for ADMPPmeForce")
+ if nonbondedMethod is app.CutoffPeriodic:
+ self.lpme = False
+ else:
+ self.lpme = True
+
n_atoms = len(data.atoms)
map_atomtype = np.zeros(n_atoms, dtype=int)
for i in range(n_atoms):
atype = data.atomType[data.atoms[i]]
map_atomtype[i] = np.where(self.types == atype)[0][0]
+ self.map_atomtype = map_atomtype
# here box is only used to setup ewald parameters, no need to be differentiable
a, b, c = system.getDefaultPeriodicBoxVectors()
@@ -614,7 +996,6 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
self.axis_types = None
self.axis_indices = None
- # get calculator
if "ethresh" in args:
self.ethresh = args["ethresh"]
@@ -627,6 +1008,7 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
self.ethresh,
self.lmax,
self.lpol,
+ lpme=self.lpme
)
self.pme_force = pme_force
diff --git a/dmff/settings.py b/dmff/settings.py
index 0ba35abe1..c7965bc2d 100644
--- a/dmff/settings.py
+++ b/dmff/settings.py
@@ -6,4 +6,4 @@
if PRECISION == 'double':
config.update("jax_enable_x64", True)
-
\ No newline at end of file
+
diff --git a/dmff/utils.py b/dmff/utils.py
index 6182b8123..1d5d9ac0d 100644
--- a/dmff/utils.py
+++ b/dmff/utils.py
@@ -1,5 +1,6 @@
from dmff.settings import DO_JIT
-from jax import jit
+from jax import jit, vmap
+import jax.numpy as jnp
def jit_condition(*args, **kwargs):
def jit_deco(func):
@@ -7,4 +8,23 @@ def jit_deco(func):
return jit(func, *args, **kwargs)
else:
return func
- return jit_deco
\ No newline at end of file
+ return jit_deco
+
+
+@jit_condition()
+@vmap
+def regularize_pairs(p):
+ dp = p[1] - p[0]
+ dp = jnp.piecewise(dp, (dp<=0, dp>0), (lambda x: jnp.array(1), lambda x: jnp.array(0)))
+ dp_vec = jnp.array([dp, 2*dp])
+ p = p - dp_vec
+ return p
+
+
+@jit_condition()
+@vmap
+def pair_buffer_scales(p):
+ return jnp.piecewise(
+ p[0] - p[1],
+ (p[0] - p[1] < 0, p[0] - p[1] >= 0),
+ (lambda x: jnp.array(1), lambda x: jnp.array(0)))
diff --git a/examples/peg_slater_isa/benchmark.xml b/examples/peg_slater_isa/benchmark.xml
new file mode 100644
index 000000000..2dcc5f1a4
--- /dev/null
+++ b/examples/peg_slater_isa/benchmark.xml
@@ -0,0 +1,77 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/peg_slater_isa/calc_energy_comps.py b/examples/peg_slater_isa/calc_energy_comps.py
new file mode 100755
index 000000000..8066a28d9
--- /dev/null
+++ b/examples/peg_slater_isa/calc_energy_comps.py
@@ -0,0 +1,202 @@
+#!/usr/bin/env python
+import numpy as np
+import openmm
+from openmm import *
+from openmm.app import *
+from openmm.unit import *
+import jax
+import jax_md
+import jax.numpy as jnp
+import dmff
+from dmff.api import Hamiltonian
+import pickle
+import time
+
+
+if __name__ == '__main__':
+ ff = 'forcefield.xml'
+ pdb_AB = PDBFile('peg2_dimer.pdb')
+ pdb_A = PDBFile('peg2.pdb')
+ pdb_B = PDBFile('peg2.pdb')
+ H_AB = Hamiltonian(ff)
+ H_A = Hamiltonian(ff)
+ H_B = Hamiltonian(ff)
+ pme_generator_AB, \
+ disp_generator_AB, \
+ ex_generator_AB, \
+ sr_es_generator_AB, \
+ sr_pol_generator_AB, \
+ sr_disp_generator_AB, \
+ dhf_generator_AB, \
+ dmp_es_generator_AB, \
+ dmp_disp_generator_AB = H_AB.getGenerators()
+ pme_generator_A, \
+ disp_generator_A, \
+ ex_generator_A, \
+ sr_es_generator_A, \
+ sr_pol_generator_A, \
+ sr_disp_generator_A, \
+ dhf_generator_A, \
+ dmp_es_generator_A, \
+ dmp_disp_generator_A = H_A.getGenerators()
+ pme_generator_B, \
+ disp_generator_B, \
+ ex_generator_B, \
+ sr_es_generator_B, \
+ sr_pol_generator_B, \
+ sr_disp_generator_B, \
+ dhf_generator_B, \
+ dmp_es_generator_B, \
+ dmp_disp_generator_B = H_B.getGenerators()
+
+ rc = 15
+
+ # get potential functions
+ potentials_AB = H_AB.createPotential(pdb_AB.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4)
+ pot_pme_AB, \
+ pot_disp_AB, \
+ pot_ex_AB, \
+ pot_sr_es_AB, \
+ pot_sr_pol_AB, \
+ pot_sr_disp_AB, \
+ pot_dhf_AB, \
+ pot_dmp_es_AB, \
+ pot_dmp_disp_AB = potentials_AB
+ potentials_A = H_A.createPotential(pdb_A.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4)
+ pot_pme_A, \
+ pot_disp_A, \
+ pot_ex_A, \
+ pot_sr_es_A, \
+ pot_sr_pol_A, \
+ pot_sr_disp_A, \
+ pot_dhf_A, \
+ pot_dmp_es_A, \
+ pot_dmp_disp_A = potentials_A
+ potentials_B = H_B.createPotential(pdb_B.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4)
+ pot_pme_B, \
+ pot_disp_B, \
+ pot_ex_B, \
+ pot_sr_es_B, \
+ pot_sr_pol_B, \
+ pot_sr_disp_B, \
+ pot_dhf_B, \
+ pot_dmp_es_B, \
+ pot_dmp_disp_B = potentials_B
+
+ pos_AB0 = jnp.array(pdb_AB.positions._value) * 10
+ n_atoms = len(pos_AB0)
+ n_atoms_A = n_atoms // 2
+ n_atoms_B = n_atoms // 2
+ pos_A0 = jnp.array(pdb_AB.positions._value[:n_atoms_A]) * 10
+ pos_B0 = jnp.array(pdb_AB.positions._value[n_atoms_A:n_atoms]) * 10
+ box = jnp.array(pdb_AB.topology.getPeriodicBoxVectors()._value) * 10
+ # nn list initial allocation
+ displacement_fn, shift_fn = jax_md.space.periodic_general(box, fractional_coordinates=False)
+ neighbor_list_fn = jax_md.partition.neighbor_list(displacement_fn, box, rc, 0, format=jax_md.partition.OrderedSparse)
+ nbr_AB = neighbor_list_fn.allocate(pos_AB0)
+ nbr_A = neighbor_list_fn.allocate(pos_A0)
+ nbr_B = neighbor_list_fn.allocate(pos_B0)
+ pairs_AB = np.array(nbr_AB.idx.T)
+ pairs_A = np.array(nbr_A.idx.T)
+ pairs_B = np.array(nbr_B.idx.T)
+ pairs_AB = pairs_AB[pairs_AB[:, 0] < pairs_AB[:, 1]]
+ pairs_A = pairs_A[pairs_A[:, 0] < pairs_A[:, 1]]
+ pairs_B = pairs_B[pairs_B[:, 0] < pairs_B[:, 1]]
+
+ # load data
+ with open('data.pickle', 'rb') as ifile:
+ data = pickle.load(ifile)
+
+ keys = list(data.keys())
+ keys.sort()
+ # for sid in keys:
+ for sid in ['000']:
+ scan_res = data[sid]
+ scan_res['tot_full'] = scan_res['tot'].copy()
+ npts = len(scan_res['tot'])
+ print(sid)
+
+ for ipt in range(npts):
+ E_es_ref = scan_res['es'][ipt]
+ E_pol_ref = scan_res['pol'][ipt]
+ E_disp_ref = scan_res['disp'][ipt]
+ E_ex_ref = scan_res['ex'][ipt]
+ E_dhf_ref = scan_res['dhf'][ipt]
+ E_tot_ref = scan_res['tot'][ipt]
+
+ # get position array
+ pos_A = jnp.array(scan_res['posA'][ipt])
+ pos_B = jnp.array(scan_res['posB'][ipt])
+ pos_AB = jnp.concatenate([pos_A, pos_B], axis=0)
+
+
+ #####################
+ # exchange repulsion
+ #####################
+ E_ex_AB = pot_ex_AB(pos_AB, box, pairs_AB, ex_generator_AB.params)
+ E_ex_A = pot_ex_A(pos_A, box, pairs_A, ex_generator_AB.params)
+ E_ex_B = pot_ex_B(pos_B, box, pairs_B, ex_generator_AB.params)
+ E_ex = E_ex_AB - E_ex_A - E_ex_B
+
+ #######################
+ # electrostatic + pol
+ #######################
+ E_AB = pot_pme_AB(pos_AB, box, pairs_AB, pme_generator_AB.params)
+ E_A = pot_pme_A(pos_A, box, pairs_A, pme_generator_A.params)
+ E_B = pot_pme_B(pos_B, box, pairs_A, pme_generator_B.params)
+ E_espol = E_AB - E_A - E_B
+
+ # use induced dipole of monomers to compute electrostatic interaction
+ U_ind_AB = jnp.vstack((pme_generator_A.pme_force.U_ind, pme_generator_B.pme_force.U_ind))
+ params = pme_generator_AB.params
+ map_atypes = pme_generator_AB.map_atomtype
+ Q_local = params['Q_local'][map_atypes]
+ pol = params['pol'][map_atypes]
+ tholes = params['tholes'][map_atypes]
+ pme_force = pme_generator_AB.pme_force
+ E_AB_nonpol = pme_force.energy_fn(pos_AB, box, pairs_AB, Q_local, U_ind_AB, pol, tholes, params['mScales'], params['pScales'], params['dScales'])
+ E_es = E_AB_nonpol - E_A - E_B
+ E_dmp_es = pot_dmp_es_AB(pos_AB, box, pairs_AB, dmp_es_generator_AB.params) \
+ - pot_dmp_es_A(pos_A, box, pairs_A, dmp_es_generator_A.params) \
+ - pot_dmp_es_B(pos_B, box, pairs_B, dmp_es_generator_B.params)
+ E_sr_es = pot_sr_es_AB(pos_AB, box, pairs_AB, sr_es_generator_AB.params) \
+ - pot_sr_es_A(pos_A, box, pairs_A, sr_es_generator_AB.params) \
+ - pot_sr_es_B(pos_B, box, pairs_B, sr_es_generator_AB.params)
+
+
+ ###################################
+ # polarization (induction) energy
+ ###################################
+ E_pol = E_espol - E_es
+ E_sr_pol = pot_sr_pol_AB(pos_AB, box, pairs_AB, sr_pol_generator_AB.params) \
+ - pot_sr_pol_A(pos_A, box, pairs_A, sr_pol_generator_AB.params) \
+ - pot_sr_pol_B(pos_B, box, pairs_B, sr_pol_generator_AB.params)
+
+
+ #############
+ # dispersion
+ #############
+ E_AB_disp = pot_disp_AB(pos_AB, box, pairs_AB, disp_generator_AB.params)
+ E_A_disp = pot_disp_A(pos_A, box, pairs_A, disp_generator_AB.params)
+ E_B_disp = pot_disp_B(pos_B, box, pairs_B, disp_generator_AB.params)
+ E_disp = E_AB_disp - E_A_disp - E_B_disp
+ E_dmp_disp = pot_dmp_disp_AB(pos_AB, box, pairs_AB, dmp_disp_generator_AB.params) \
+ - pot_dmp_disp_A(pos_A, box, pairs_A, dmp_disp_generator_A.params) \
+ - pot_dmp_disp_B(pos_B, box, pairs_B, dmp_disp_generator_B.params)
+ E_sr_disp = pot_sr_disp_AB(pos_AB, box, pairs_AB, sr_disp_generator_AB.params) \
+ - pot_sr_disp_A(pos_A, box, pairs_A, sr_disp_generator_AB.params) \
+ - pot_sr_disp_B(pos_B, box, pairs_B, sr_disp_generator_AB.params)
+
+ ###########
+ # dhf
+ ###########
+ E_AB_dhf = pot_dhf_AB(pos_AB, box, pairs_AB, dhf_generator_AB.params)
+ E_A_dhf = pot_dhf_A(pos_A, box, pairs_A, dhf_generator_AB.params)
+ E_B_dhf = pot_dhf_B(pos_B, box, pairs_B, dhf_generator_AB.params)
+ E_dhf = E_AB_dhf - E_A_dhf - E_B_dhf
+
+ # total energy
+ E_tot = (E_es + E_sr_es + E_dmp_es) + (E_ex) + (E_pol + E_sr_pol) + (E_disp + E_dmp_disp + E_sr_disp) + (E_dhf)
+ # print(E_dmp_es + E_disp + E_dmp_disp)
+ print(E_es + E_pol)
+
diff --git a/examples/peg_slater_isa/check.py b/examples/peg_slater_isa/check.py
new file mode 100755
index 000000000..5b2f22abd
--- /dev/null
+++ b/examples/peg_slater_isa/check.py
@@ -0,0 +1,262 @@
+#!/usr/bin/env python
+import sys
+import numpy as np
+import openmm
+from openmm import *
+from openmm.app import *
+from openmm.unit import *
+import jax
+import jax_md
+import jax.numpy as jnp
+import dmff
+from dmff.api import Hamiltonian
+import pickle
+import time
+from jax import value_and_grad, jit
+import optax
+
+
+if __name__ == '__main__':
+ ff = 'forcefield.xml'
+ pdb_AB = PDBFile('peg2_dimer.pdb')
+ pdb_A = PDBFile('peg2.pdb')
+ pdb_B = PDBFile('peg2.pdb')
+ param_file = 'params.0.pickle'
+ H_AB = Hamiltonian(ff)
+ H_A = Hamiltonian(ff)
+ H_B = Hamiltonian(ff)
+ pme_generator_AB, \
+ disp_generator_AB, \
+ ex_generator_AB, \
+ sr_es_generator_AB, \
+ sr_pol_generator_AB, \
+ sr_disp_generator_AB, \
+ dhf_generator_AB, \
+ dmp_es_generator_AB, \
+ dmp_disp_generator_AB = H_AB.getGenerators()
+ pme_generator_A, \
+ disp_generator_A, \
+ ex_generator_A, \
+ sr_es_generator_A, \
+ sr_pol_generator_A, \
+ sr_disp_generator_A, \
+ dhf_generator_A, \
+ dmp_es_generator_A, \
+ dmp_disp_generator_A = H_A.getGenerators()
+ pme_generator_B, \
+ disp_generator_B, \
+ ex_generator_B, \
+ sr_es_generator_B, \
+ sr_pol_generator_B, \
+ sr_disp_generator_B, \
+ dhf_generator_B, \
+ dmp_es_generator_B, \
+ dmp_disp_generator_B = H_B.getGenerators()
+
+ rc = 15
+
+ # get potential functions
+ potentials_AB = H_AB.createPotential(pdb_AB.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4)
+ pot_pme_AB, \
+ pot_disp_AB, \
+ pot_ex_AB, \
+ pot_sr_es_AB, \
+ pot_sr_pol_AB, \
+ pot_sr_disp_AB, \
+ pot_dhf_AB, \
+ pot_dmp_es_AB, \
+ pot_dmp_disp_AB = potentials_AB
+ potentials_A = H_A.createPotential(pdb_A.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4)
+ pot_pme_A, \
+ pot_disp_A, \
+ pot_ex_A, \
+ pot_sr_es_A, \
+ pot_sr_pol_A, \
+ pot_sr_disp_A, \
+ pot_dhf_A, \
+ pot_dmp_es_A, \
+ pot_dmp_disp_A = potentials_A
+ potentials_B = H_B.createPotential(pdb_B.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4)
+ pot_pme_B, \
+ pot_disp_B, \
+ pot_ex_B, \
+ pot_sr_es_B, \
+ pot_sr_pol_B, \
+ pot_sr_disp_B, \
+ pot_dhf_B, \
+ pot_dmp_es_B, \
+ pot_dmp_disp_B = potentials_B
+
+ # init positions used to set up neighbor list
+ pos_AB0 = jnp.array(pdb_AB.positions._value) * 10
+ n_atoms = len(pos_AB0)
+ n_atoms_A = n_atoms // 2
+ n_atoms_B = n_atoms // 2
+ pos_A0 = jnp.array(pdb_AB.positions._value[:n_atoms_A]) * 10
+ pos_B0 = jnp.array(pdb_AB.positions._value[n_atoms_A:n_atoms]) * 10
+ box = jnp.array(pdb_AB.topology.getPeriodicBoxVectors()._value) * 10
+
+ # nn list initial allocation
+ displacement_fn, shift_fn = jax_md.space.periodic_general(box, fractional_coordinates=False)
+ neighbor_list_fn = jax_md.partition.neighbor_list(displacement_fn, box, rc, 0, format=jax_md.partition.OrderedSparse)
+ nbr_AB = neighbor_list_fn.allocate(pos_AB0)
+ nbr_A = neighbor_list_fn.allocate(pos_A0)
+ nbr_B = neighbor_list_fn.allocate(pos_B0)
+ pairs_AB = np.array(nbr_AB.idx.T)
+ pairs_A = np.array(nbr_A.idx.T)
+ pairs_B = np.array(nbr_B.idx.T)
+ pairs_AB = pairs_AB[pairs_AB[:, 0] < pairs_AB[:, 1]]
+ pairs_A = pairs_A[pairs_A[:, 0] < pairs_A[:, 1]]
+ pairs_B = pairs_B[pairs_B[:, 0] < pairs_B[:, 1]]
+
+
+ # construct total force field params
+ comps = ['ex', 'es', 'pol', 'disp', 'dhf', 'tot']
+ # load parameters
+ with open(param_file, 'rb') as ifile:
+ params = pickle.load(ifile)
+
+ # setting up params for all calculators
+ params_ex = {}
+ params_sr_es = {}
+ params_sr_pol = {}
+ params_sr_disp = {}
+ params_dhf = {}
+ params_dmp_es = {}
+ params_dmp_disp = {}
+ for k in ['B', 'mScales']:
+ params_ex[k] = params[k]
+ params_sr_es[k] = params[k]
+ params_sr_pol[k] = params[k]
+ params_sr_disp[k] = params[k]
+ params_dhf[k] = params[k]
+ params_dmp_es[k] = params[k]
+ params_dmp_disp[k] = params[k]
+ params_ex['A'] = params['A_ex']
+ params_sr_es['A'] = params['A_es']
+ params_sr_pol['A'] = params['A_pol']
+ params_sr_disp['A'] = params['A_disp']
+ params_dhf['A'] = params['A_dhf']
+ # damping parameters
+ params_dmp_es['Q'] = params['Q']
+ params_dmp_disp['C6'] = params['C6']
+ params_dmp_disp['C8'] = params['C8']
+ params_dmp_disp['C10'] = params['C10']
+ # long range parameters
+ params_espol = {}
+ for k in ['mScales', 'pScales', 'dScales', 'Q_local', 'pol', 'tholes']:
+ params_espol[k] = params[k]
+ params_disp = {}
+ for k in ['B', 'C6', 'C8', 'C10', 'mScales']:
+ params_disp[k] = params[k]
+
+
+ # load data
+ with open('data.pickle', 'rb') as ifile:
+ data = pickle.load(ifile)
+ with open('data_sr.pickle', 'rb') as ifile:
+ data_sr = pickle.load(ifile)
+ with open('data_lr.pickle', 'rb') as ifile:
+ data_lr = pickle.load(ifile)
+ sids = list(data.keys())
+ sids.sort()
+
+ # run test
+ # for sid in sids:
+ for sid in [sys.argv[1]]:
+ scan_res = data[sid]
+ scan_res_sr = data_sr[sid]
+ scan_res_lr = data_lr[sid]
+ npts = len(scan_res['tot'])
+
+ for ipt in range(npts):
+ E_es_ref = scan_res['es'][ipt]
+ E_pol_ref = scan_res['pol'][ipt]
+ E_disp_ref = scan_res['disp'][ipt]
+ E_ex_ref = scan_res['ex'][ipt]
+ E_dhf_ref = scan_res['dhf'][ipt]
+ E_tot_ref = scan_res['tot'][ipt]
+
+ pos_A = jnp.array(scan_res['posA'][ipt])
+ pos_B = jnp.array(scan_res['posB'][ipt])
+ pos_AB = jnp.concatenate([pos_A, pos_B], axis=0)
+
+ #####################
+ # exchange repulsion
+ #####################
+ E_ex_AB = pot_ex_AB(pos_AB, box, pairs_AB, params_ex)
+ E_ex_A = pot_ex_A(pos_A, box, pairs_A, params_ex)
+ E_ex_B = pot_ex_B(pos_B, box, pairs_B, params_ex)
+ E_ex = E_ex_AB - E_ex_A - E_ex_B
+
+ #######################
+ # electrostatic + pol
+ #######################
+ E_AB = pot_pme_AB(pos_AB, box, pairs_AB, params_espol)
+ E_A = pot_pme_A(pos_A, box, pairs_A, params_espol)
+ E_B = pot_pme_B(pos_B, box, pairs_A, params_espol)
+ E_espol = E_AB - E_A - E_B
+
+ # use induced dipole of monomers to compute electrostatic interaction
+ U_ind_AB = jnp.vstack((pme_generator_A.pme_force.U_ind, pme_generator_B.pme_force.U_ind))
+ params = params_espol
+ map_atypes = pme_generator_AB.map_atomtype
+ Q_local = params['Q_local'][map_atypes]
+ pol = params['pol'][map_atypes]
+ tholes = params['tholes'][map_atypes]
+ pme_force = pme_generator_AB.pme_force
+ E_AB_nonpol = pme_force.energy_fn(pos_AB, box, pairs_AB, Q_local, U_ind_AB, pol, tholes, params['mScales'], params['pScales'], params['dScales'])
+ E_es = E_AB_nonpol - E_A - E_B
+ E_dmp_es = pot_dmp_es_AB(pos_AB, box, pairs_AB, params_dmp_es) \
+ - pot_dmp_es_A(pos_A, box, pairs_A, params_dmp_es) \
+ - pot_dmp_es_B(pos_B, box, pairs_B, params_dmp_es)
+ E_sr_es = pot_sr_es_AB(pos_AB, box, pairs_AB, params_sr_es) \
+ - pot_sr_es_A(pos_A, box, pairs_A, params_sr_es) \
+ - pot_sr_es_B(pos_B, box, pairs_B, params_sr_es)
+
+
+ ###################################
+ # polarization (induction) energy
+ ###################################
+ E_pol = E_espol - E_es
+ E_sr_pol = pot_sr_pol_AB(pos_AB, box, pairs_AB, params_sr_pol) \
+ - pot_sr_pol_A(pos_A, box, pairs_A, params_sr_pol) \
+ - pot_sr_pol_B(pos_B, box, pairs_B, params_sr_pol)
+
+
+ #############
+ # dispersion
+ #############
+ E_AB_disp = pot_disp_AB(pos_AB, box, pairs_AB, params_disp)
+ E_A_disp = pot_disp_A(pos_A, box, pairs_A, params_disp)
+ E_B_disp = pot_disp_B(pos_B, box, pairs_B, params_disp)
+ E_disp = E_AB_disp - E_A_disp - E_B_disp
+ E_dmp_disp = pot_dmp_disp_AB(pos_AB, box, pairs_AB, params_dmp_disp) \
+ - pot_dmp_disp_A(pos_A, box, pairs_A, params_dmp_disp) \
+ - pot_dmp_disp_B(pos_B, box, pairs_B, params_dmp_disp)
+ E_sr_disp = pot_sr_disp_AB(pos_AB, box, pairs_AB, params_sr_disp) \
+ - pot_sr_disp_A(pos_A, box, pairs_A, params_sr_disp) \
+ - pot_sr_disp_B(pos_B, box, pairs_B, params_sr_disp)
+
+ ###########
+ # dhf
+ ###########
+ E_AB_dhf = pot_dhf_AB(pos_AB, box, pairs_AB, params_dhf)
+ E_A_dhf = pot_dhf_A(pos_A, box, pairs_A, params_dhf)
+ E_B_dhf = pot_dhf_B(pos_B, box, pairs_B, params_dhf)
+ E_dhf = E_AB_dhf - E_A_dhf - E_B_dhf
+
+ # total energy
+ E_tot = (E_es + E_sr_es + E_dmp_es) + (E_ex) + (E_pol + E_sr_pol) + (E_disp + E_dmp_disp + E_sr_disp) + (E_dhf)
+ E_tot_sr = (E_sr_es + E_dmp_es) + (E_ex) + (E_sr_pol) + (E_sr_disp + E_dmp_disp) + (E_dhf)
+ E_tot_lr = E_es + E_pol + E_disp
+
+ print(ipt, E_tot, E_tot_ref)
+ # print(ipt, E_tot, E_tot_ref, E_tot_sr, data_sr[sid]['tot'][ipt], E_tot_lr, data[sid]['tot'][ipt]-data_sr[sid]['tot'][ipt])
+ # print(ipt, E_tot_lr, scan_res_lr['tot'][ipt])
+ # print(ipt, E_tot_sr, scan_res_sr['tot'][ipt], scan_res['tot'][ipt])
+ # if scan_res['tot'][ipt] < 25:
+ # print(scan_res_sr['tot'][ipt], scan_res_sr['tot'][ipt], E_tot_sr)
+ # # print(scan_res['tot'][ipt], scan_res['tot'][ipt], E_tot)
+ # sys.stdout.flush()
+
diff --git a/examples/peg_slater_isa/data.pickle b/examples/peg_slater_isa/data.pickle
new file mode 100644
index 000000000..e2ab58d9c
Binary files /dev/null and b/examples/peg_slater_isa/data.pickle differ
diff --git a/examples/peg_slater_isa/data_lr.pickle b/examples/peg_slater_isa/data_lr.pickle
new file mode 100644
index 000000000..928c60557
Binary files /dev/null and b/examples/peg_slater_isa/data_lr.pickle differ
diff --git a/examples/peg_slater_isa/data_sr.pickle b/examples/peg_slater_isa/data_sr.pickle
new file mode 100644
index 000000000..82276c420
Binary files /dev/null and b/examples/peg_slater_isa/data_sr.pickle differ
diff --git a/examples/peg_slater_isa/fit.py b/examples/peg_slater_isa/fit.py
new file mode 100755
index 000000000..d5b5bccb8
--- /dev/null
+++ b/examples/peg_slater_isa/fit.py
@@ -0,0 +1,298 @@
+#!/usr/bin/env python
+import sys
+import numpy as np
+import openmm
+from openmm import *
+from openmm.app import *
+from openmm.unit import *
+import jax
+import jax_md
+import jax.numpy as jnp
+import dmff
+from dmff.api import Hamiltonian
+import pickle
+import time
+from jax import value_and_grad, jit
+import optax
+
+
+if __name__ == '__main__':
+ restart = 'params.0.pickle' # None
+ ff = 'forcefield.xml'
+ pdb_AB = PDBFile('peg2_dimer.pdb')
+ pdb_A = PDBFile('peg2.pdb')
+ pdb_B = PDBFile('peg2.pdb')
+ H_AB = Hamiltonian(ff)
+ H_A = Hamiltonian(ff)
+ H_B = Hamiltonian(ff)
+ pme_generator_AB, \
+ disp_generator_AB, \
+ ex_generator_AB, \
+ sr_es_generator_AB, \
+ sr_pol_generator_AB, \
+ sr_disp_generator_AB, \
+ dhf_generator_AB, \
+ dmp_es_generator_AB, \
+ dmp_disp_generator_AB = H_AB.getGenerators()
+ pme_generator_A, \
+ disp_generator_A, \
+ ex_generator_A, \
+ sr_es_generator_A, \
+ sr_pol_generator_A, \
+ sr_disp_generator_A, \
+ dhf_generator_A, \
+ dmp_es_generator_A, \
+ dmp_disp_generator_A = H_A.getGenerators()
+ pme_generator_B, \
+ disp_generator_B, \
+ ex_generator_B, \
+ sr_es_generator_B, \
+ sr_pol_generator_B, \
+ sr_disp_generator_B, \
+ dhf_generator_B, \
+ dmp_es_generator_B, \
+ dmp_disp_generator_B = H_B.getGenerators()
+
+ rc = 15
+
+ # get potential functions
+ potentials_AB = H_AB.createPotential(pdb_AB.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4)
+ pot_pme_AB, \
+ pot_disp_AB, \
+ pot_ex_AB, \
+ pot_sr_es_AB, \
+ pot_sr_pol_AB, \
+ pot_sr_disp_AB, \
+ pot_dhf_AB, \
+ pot_dmp_es_AB, \
+ pot_dmp_disp_AB = potentials_AB
+ potentials_A = H_A.createPotential(pdb_A.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4)
+ pot_pme_A, \
+ pot_disp_A, \
+ pot_ex_A, \
+ pot_sr_es_A, \
+ pot_sr_pol_A, \
+ pot_sr_disp_A, \
+ pot_dhf_A, \
+ pot_dmp_es_A, \
+ pot_dmp_disp_A = potentials_A
+ potentials_B = H_B.createPotential(pdb_B.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4)
+ pot_pme_B, \
+ pot_disp_B, \
+ pot_ex_B, \
+ pot_sr_es_B, \
+ pot_sr_pol_B, \
+ pot_sr_disp_B, \
+ pot_dhf_B, \
+ pot_dmp_es_B, \
+ pot_dmp_disp_B = potentials_B
+
+ pos_AB0 = jnp.array(pdb_AB.positions._value) * 10
+ n_atoms = len(pos_AB0)
+ n_atoms_A = n_atoms // 2
+ n_atoms_B = n_atoms // 2
+ pos_A0 = jnp.array(pdb_AB.positions._value[:n_atoms_A]) * 10
+ pos_B0 = jnp.array(pdb_AB.positions._value[n_atoms_A:n_atoms]) * 10
+ box = jnp.array(pdb_AB.topology.getPeriodicBoxVectors()._value) * 10
+ # nn list initial allocation
+ displacement_fn, shift_fn = jax_md.space.periodic_general(box, fractional_coordinates=False)
+ neighbor_list_fn = jax_md.partition.neighbor_list(displacement_fn, box, rc, 0, format=jax_md.partition.OrderedSparse)
+ nbr_AB = neighbor_list_fn.allocate(pos_AB0)
+ nbr_A = neighbor_list_fn.allocate(pos_A0)
+ nbr_B = neighbor_list_fn.allocate(pos_B0)
+ pairs_AB = np.array(nbr_AB.idx.T)
+ pairs_A = np.array(nbr_A.idx.T)
+ pairs_B = np.array(nbr_B.idx.T)
+
+
+ # construct total force field params
+ comps = ['ex', 'es', 'pol', 'disp', 'dhf', 'tot']
+ weights_comps = jnp.array([0.001, 0.001, 0.001, 0.001, 0.001, 1.0])
+ if restart is None:
+ params = {}
+ sr_generators = {
+ 'ex': ex_generator_AB,
+ 'es': sr_es_generator_AB,
+ 'pol': sr_pol_generator_AB,
+ 'disp': sr_disp_generator_AB,
+ 'dhf': dhf_generator_AB,
+ }
+ for k in pme_generator_AB.params:
+ params[k] = pme_generator_AB.params[k]
+ for k in disp_generator_AB.params:
+ params[k] = disp_generator_AB.params[k]
+ for c in comps:
+ if c == 'tot':
+ continue
+ gen = sr_generators[c]
+ for k in gen.params:
+ if k == 'A':
+ params['A_'+c] = gen.params[k]
+ else:
+ params[k] = gen.params[k]
+ # a random initialization of A
+ for c in comps:
+ if c == 'tot':
+ continue
+ params['A_'+c] = jnp.array(np.random.random(params['A_'+c].shape))
+ # specify charges for es damping
+ params['Q'] = dmp_es_generator_AB.params['Q']
+ else:
+ with open(restart, 'rb') as ifile:
+ params = pickle.load(ifile)
+
+
+ @jit
+ def MSELoss(params, scan_res):
+ '''
+ The weighted mean squared error loss function
+ Conducted for each scan
+ '''
+ E_tot_full = scan_res['tot_full']
+ kT = 2.494 # 300 K = 2.494 kJ/mol
+ weights_pts = jnp.piecewise(E_tot_full, [E_tot_full<25, E_tot_full>=25], [lambda x: jnp.array(1.0), lambda x: jnp.exp(-(x-25)/kT)])
+ npts = len(weights_pts)
+
+ energies = {
+ 'ex': jnp.zeros(npts),
+ 'es': jnp.zeros(npts),
+ 'pol': jnp.zeros(npts),
+ 'disp': jnp.zeros(npts),
+ 'dhf': jnp.zeros(npts),
+ 'tot': jnp.zeros(npts)
+ }
+
+ # setting up params for all calculators
+ params_ex = {}
+ params_sr_es = {}
+ params_sr_pol = {}
+ params_sr_disp = {}
+ params_dhf = {}
+ params_dmp_es = {} # electrostatic damping
+ params_dmp_disp = {} # dispersion damping
+ for k in ['B', 'mScales']:
+ params_ex[k] = params[k]
+ params_sr_es[k] = params[k]
+ params_sr_pol[k] = params[k]
+ params_sr_disp[k] = params[k]
+ params_dhf[k] = params[k]
+ params_dmp_es[k] = params[k]
+ params_dmp_disp[k] = params[k]
+ params_ex['A'] = params['A_ex']
+ params_sr_es['A'] = params['A_es']
+ params_sr_pol['A'] = params['A_pol']
+ params_sr_disp['A'] = params['A_disp']
+ params_dhf['A'] = params['A_dhf']
+ # damping parameters
+ params_dmp_es['Q'] = params['Q']
+ params_dmp_disp['C6'] = params['C6']
+ params_dmp_disp['C8'] = params['C8']
+ params_dmp_disp['C10'] = params['C10']
+
+ # calculate each points, only the short range and damping components
+ for ipt in range(npts):
+ # get position array
+ pos_A = jnp.array(scan_res['posA'][ipt])
+ pos_B = jnp.array(scan_res['posB'][ipt])
+ pos_AB = jnp.concatenate([pos_A, pos_B], axis=0)
+
+ #####################
+ # exchange repulsion
+ #####################
+ E_ex_AB = pot_ex_AB(pos_AB, box, pairs_AB, params_ex)
+ E_ex_A = pot_ex_A(pos_A, box, pairs_A, params_ex)
+ E_ex_B = pot_ex_B(pos_B, box, pairs_B, params_ex)
+ E_ex = E_ex_AB - E_ex_A - E_ex_B
+
+ #######################
+ # electrostatic + pol
+ #######################
+ E_dmp_es = pot_dmp_es_AB(pos_AB, box, pairs_AB, params_dmp_es) \
+ - pot_dmp_es_A(pos_A, box, pairs_A, params_dmp_es) \
+ - pot_dmp_es_B(pos_B, box, pairs_B, params_dmp_es)
+ E_sr_es = pot_sr_es_AB(pos_AB, box, pairs_AB, params_sr_es) \
+ - pot_sr_es_A(pos_A, box, pairs_A, params_sr_es) \
+ - pot_sr_es_B(pos_B, box, pairs_B, params_sr_es)
+
+ ###################################
+ # polarization (induction) energy
+ ###################################
+ E_sr_pol = pot_sr_pol_AB(pos_AB, box, pairs_AB, params_sr_pol) \
+ - pot_sr_pol_A(pos_A, box, pairs_A, params_sr_pol) \
+ - pot_sr_pol_B(pos_B, box, pairs_B, params_sr_pol)
+
+ #############
+ # dispersion
+ #############
+ E_dmp_disp = pot_dmp_disp_AB(pos_AB, box, pairs_AB, params_dmp_disp) \
+ - pot_dmp_disp_A(pos_A, box, pairs_A, params_dmp_disp) \
+ - pot_dmp_disp_B(pos_B, box, pairs_B, params_dmp_disp)
+ E_sr_disp = pot_sr_disp_AB(pos_AB, box, pairs_AB, params_sr_disp) \
+ - pot_sr_disp_A(pos_A, box, pairs_A, params_sr_disp) \
+ - pot_sr_disp_B(pos_B, box, pairs_B, params_sr_disp)
+
+ ###########
+ # dhf
+ ###########
+ E_AB_dhf = pot_dhf_AB(pos_AB, box, pairs_AB, params_dhf)
+ E_A_dhf = pot_dhf_A(pos_A, box, pairs_A, params_dhf)
+ E_B_dhf = pot_dhf_B(pos_B, box, pairs_B, params_dhf)
+ E_dhf = E_AB_dhf - E_A_dhf - E_B_dhf
+
+ energies['ex'] = energies['ex'].at[ipt].set(E_ex)
+ energies['es'] = energies['es'].at[ipt].set(E_dmp_es + E_sr_es)
+ energies['pol'] = energies['pol'].at[ipt].set(E_sr_pol)
+ energies['disp'] = energies['disp'].at[ipt].set(E_dmp_disp + E_sr_disp)
+ energies['dhf'] = energies['dhf'].at[ipt].set(E_dhf)
+ energies['tot'] = energies['tot'].at[ipt].set(E_ex
+ + E_dmp_es + E_sr_es
+ + E_sr_pol
+ + E_dmp_disp + E_sr_disp
+ + E_dhf)
+
+
+ errs = jnp.zeros(len(comps))
+ for ic, c in enumerate(comps):
+ dE = energies[c] - scan_res[c]
+ mse = dE**2 * weights_pts / jnp.sum(weights_pts)
+ errs = errs.at[ic].set(jnp.sum(mse))
+
+ return jnp.sum(weights_comps * errs)
+
+
+ # load data
+ with open('data_sr.pickle', 'rb') as ifile:
+ data = pickle.load(ifile)
+
+ err, gradients = value_and_grad(MSELoss, argnums=(0))(params, data['000'])
+ sids = np.array(list(data.keys()))
+
+
+ # only optimize these parameters A/B
+ def mask_fn(grads):
+ for k in grads:
+ if k.startswith('A_') or k == 'B':
+ continue
+ else:
+ grads[k] = 0.0
+ return grads
+
+ # start to do optmization
+ lr = 0.001
+ optimizer = optax.adam(lr)
+ opt_state = optimizer.init(params)
+
+ n_epochs = 1000
+ for i_epoch in range(n_epochs):
+ np.random.shuffle(sids)
+ for sid in sids:
+ loss, grads = value_and_grad(MSELoss, argnums=(0))(params, data[sid])
+ grads = mask_fn(grads)
+ print(loss)
+ sys.stdout.flush()
+ updates, opt_state = optimizer.update(grads, opt_state)
+ params = optax.apply_updates(params, updates)
+ with open('params.pickle', 'wb') as ofile:
+ pickle.dump(params, ofile)
+
+
diff --git a/examples/peg_slater_isa/fit.sh b/examples/peg_slater_isa/fit.sh
new file mode 100644
index 000000000..6e2a876a1
--- /dev/null
+++ b/examples/peg_slater_isa/fit.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+#SBATCH -N 1 -n 1 --gres=gpu:1
+#SBATCH -t 24:00:00 -o out -e err
+#SBATCH -p rtx3090
+
+./fit.py > log
diff --git a/examples/peg_slater_isa/forcefield.xml b/examples/peg_slater_isa/forcefield.xml
new file mode 100644
index 000000000..25d90db50
--- /dev/null
+++ b/examples/peg_slater_isa/forcefield.xml
@@ -0,0 +1,145 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/peg_slater_isa/forcefield_nonpol.xml b/examples/peg_slater_isa/forcefield_nonpol.xml
new file mode 100644
index 000000000..7160d4f2d
--- /dev/null
+++ b/examples/peg_slater_isa/forcefield_nonpol.xml
@@ -0,0 +1,75 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/peg_slater_isa/params.0.pickle b/examples/peg_slater_isa/params.0.pickle
new file mode 100644
index 000000000..c6e9cd939
Binary files /dev/null and b/examples/peg_slater_isa/params.0.pickle differ
diff --git a/examples/peg_slater_isa/peg2.pdb b/examples/peg_slater_isa/peg2.pdb
new file mode 100644
index 000000000..e44d9f14b
--- /dev/null
+++ b/examples/peg_slater_isa/peg2.pdb
@@ -0,0 +1,35 @@
+TITLE MDANALYSIS FRAME 0: Created by PDBWriter
+CRYST1 30.000 30.000 30.000 90.00 90.00 90.00 P 1 1
+ATOM 1 C00 TER M 1 14.058 13.500 16.329 0.00 1.00 SYST C
+ATOM 2 H01 TER M 1 14.410 14.387 16.870 0.00 1.00 SYST H
+ATOM 3 H02 TER M 1 14.412 12.614 16.871 0.00 1.00 SYST H
+ATOM 4 O03 TER M 1 14.578 13.500 15.000 0.00 1.00 SYST O
+ATOM 5 C04 TER M 1 16.000 13.500 15.000 0.00 1.00 SYST C
+ATOM 6 H05 TER M 1 16.344 13.499 13.962 0.00 1.00 SYST H
+ATOM 7 H06 TER M 1 16.382 12.602 15.496 0.00 1.00 SYST H
+ATOM 8 H07 TER M 1 16.382 14.399 15.493 0.00 1.00 SYST H
+ATOM 9 C00 TER M 2 12.535 13.498 16.276 0.00 1.00 SYST C
+ATOM 10 H01 TER M 2 12.184 12.612 15.734 0.00 1.00 SYST H
+ATOM 11 H02 TER M 2 12.182 14.385 15.735 0.00 1.00 SYST H
+ATOM 12 O03 TER M 2 12.015 13.496 17.605 0.00 1.00 SYST O
+ATOM 13 C04 TER M 2 10.593 13.493 17.605 0.00 1.00 SYST C
+ATOM 14 H05 TER M 2 10.250 13.491 18.643 0.00 1.00 SYST H
+ATOM 15 H06 TER M 2 10.213 12.595 17.109 0.00 1.00 SYST H
+ATOM 16 H07 TER M 2 10.209 14.391 17.112 0.00 1.00 SYST H
+CONECT 1 2 3 4 9
+CONECT 2 1
+CONECT 3 1
+CONECT 4 1 5
+CONECT 5 4 6 7 8
+CONECT 6 5
+CONECT 7 5
+CONECT 8 5
+CONECT 9 1 10 11 12
+CONECT 10 9
+CONECT 11 9
+CONECT 12 9 13
+CONECT 13 12 14 15 16
+CONECT 14 13
+CONECT 15 13
+CONECT 16 13
+END
diff --git a/examples/peg_slater_isa/peg2_dimer.pdb b/examples/peg_slater_isa/peg2_dimer.pdb
new file mode 100644
index 000000000..5bff255ca
--- /dev/null
+++ b/examples/peg_slater_isa/peg2_dimer.pdb
@@ -0,0 +1,67 @@
+TITLE MDANALYSIS FRAME 0: Created by PDBWriter
+CRYST1 30.000 30.000 30.000 90.00 90.00 90.00 P 1 1
+ATOM 1 C00 TER M 1 14.058 13.500 16.329 0.00 1.00 SYST C
+ATOM 2 H01 TER M 1 14.410 14.387 16.870 0.00 1.00 SYST H
+ATOM 3 H02 TER M 1 14.412 12.614 16.871 0.00 1.00 SYST H
+ATOM 4 O03 TER M 1 14.578 13.500 15.000 0.00 1.00 SYST O
+ATOM 5 C04 TER M 1 16.000 13.500 15.000 0.00 1.00 SYST C
+ATOM 6 H05 TER M 1 16.344 13.499 13.962 0.00 1.00 SYST H
+ATOM 7 H06 TER M 1 16.382 12.602 15.496 0.00 1.00 SYST H
+ATOM 8 H07 TER M 1 16.382 14.399 15.493 0.00 1.00 SYST H
+ATOM 9 C00 TER M 2 12.535 13.498 16.276 0.00 1.00 SYST C
+ATOM 10 H01 TER M 2 12.184 12.612 15.734 0.00 1.00 SYST H
+ATOM 11 H02 TER M 2 12.182 14.385 15.735 0.00 1.00 SYST H
+ATOM 12 O03 TER M 2 12.015 13.496 17.605 0.00 1.00 SYST O
+ATOM 13 C04 TER M 2 10.593 13.493 17.605 0.00 1.00 SYST C
+ATOM 14 H05 TER M 2 10.250 13.491 18.643 0.00 1.00 SYST H
+ATOM 15 H06 TER M 2 10.213 12.595 17.109 0.00 1.00 SYST H
+ATOM 16 H07 TER M 2 10.209 14.391 17.112 0.00 1.00 SYST H
+ATOM 17 C00 TER M 3 14.058 18.500 16.329 0.00 1.00 SYST C
+ATOM 18 H01 TER M 3 14.410 19.387 16.870 0.00 1.00 SYST H
+ATOM 19 H02 TER M 3 14.412 17.614 16.871 0.00 1.00 SYST H
+ATOM 20 O03 TER M 3 14.578 18.500 15.000 0.00 1.00 SYST O
+ATOM 21 C04 TER M 3 16.000 18.500 15.000 0.00 1.00 SYST C
+ATOM 22 H05 TER M 3 16.344 18.499 13.962 0.00 1.00 SYST H
+ATOM 23 H06 TER M 3 16.382 17.602 15.496 0.00 1.00 SYST H
+ATOM 24 H07 TER M 3 16.382 19.399 15.493 0.00 1.00 SYST H
+ATOM 25 C00 TER M 4 12.535 18.498 16.276 0.00 1.00 SYST C
+ATOM 26 H01 TER M 4 12.184 17.612 15.734 0.00 1.00 SYST H
+ATOM 27 H02 TER M 4 12.182 19.385 15.735 0.00 1.00 SYST H
+ATOM 28 O03 TER M 4 12.015 18.496 17.605 0.00 1.00 SYST O
+ATOM 29 C04 TER M 4 10.593 18.493 17.605 0.00 1.00 SYST C
+ATOM 30 H05 TER M 4 10.250 18.491 18.643 0.00 1.00 SYST H
+ATOM 31 H06 TER M 4 10.213 17.595 17.109 0.00 1.00 SYST H
+ATOM 32 H07 TER M 4 10.209 19.391 17.112 0.00 1.00 SYST H
+CONECT 1 2 3 4 9
+CONECT 2 1
+CONECT 3 1
+CONECT 4 1 5
+CONECT 5 4 6 7 8
+CONECT 6 5
+CONECT 7 5
+CONECT 8 5
+CONECT 9 1 10 11 12
+CONECT 10 9
+CONECT 11 9
+CONECT 12 9 13
+CONECT 13 12 14 15 16
+CONECT 14 13
+CONECT 15 13
+CONECT 16 13
+CONECT 17 18 19 20 25
+CONECT 18 17
+CONECT 19 17
+CONECT 20 17 21
+CONECT 21 20 22 23 24
+CONECT 22 21
+CONECT 23 21
+CONECT 24 21
+CONECT 25 17 26 27 28
+CONECT 26 25
+CONECT 27 25
+CONECT 28 25 29
+CONECT 29 28 30 31 32
+CONECT 30 29
+CONECT 31 29
+CONECT 32 29
+END
diff --git a/examples/peg_slater_isa/peg4.pdb b/examples/peg_slater_isa/peg4.pdb
new file mode 100644
index 000000000..2c11081d1
--- /dev/null
+++ b/examples/peg_slater_isa/peg4.pdb
@@ -0,0 +1,63 @@
+REMARK
+CRYST1 50.000 50.000 50.000 90.00 90.00 90.00 P 1 1
+ATOM 1 C00 TER 1 -2.962 3.637 -1.170
+ATOM 2 H01 TER 1 -2.608 4.142 -0.296
+ATOM 3 H02 TER 1 -4.032 3.635 -1.171
+ATOM 4 O03 TER 1 -2.484 2.289 -1.168
+ATOM 5 C04 TER 1 -2.961 1.615 0.000
+ATOM 6 H05 TER 1 -2.604 0.606 0.000
+ATOM 7 H06 TER 1 -2.604 2.119 0.874
+ATOM 8 H07 TER 1 -4.031 1.615 0.000
+ATOM 9 C00 INT 2 -2.449 6.384 -3.596
+ATOM 10 H01 INT 2 -2.804 5.879 -4.470
+ATOM 11 H02 INT 2 -1.379 6.386 -3.595
+ATOM 12 O03 INT 2 -2.927 5.710 -2.429
+ATOM 13 C04 INT 2 -2.448 4.362 -2.427
+ATOM 14 H05 INT 2 -2.803 3.856 -3.301
+ATOM 15 H06 INT 2 -1.378 4.364 -2.425
+ATOM 16 C00 INT 3 -2.966 9.857 -4.767
+ATOM 17 H01 INT 3 -2.612 10.363 -3.893
+ATOM 18 H02 INT 3 -4.036 9.855 -4.768
+ATOM 19 O03 INT 3 -2.488 8.509 -4.765
+ATOM 20 C04 INT 3 -2.965 7.835 -3.597
+ATOM 21 H05 INT 3 -2.610 8.340 -2.724
+ATOM 22 H06 INT 3 -4.035 7.833 -3.599
+ATOM 23 C00 TER 4 -2.452 10.582 -6.024
+ATOM 24 H01 TER 4 -2.807 10.077 -6.898
+ATOM 25 H02 TER 4 -1.382 10.584 -6.022
+ATOM 26 O03 TER 4 -2.931 11.930 -6.026
+ATOM 27 C04 TER 4 -2.453 12.604 -7.193
+ATOM 28 H05 TER 4 -2.808 12.099 -8.067
+ATOM 29 H06 TER 4 -2.812 13.613 -7.194
+ATOM 30 H07 TER 4 -1.383 12.606 -7.192
+TER
+CONECT 5 6
+CONECT 5 7
+CONECT 5 8
+CONECT 5 4
+CONECT 4 1
+CONECT 1 2
+CONECT 1 3
+CONECT 1 13
+CONECT 13 14
+CONECT 13 15
+CONECT 13 12
+CONECT 12 9
+CONECT 9 10
+CONECT 9 11
+CONECT 9 20
+CONECT 20 21
+CONECT 20 22
+CONECT 20 19
+CONECT 19 16
+CONECT 16 17
+CONECT 16 18
+CONECT 16 23
+CONECT 23 24
+CONECT 23 25
+CONECT 23 26
+CONECT 26 27
+CONECT 27 28
+CONECT 27 29
+CONECT 27 30
+END
diff --git a/examples/peg_slater_isa/remove_lr.py b/examples/peg_slater_isa/remove_lr.py
new file mode 100755
index 000000000..125ab83e1
--- /dev/null
+++ b/examples/peg_slater_isa/remove_lr.py
@@ -0,0 +1,222 @@
+#!/usr/bin/env python
+import numpy as np
+import openmm
+from openmm import *
+from openmm.app import *
+from openmm.unit import *
+import jax
+import jax_md
+import jax.numpy as jnp
+import dmff
+from dmff.api import Hamiltonian
+import pickle
+import time
+
+
+if __name__ == '__main__':
+ ff = 'forcefield.xml'
+ pdb_AB = PDBFile('peg2_dimer.pdb')
+ pdb_A = PDBFile('peg2.pdb')
+ pdb_B = PDBFile('peg2.pdb')
+ H_AB = Hamiltonian(ff)
+ H_A = Hamiltonian(ff)
+ H_B = Hamiltonian(ff)
+ pme_generator_AB, \
+ disp_generator_AB, \
+ ex_generator_AB, \
+ sr_es_generator_AB, \
+ sr_pol_generator_AB, \
+ sr_disp_generator_AB, \
+ dhf_generator_AB, \
+ dmp_es_generator_AB, \
+ dmp_disp_generator_AB = H_AB.getGenerators()
+ pme_generator_A, \
+ disp_generator_A, \
+ ex_generator_A, \
+ sr_es_generator_A, \
+ sr_pol_generator_A, \
+ sr_disp_generator_A, \
+ dhf_generator_A, \
+ dmp_es_generator_A, \
+ dmp_disp_generator_A = H_A.getGenerators()
+ pme_generator_B, \
+ disp_generator_B, \
+ ex_generator_B, \
+ sr_es_generator_B, \
+ sr_pol_generator_B, \
+ sr_disp_generator_B, \
+ dhf_generator_B, \
+ dmp_es_generator_B, \
+ dmp_disp_generator_B = H_B.getGenerators()
+
+ rc = 15
+
+ # get potential functions
+ potentials_AB = H_AB.createPotential(pdb_AB.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4)
+ pot_pme_AB, \
+ pot_disp_AB, \
+ pot_ex_AB, \
+ pot_sr_es_AB, \
+ pot_sr_pol_AB, \
+ pot_sr_disp_AB, \
+ pot_dhf_AB, \
+ pot_dmp_es_AB, \
+ pot_dmp_disp_AB = potentials_AB
+ potentials_A = H_A.createPotential(pdb_A.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4)
+ pot_pme_A, \
+ pot_disp_A, \
+ pot_ex_A, \
+ pot_sr_es_A, \
+ pot_sr_pol_A, \
+ pot_sr_disp_A, \
+ pot_dhf_A, \
+ pot_dmp_es_A, \
+ pot_dmp_disp_A = potentials_A
+ potentials_B = H_B.createPotential(pdb_B.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4)
+ pot_pme_B, \
+ pot_disp_B, \
+ pot_ex_B, \
+ pot_sr_es_B, \
+ pot_sr_pol_B, \
+ pot_sr_disp_B, \
+ pot_dhf_B, \
+ pot_dmp_es_B, \
+ pot_dmp_disp_B = potentials_B
+
+ pos_AB0 = jnp.array(pdb_AB.positions._value) * 10
+ n_atoms = len(pos_AB0)
+ n_atoms_A = n_atoms // 2
+ n_atoms_B = n_atoms // 2
+ pos_A0 = jnp.array(pdb_AB.positions._value[:n_atoms_A]) * 10
+ pos_B0 = jnp.array(pdb_AB.positions._value[n_atoms_A:n_atoms]) * 10
+ box = jnp.array(pdb_AB.topology.getPeriodicBoxVectors()._value) * 10
+ # nn list initial allocation
+ displacement_fn, shift_fn = jax_md.space.periodic_general(box, fractional_coordinates=False)
+ neighbor_list_fn = jax_md.partition.neighbor_list(displacement_fn, box, rc, 0, format=jax_md.partition.OrderedSparse)
+ nbr_AB = neighbor_list_fn.allocate(pos_AB0)
+ nbr_A = neighbor_list_fn.allocate(pos_A0)
+ nbr_B = neighbor_list_fn.allocate(pos_B0)
+ pairs_AB = np.array(nbr_AB.idx.T)
+ pairs_A = np.array(nbr_A.idx.T)
+ pairs_B = np.array(nbr_B.idx.T)
+ pairs_AB = pairs_AB[pairs_AB[:, 0] < pairs_AB[:, 1]]
+ pairs_A = pairs_A[pairs_A[:, 0] < pairs_A[:, 1]]
+ pairs_B = pairs_B[pairs_B[:, 0] < pairs_B[:, 1]]
+
+ # load data
+ with open('data.pickle', 'rb') as ifile:
+ data = pickle.load(ifile)
+
+ keys = list(data.keys())
+ keys.sort()
+ data_lr = {}
+ for sid in keys:
+ scan_res = data[sid]
+ scan_res['tot_full'] = scan_res['tot'].copy()
+ npts = len(scan_res['tot'])
+ # long range
+ scan_res_lr = {}
+ scan_res_lr['es'] = np.zeros(npts)
+ scan_res_lr['pol'] = np.zeros(npts)
+ scan_res_lr['disp'] = np.zeros(npts)
+ scan_res_lr['tot'] = np.zeros(npts)
+ print(sid)
+
+ for ipt in range(npts):
+ E_es_ref = scan_res['es'][ipt]
+ E_pol_ref = scan_res['pol'][ipt]
+ E_disp_ref = scan_res['disp'][ipt]
+ E_ex_ref = scan_res['ex'][ipt]
+ E_dhf_ref = scan_res['dhf'][ipt]
+ E_tot_ref = scan_res['tot'][ipt]
+
+ # get position array
+ pos_A = jnp.array(scan_res['posA'][ipt])
+ pos_B = jnp.array(scan_res['posB'][ipt])
+ pos_AB = jnp.concatenate([pos_A, pos_B], axis=0)
+
+
+ #####################
+ # exchange repulsion
+ #####################
+ # E_ex_AB = pot_ex_AB(pos_AB, box, pairs_AB, ex_generator_AB.params)
+ # E_ex_A = pot_ex_A(pos_A, box, pairs_A, ex_generator_AB.params)
+ # E_ex_B = pot_ex_B(pos_B, box, pairs_B, ex_generator_AB.params)
+ # E_ex = E_ex_AB - E_ex_A - E_ex_B
+
+ #######################
+ # electrostatic + pol
+ #######################
+ E_AB = pot_pme_AB(pos_AB, box, pairs_AB, pme_generator_AB.params)
+ E_A = pot_pme_A(pos_A, box, pairs_A, pme_generator_A.params)
+ E_B = pot_pme_B(pos_B, box, pairs_A, pme_generator_B.params)
+ E_espol = E_AB - E_A - E_B
+
+ # use induced dipole of monomers to compute electrostatic interaction
+ U_ind_AB = jnp.vstack((pme_generator_A.pme_force.U_ind, pme_generator_B.pme_force.U_ind))
+ params = pme_generator_AB.params
+ map_atypes = pme_generator_AB.map_atomtype
+ Q_local = params['Q_local'][map_atypes]
+ pol = params['pol'][map_atypes]
+ tholes = params['tholes'][map_atypes]
+ pme_force = pme_generator_AB.pme_force
+ E_AB_nonpol = pme_force.energy_fn(pos_AB, box, pairs_AB, Q_local, U_ind_AB, pol, tholes, params['mScales'], params['pScales'], params['dScales'])
+ E_es = E_AB_nonpol - E_A - E_B
+ # E_dmp_es = pot_dmp_es_AB(pos_AB, box, pairs_AB, dmp_es_generator_AB.params) \
+ # - pot_dmp_es_A(pos_A, box, pairs_A, dmp_es_generator_A.params) \
+ # - pot_dmp_es_B(pos_B, box, pairs_B, dmp_es_generator_B.params)
+ # E_sr_es = pot_sr_es_AB(pos_AB, box, pairs_AB, sr_es_generator_AB.params) \
+ # - pot_sr_es_A(pos_A, box, pairs_A, sr_es_generator_AB.params) \
+ # - pot_sr_es_B(pos_B, box, pairs_B, sr_es_generator_AB.params)
+
+
+ ###################################
+ # polarization (induction) energy
+ ###################################
+ E_pol = E_espol - E_es
+ # E_sr_pol = pot_sr_pol_AB(pos_AB, box, pairs_AB, sr_pol_generator_AB.params) \
+ # - pot_sr_pol_A(pos_A, box, pairs_A, sr_pol_generator_AB.params) \
+ # - pot_sr_pol_B(pos_B, box, pairs_B, sr_pol_generator_AB.params)
+
+
+ #############
+ # dispersion
+ #############
+ E_AB_disp = pot_disp_AB(pos_AB, box, pairs_AB, disp_generator_AB.params)
+ E_A_disp = pot_disp_A(pos_A, box, pairs_A, disp_generator_AB.params)
+ E_B_disp = pot_disp_B(pos_B, box, pairs_B, disp_generator_AB.params)
+ E_disp = E_AB_disp - E_A_disp - E_B_disp
+ # E_dmp_disp = pot_dmp_disp_AB(pos_AB, box, pairs_AB, dmp_disp_generator_AB.params) \
+ # - pot_dmp_disp_A(pos_A, box, pairs_A, dmp_disp_generator_A.params) \
+ # - pot_dmp_disp_B(pos_B, box, pairs_B, dmp_disp_generator_B.params)
+ # E_sr_disp = pot_sr_disp_AB(pos_AB, box, pairs_AB, sr_disp_generator_AB.params) \
+ # - pot_sr_disp_A(pos_A, box, pairs_A, sr_disp_generator_AB.params) \
+ # - pot_sr_disp_B(pos_B, box, pairs_B, sr_disp_generator_AB.params)
+
+ ###########
+ # dhf
+ ###########
+ # E_AB_dhf = pot_dhf_AB(pos_AB, box, pairs_AB, dhf_generator_AB.params)
+ # E_A_dhf = pot_dhf_A(pos_A, box, pairs_A, dhf_generator_AB.params)
+ # E_B_dhf = pot_dhf_B(pos_B, box, pairs_B, dhf_generator_AB.params)
+ # E_dhf = E_AB_dhf - E_A_dhf - E_B_dhf
+
+ # remove long range
+ scan_res['es'][ipt] -= E_es
+ scan_res['pol'][ipt] -= E_pol
+ scan_res['disp'][ipt] -= E_disp
+ scan_res['tot'][ipt] -= (E_es + E_pol + E_disp)
+ # save long range
+ scan_res_lr['es'][ipt] = E_es
+ scan_res_lr['pol'][ipt] = E_pol
+ scan_res_lr['disp'][ipt] = E_disp
+ scan_res_lr['tot'][ipt] = E_es + E_pol + E_disp
+ data[sid] = scan_res
+ data_lr[sid] = scan_res_lr
+
+
+with open('data_sr.pickle', 'wb') as ofile:
+ pickle.dump(data, ofile)
+
+with open('data_lr.pickle', 'wb') as ofile:
+ pickle.dump(data_lr, ofile)
diff --git a/examples/peg_slater_isa/residues.xml b/examples/peg_slater_isa/residues.xml
new file mode 100644
index 000000000..d8e0ea8e7
--- /dev/null
+++ b/examples/peg_slater_isa/residues.xml
@@ -0,0 +1,29 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/peg_slater_isa/run_amoeba.py b/examples/peg_slater_isa/run_amoeba.py
new file mode 100755
index 000000000..bf2f9b891
--- /dev/null
+++ b/examples/peg_slater_isa/run_amoeba.py
@@ -0,0 +1,97 @@
+#!/usr/bin/env python
+import sys
+import numpy as np
+import openmm
+from openmm import *
+from openmm.app import *
+from openmm.unit import *
+import pickle
+
+if __name__ == '__main__':
+ pdb_AB = PDBFile('peg2_dimer.pdb')
+ pdb_A = PDBFile('peg2.pdb')
+ pdb_B = PDBFile('peg2.pdb')
+ forcefield = ForceField('benchmark.xml')
+
+ system_AB = forcefield.createSystem(pdb_AB.topology, nonbondedMethod=PME, nonbondedCutoff=15*angstrom)
+ system_A = forcefield.createSystem(pdb_A.topology, nonbondedMethod=PME, nonbondedCutoff=15*angstrom)
+ system_B = forcefield.createSystem(pdb_B.topology, nonbondedMethod=PME, nonbondedCutoff=15*angstrom)
+ forces_AB = system_AB.getForces()
+ forces_A = system_A.getForces()
+ forces_B = system_B.getForces()
+ for i in range(len(forces_AB)):
+ forces_AB[i].setForceGroup(i)
+ forces_A[i].setForceGroup(i)
+ forces_B[i].setForceGroup(i)
+
+ platform_AB = Platform.getPlatformByName('CUDA')
+ platform_A = Platform.getPlatformByName('CUDA')
+ platform_B = Platform.getPlatformByName('CUDA')
+ properties = {}
+
+ integrator_AB = LangevinIntegrator(300*kelvin, 1.0/picosecond, 1*femtosecond)
+ integrator_A = LangevinIntegrator(300*kelvin, 1.0/picosecond, 1*femtosecond)
+ integrator_B = LangevinIntegrator(300*kelvin, 1.0/picosecond, 1*femtosecond)
+
+ simulation_AB = Simulation(pdb_AB.topology, system_AB, integrator_AB, platform=platform_AB)
+ simulation_A = Simulation(pdb_A.topology, system_A, integrator_A, platform=platform_A)
+ simulation_B = Simulation(pdb_B.topology, system_B, integrator_B, platform=platform_B)
+
+
+ pos_AB0 = np.array(pdb_AB.positions._value) * 10
+ n_atoms = len(pos_AB0)
+ n_atoms_A = n_atoms // 2
+ n_atoms_B = n_atoms // 2
+ pos_A0 = pos_AB0[:n_atoms_A]
+ pos_B0 = pos_AB0[n_atoms_A: n_atoms]
+
+ # dr = np.average(pos_B0) - np.average(pos_A0)
+ # dn = dr / np.linalg.norm(dr)
+
+ # for dz in np.arange(0, 4, 0.1):
+ # pos_A = pos_A0
+ # pos_B = pos_B0 + dz * dn
+ # pos_AB = np.vstack((pos_A, pos_B))
+ # simulation_AB.context.setPositions(pos_AB * angstrom)
+ # simulation_A.context.setPositions(pos_A * angstrom)
+ # simulation_B.context.setPositions(pos_B * angstrom)
+
+ # state_AB = simulation_AB.context.getState(getEnergy=True)
+ # state_A = simulation_A.context.getState(getEnergy=True)
+ # state_B = simulation_B.context.getState(getEnergy=True)
+
+ # E_AB = state_AB.getPotentialEnergy()._value
+ # E_A = state_A.getPotentialEnergy()._value
+ # E_B = state_B.getPotentialEnergy()._value
+
+ # print(dz, E_AB - E_A - E_B)
+
+ with open('data.pickle', 'rb') as ifile:
+ data = pickle.load(ifile)
+
+ for sid in ['000']:
+ scan_res = data[sid]
+
+ for ipt in range(len(scan_res['posA'])):
+ pos_A = np.array(scan_res['posA'][ipt])
+ pos_B = np.array(scan_res['posB'][ipt])
+ pos_AB = np.vstack([pos_A, pos_B])
+ E_es_ref = scan_res['es'][ipt]
+ E_pol_ref = scan_res['pol'][ipt]
+
+ simulation_AB.context.setPositions(pos_AB * angstrom)
+ simulation_A.context.setPositions(pos_A * angstrom)
+ simulation_B.context.setPositions(pos_B * angstrom)
+
+ state_AB = simulation_AB.context.getState(getEnergy=True, groups=2**0)
+ state_A = simulation_A.context.getState(getEnergy=True, groups=2**0)
+ state_B = simulation_B.context.getState(getEnergy=True, groups=2**0)
+ # state_AB = simulation_AB.context.getState(getEnergy=True)
+ # state_A = simulation_A.context.getState(getEnergy=True)
+ # state_B = simulation_B.context.getState(getEnergy=True)
+
+ E_AB = state_AB.getPotentialEnergy()._value
+ E_A = state_A.getPotentialEnergy()._value
+ E_B = state_B.getPotentialEnergy()._value
+
+ print(E_AB - E_A - E_B)
diff --git a/examples/peg_slater_isa/tot_slater.xml b/examples/peg_slater_isa/tot_slater.xml
new file mode 100644
index 000000000..105d498cf
--- /dev/null
+++ b/examples/peg_slater_isa/tot_slater.xml
@@ -0,0 +1,78 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/sgnn/model1.pickle b/examples/sgnn/model1.pickle
index a9efb1262..0c3959cd9 100644
Binary files a/examples/sgnn/model1.pickle and b/examples/sgnn/model1.pickle differ
diff --git a/examples/sgnn/pth2pickle.py b/examples/sgnn/pth2pickle.py
index 0ef791de3..d05bb9081 100755
--- a/examples/sgnn/pth2pickle.py
+++ b/examples/sgnn/pth2pickle.py
@@ -8,6 +8,9 @@
pth = sys.argv[1]
state_dict = torch.load(sys.argv[1])
+for k in state_dict:
+ state_dict[k] = state_dict[k].numpy()
+
ofn = re.sub('\.pth$', '.pickle', pth)
with open(ofn, 'wb') as ofile:
pickle.dump(state_dict, ofile)