diff --git a/dmff/admp/disp_pme.py b/dmff/admp/disp_pme.py index 3df7d11d6..11c30eab0 100755 --- a/dmff/admp/disp_pme.py +++ b/dmff/admp/disp_pme.py @@ -1,6 +1,6 @@ import jax.numpy as jnp from jax import vmap, value_and_grad -from dmff.utils import jit_condition +from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales from dmff.admp.spatial import pbc_shift from dmff.admp.pme import setup_ewald_parameters from dmff.admp.recip import generate_pme_recip, Ck_6, Ck_8, Ck_10 @@ -14,17 +14,24 @@ class ADMPDispPmeForce: The so called "environment paramters" means parameters that do not need to be differentiable ''' - def __init__(self, box, covalent_map, rc, ethresh, pmax): + def __init__(self, box, covalent_map, rc, ethresh, pmax, lpme=True): self.covalent_map = covalent_map self.rc = rc self.ethresh = ethresh self.pmax = pmax # Need a different function for dispersion ??? Need tests - kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box) - self.kappa = kappa - self.K1 = K1 - self.K2 = K2 - self.K3 = K3 + self.lpme = lpme + if lpme: + kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box) + self.kappa = kappa + self.K1 = K1 + self.K2 = K2 + self.K3 = K3 + else: + self.kappa = 0.0 + self.K1 = 0 + self.K2 = 0 + self.K3 = 0 self.pme_order = 6 # setup calculators self.refresh_calculators() @@ -36,7 +43,7 @@ def get_energy(positions, box, pairs, c_list, mScales): return energy_disp_pme(positions, box, pairs, c_list, mScales, self.covalent_map, self.kappa, self.K1, self.K2, self.K3, self.pmax, - self.d6_recip, self.d8_recip, self.d10_recip) + self.d6_recip, self.d8_recip, self.d10_recip, lpme=self.lpme) return get_energy @@ -70,7 +77,7 @@ def refresh_calculators(self): def energy_disp_pme(positions, box, pairs, c_list, mScales, covalent_map, kappa, K1, K2, K3, pmax, - recip_fn6, recip_fn8, recip_fn10): + recip_fn6, recip_fn8, recip_fn10, lpme=True): ''' Top level wrapper for dispersion pme @@ -95,22 +102,29 @@ def energy_disp_pme(positions, box, pairs, int: max K for reciprocal calculations pmax: int array: maximal exponents (p) to compute, e.g., (6, 8, 10) + lpme: + bool: whether do pme or not, useful when doing cluster calculations Output: energy: total dispersion pme energy ''' - ene_real = disp_pme_real(positions, box, pairs, c_list, mScales, covalent_map, kappa, pmax) + if lpme is False: + kappa = 0 - ene_recip = recip_fn6(positions, box, c_list[:, 0, jnp.newaxis]) - if pmax >= 8: - ene_recip += recip_fn8(positions, box, c_list[:, 1, jnp.newaxis]) - if pmax >= 10: - ene_recip += recip_fn10(positions, box, c_list[:, 2, jnp.newaxis]) + ene_real = disp_pme_real(positions, box, pairs, c_list, mScales, covalent_map, kappa, pmax) - ene_self = disp_pme_self(c_list, kappa, pmax) + if lpme: + ene_recip = recip_fn6(positions, box, c_list[:, 0, jnp.newaxis]) + if pmax >= 8: + ene_recip += recip_fn8(positions, box, c_list[:, 1, jnp.newaxis]) + if pmax >= 10: + ene_recip += recip_fn10(positions, box, c_list[:, 2, jnp.newaxis]) + ene_self = disp_pme_self(c_list, kappa, pmax) + return ene_real + ene_recip + ene_self - return ene_real + ene_recip + ene_self + else: + return ene_real def disp_pme_real(positions, box, pairs, @@ -144,24 +158,26 @@ def disp_pme_real(positions, box, pairs, ''' # expand pairwise parameters - pairs = pairs[pairs[:, 0] < pairs[:, 1]] + # pairs = pairs[pairs[:, 0] < pairs[:, 1]] + pairs = regularize_pairs(pairs) box_inv = jnp.linalg.inv(box) ri = distribute_v3(positions, pairs[:, 0]) rj = distribute_v3(positions, pairs[:, 1]) - # ri = positions[pairs[:, 0]] - # rj = positions[pairs[:, 1]] nbonds = covalent_map[pairs[:, 0], pairs[:, 1]] mscales = distribute_scalar(mScales, nbonds-1) - # mscales = mScales[nbonds-1] + + buffer_scales = pair_buffer_scales(pairs) + mscales = mscales * buffer_scales ci = distribute_dispcoeff(c_list, pairs[:, 0]) cj = distribute_dispcoeff(c_list, pairs[:, 1]) - # ci = c_list[pairs[:, 0], :] - # cj = c_list[pairs[:, 1], :] - ene_real = jnp.sum(disp_pme_real_kernel(ri, rj, ci, cj, box, box_inv, mscales, kappa, pmax)) + ene_real = jnp.sum( + disp_pme_real_kernel(ri, rj, ci, cj, box, box_inv, mscales, kappa, pmax) + * buffer_scales + ) return jnp.sum(ene_real) @@ -193,6 +209,7 @@ def disp_pme_real_kernel(ri, rj, ci, cj, box, box_inv, mscales, kappa, pmax): dr = ri - rj dr = pbc_shift(dr, box, box_inv) dr2 = jnp.dot(dr, dr) + x2 = kappa * kappa * dr2 g = g_p(x2, pmax) dr6 = dr2 * dr2 * dr2 @@ -269,85 +286,3 @@ def disp_pme_self(c_list, kappa, pmax): return E -# def validation(pdb): -# xml = 'mpidwater.xml' -# pdbinfo = read_pdb(pdb) -# serials = pdbinfo['serials'] -# names = pdbinfo['names'] -# resNames = pdbinfo['resNames'] -# resSeqs = pdbinfo['resSeqs'] -# positions = pdbinfo['positions'] -# box = pdbinfo['box'] # a, b, c, α, β, γ -# charges = pdbinfo['charges'] -# positions = jnp.asarray(positions) -# lx, ly, lz, _, _, _ = box -# box = jnp.eye(3)*jnp.array([lx, ly, lz]) - -# mScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) -# pScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) -# dScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) - -# rc = 4 # in Angstrom -# ethresh = 1e-4 - -# n_atoms = len(serials) - -# atomTemplate, residueTemplate = read_xml(xml) -# atomDicts, residueDicts = init_residues(serials, names, resNames, resSeqs, positions, charges, atomTemplate, residueTemplate) - -# covalent_map = assemble_covalent(residueDicts, n_atoms) -# displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=False) -# neighbor_list_fn = partition.neighbor_list(displacement_fn, box, rc, 0, format=partition.OrderedSparse) -# nbr = neighbor_list_fn.allocate(positions) -# pairs = nbr.idx.T - -# pmax = 10 -# kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box) -# kappa = 0.657065221219616 - -# # construct the C list -# c_list = np.zeros((3,n_atoms)) -# nmol=int(n_atoms/3) -# for i in range(nmol): -# a = i*3 -# b = i*3+1 -# c = i*3+2 -# c_list[0][a]=37.19677405 -# c_list[0][b]=7.6111103 -# c_list[0][c]=7.6111103 -# c_list[1][a]=85.26810658 -# c_list[1][b]=11.90220148 -# c_list[1][c]=11.90220148 -# c_list[2][a]=134.44874488 -# c_list[2][b]=15.05074749 -# c_list[2][c]=15.05074749 -# c_list = jnp.array(c_list.T) - - -# # Finish data preparation -# # ------------------------------------------------------------------------------------- -# # pme_order = 6 -# # d6_recip = generate_pme_recip(Ck_6, kappa, True, pme_order, K1, K2, K3, 0) -# # d8_recip = generate_pme_recip(Ck_8, kappa, True, pme_order, K1, K2, K3, 0) -# # d10_recip = generate_pme_recip(Ck_10, kappa, True, pme_order, K1, K2, K3, 0) -# # disp_pme_recip_fns = [d6_recip, d8_recip, d10_recip] -# # energy_force_disp_pme = value_and_grad(energy_disp_pme) -# # e, f = energy_force_disp_pme(positions, box, pairs, c_list, mScales, covalent_map, kappa, K1, K2, K3, pmax, *disp_pme_recip_fns) -# # print('ok') -# # e, f = energy_force_disp_pme(positions, box, pairs, c_list, mScales, covalent_map, kappa, K1, K2, K3, pmax, *disp_pme_recip_fns) -# # print(e) - -# disp_pme_force = ADMPDispPmeForce(box, covalent_map, rc, ethresh, pmax) -# disp_pme_force.update_env('kappa', 0.657065221219616) - -# print(c_list[:4]) -# E, F = disp_pme_force.get_forces(positions, box, pairs, c_list, mScales) -# print('ok') -# E, F = disp_pme_force.get_forces(positions, box, pairs, c_list, mScales) -# print(E) -# return - - -# # below is the validation code -# if __name__ == '__main__': -# validation(sys.argv[1]) diff --git a/dmff/admp/mbpol_intra.py b/dmff/admp/mbpol_intra.py new file mode 100755 index 000000000..7b1ffff56 --- /dev/null +++ b/dmff/admp/mbpol_intra.py @@ -0,0 +1,529 @@ + +import sys +import numpy as np +import jax.numpy as jnp +from jax import grad, value_and_grad +from dmff.settings import DO_JIT +from dmff.utils import jit_condition +from dmff.admp.spatial import v_pbc_shift +from dmff.admp.pme import ADMPPmeForce +from dmff.admp.parser import * +from jax import vmap +import time +#from admp.multipole import convert_cart2harm +#from jax_md import partition, space + +#const +f5z = 0.999677885 +fbasis = 0.15860145369897 +fcore = -1.6351695982132 +frest = 1.0 +reoh = 0.958649; +thetae = 104.3475; +b1 = 2.0; +roh = 0.9519607159623009; +alphaoh = 2.587949757553683; +deohA = 42290.92019288289; +phh1A = 16.94879431193463; +phh2 = 12.66426998162947; + +c5zA = jnp.array([4.2278462684916e+04, 4.5859382909906e-02, 9.4804986183058e+03, + 7.5485566680955e+02, 1.9865052511496e+03, 4.3768071560862e+02, + 1.4466054104131e+03, 1.3591924557890e+02,-1.4299027252645e+03, + 6.6966329416373e+02, 3.8065088734195e+03,-5.0582552618154e+02, + -3.2067534385604e+03, 6.9673382568135e+02, 1.6789085874578e+03, + -3.5387509130093e+03,-1.2902326455736e+04,-6.4271125232353e+03, + -6.9346876863641e+03,-4.9765266152649e+02,-3.4380943579627e+03, + 3.9925274973255e+03,-1.2703668547457e+04,-1.5831591056092e+04, + 2.9431777405339e+04, 2.5071411925779e+04,-4.8518811956397e+04, + -1.4430705306580e+04, 2.5844109323395e+04,-2.3371683301770e+03, + 1.2333872678202e+04, 6.6525207018832e+03,-2.0884209672231e+03, + -6.3008463062877e+03, 4.2548148298119e+04, 2.1561445953347e+04, + -1.5517277060400e+05, 2.9277086555691e+04, 2.6154026873478e+05, + -1.3093666159230e+05,-1.6260425387088e+05, 1.2311652217133e+05, + -5.1764697159603e+04, 2.5287599662992e+03, 3.0114701659513e+04, + -2.0580084492150e+03, 3.3617940269402e+04, 1.3503379582016e+04, + -1.0401149481887e+05,-6.3248258344140e+04, 2.4576697811922e+05, + 8.9685253338525e+04,-2.3910076031416e+05,-6.5265145723160e+04, + 8.9184290973880e+04,-8.0850272976101e+03,-3.1054961140464e+04, + -1.3684354599285e+04, 9.3754012976495e+03,-7.4676475789329e+04, + -1.8122270942076e+05, 2.6987309391410e+05, 4.0582251904706e+05, + -4.7103517814752e+05,-3.6115503974010e+05, 3.2284775325099e+05, + 1.3264691929787e+04, 1.8025253924335e+05,-1.2235925565102e+04, + -9.1363898120735e+03,-4.1294242946858e+04,-3.4995730900098e+04, + 3.1769893347165e+05, 2.8395605362570e+05,-1.0784536354219e+06, + -5.9451106980882e+05, 1.5215430060937e+06, 4.5943167339298e+05, + -7.9957883936866e+05,-9.2432840622294e+04, 5.5825423140341e+03, + 3.0673594098716e+03, 8.7439532014842e+04, 1.9113438435651e+05, + -3.4306742659939e+05,-3.0711488132651e+05, 6.2118702580693e+05, + -1.5805976377422e+04,-4.2038045404190e+05, 3.4847108834282e+05, + -1.3486811106770e+04, 3.1256632170871e+04, 5.3344700235019e+03, + 2.6384242145376e+04, 1.2917121516510e+05,-1.3160848301195e+05, + -4.5853998051192e+05, 3.5760105069089e+05, 6.4570143281747e+05, + -3.6980075904167e+05,-3.2941029518332e+05,-3.5042507366553e+05, + 2.1513919629391e+03, 6.3403845616538e+04, 6.2152822008047e+04, + -4.8805335375295e+05,-6.3261951398766e+05, 1.8433340786742e+06, + 1.4650263449690e+06,-2.9204939728308e+06,-1.1011338105757e+06, + 1.7270664922758e+06, 3.4925947462024e+05,-1.9526251371308e+04, + -3.2271030511683e+04,-3.7601575719875e+05, 1.8295007005531e+05, + 1.5005699079799e+06,-1.2350076538617e+06,-1.8221938812193e+06, + 1.5438780841786e+06,-3.2729150692367e+03, 1.0546285883943e+04, + -4.7118461673723e+04,-1.1458551385925e+05, 2.7704588008958e+05, + 7.4145816862032e+05,-6.6864945408289e+05,-1.6992324545166e+06, + 6.7487333473248e+05, 1.4361670430046e+06,-2.0837555267331e+05, + 4.7678355561019e+05,-1.5194821786066e+04,-1.1987249931134e+05, + 1.3007675671713e+05, 9.6641544907323e+05,-5.3379849922258e+05, + -2.4303858824867e+06, 1.5261649025605e+06, 2.0186755858342e+06, + -1.6429544469130e+06,-1.7921520714752e+04, 1.4125624734639e+04, + -2.5345006031695e+04, 1.7853375909076e+05,-5.4318156343922e+04, + -3.6889685715963e+05, 4.2449670705837e+05, 3.5020329799394e+05, + 9.3825886484788e+03,-8.0012127425648e+05, 9.8554789856472e+04, + 4.9210554266522e+05,-6.4038493953446e+05,-2.8398085766046e+06, + 2.1390360019254e+06, 6.3452935017176e+06,-2.3677386290925e+06, + -3.9697874352050e+06,-1.9490691547041e+04, 4.4213579019433e+04, + 1.6113884156437e+05,-7.1247665213713e+05,-1.1808376404616e+06, + 3.0815171952564e+06, 1.3519809705593e+06,-3.4457898745450e+06, + 2.0705775494050e+05,-4.3778169926622e+05, 8.7041260169714e+03, + 1.8982512628535e+05,-2.9708215504578e+05,-8.8213012222074e+05, + 8.6031109049755e+05, 1.0968800857081e+06,-1.0114716732602e+06, + 1.9367263614108e+05, 2.8678295007137e+05,-9.4347729862989e+04, + 4.4154039394108e+04, 5.3686756196439e+05, 1.7254041770855e+05, + -2.5310674462399e+06,-2.0381171865455e+06, 3.3780796258176e+06, + 7.8836220768478e+05,-1.5307728782887e+05,-3.7573362053757e+05, + 1.0124501604626e+06, 2.0929686545723e+06,-5.7305706586465e+06, + -2.6200352535413e+06, 7.1543745536691e+06,-1.9733601879064e+04, + 8.5273008477607e+04, 6.1062454495045e+04,-2.2642508675984e+05, + 2.4581653864150e+05,-9.0376851105383e+05,-4.4367930945690e+05, + 1.5740351463593e+06, 2.4563041445249e+05,-3.4697646046367e+03, + -2.1391370322552e+05, 4.2358948404842e+05, 5.6270081955003e+05, + -8.5007851251980e+05,-6.1182429537130e+05, 5.6690751824341e+05, + -3.5617502919487e+05,-8.1875263381402e+02,-2.4506258140060e+05, + 2.5830513731509e+05, 6.0646114465433e+05,-6.9676584616955e+05, + 5.1937406389690e+05, 1.7261913546007e+05,-1.7405787307472e+04, + -3.8301842660567e+05, 5.4227693205154e+05, 2.5442083515211e+06, + -1.1837755702370e+06,-1.9381959088092e+06,-4.0642141553575e+05, + 1.1840693827934e+04,-1.5334500255967e+05, 4.9098619510989e+05, + 6.1688992640977e+05, 2.2351144690009e+05,-1.8550462739570e+06, + 9.6815110649918e+03,-8.1526584681055e+04,-8.0810433155289e+04, + 3.4520506615177e+05, 2.5509863381419e+05,-1.3331224992157e+05, + -4.3119301071653e+05,-5.9818343115856e+04, 1.7863692414573e+03, + 8.9440694919836e+04,-2.5558967650731e+05,-2.2130423988459e+04, + 4.4973674518316e+05,-2.2094939343618e+05]) + +cbasis = jnp.array([6.9770019624764e-04,-2.4209870001642e+01, 1.8113927151562e+01, + 3.5107416275981e+01,-5.4600021126735e+00,-4.8731149608386e+01, + 3.6007189184766e+01, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + -7.7178474355102e+01,-3.8460795013977e+01,-4.6622480912340e+01, + 5.5684951167513e+01, 1.2274939911242e+02,-1.4325154752086e+02, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00,-6.0800589055949e+00, + 8.6171499453475e+01,-8.4066835441327e+01,-5.8228085624620e+01, + 2.0237393793875e+02, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 3.3525582670313e+02, 7.0056962392208e+01,-4.5312502936708e+01, + -3.0441141194247e+02, 2.8111438108965e+02, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00,-1.2983583774779e+02, 3.9781671212935e+01, + -6.6793945229609e+01,-1.9259805675433e+02, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00,-8.2855757669957e+02,-5.7003072730941e+01, + -3.5604806670066e+01, 9.6277766002709e+01, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 8.8645622149112e+02,-7.6908409772041e+01, + 6.8111763314154e+01, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 2.5090493428062e+02,-2.3622141780572e+02, 5.8155647658455e+02, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 2.8919570295095e+03, + -1.7871014635921e+02,-1.3515667622500e+02, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00,-3.6965613754734e+03, 2.1148158286617e+02, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00,-1.4795670139431e+03, + 3.6210798138768e+02, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + -5.3552886800881e+03, 3.1006384016202e+02, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 1.6241824368764e+03, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 4.3764909606382e+03, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 1.0940849243716e+03, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 3.0743267832931e+03, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00]) + +ccore = jnp.array([2.4332191647159e-02,-2.9749090113656e+01, 1.8638980892831e+01, + -6.1272361746520e+00, 2.1567487597605e+00,-1.5552044084945e+01, + 8.9752150543954e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + -3.5693557878741e+02,-3.0398393196894e+00,-6.5936553294576e+00, + 1.6056619388911e+01, 7.8061422868204e+01,-8.6270891686359e+01, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00,-3.1688002530217e+01, + 3.7586725583944e+01,-3.2725765966657e+01,-5.6458213299259e+00, + 2.1502613314595e+01, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 5.2789943583277e+02,-4.2461079404962e+00,-2.4937638543122e+01, + -1.1963809321312e+02, 2.0240663228078e+02, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00,-6.2574211352272e+02,-6.9617539465382e+00, + -5.9440243471241e+01, 1.4944220180218e+01, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00,-1.2851139918332e+03,-6.5043516710835e+00, + 4.0410829440249e+01,-6.7162452402027e+01, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 1.0031942127832e+03, 7.6137226541944e+01, + -2.7279242226902e+01, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + -3.3059000871075e+01, 2.4384498749480e+01,-1.4597931874215e+02, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 1.6559579606045e+03, + 1.5038996611400e+02,-7.3865347730818e+01, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00,-1.9738401290808e+03,-1.4149993809415e+02, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00,-1.2756627454888e+02, + 4.1487702227579e+01, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + -1.7406770966429e+03,-9.3812204399266e+01, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00,-1.1890301282216e+03, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 2.3723447727360e+03, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00,-1.0279968223292e+03, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 5.7153838472603e+02, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00]) + +crest = jnp.array([ 0.0000000000000e+00,-4.7430930170000e+00,-1.4422132560000e+01, + -1.8061146510000e+01, 7.5186735000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + -2.7962099800000e+02, 1.7616414260000e+01,-9.9741392630000e+01, + 7.1402447000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00,-7.8571336480000e+01, + 5.2434353250000e+01, 7.7696745000000e+01, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 1.7799123760000e+02, 1.4564532380000e+02, 2.2347226000000e+02, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00,-4.3823284100000e+02,-7.2846553000000e+02, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00,-2.6752313750000e+02, 3.6170310000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00, 0.0000000000000e+00, + 0.0000000000000e+00, 0.0000000000000e+00]) + +idx1 = jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 5, + 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, + 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, + 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 9, 9, + 9, 9, 9, 9, 9]) + +idx2 = jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 4, + 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, + 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, + 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 1, 1, + 1, 1, 1, 1, 1]) + +idx3 = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15, 1, 2, 3, 4, 5, + 6, 7, 8, 9,10,11,12,13,14, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11, + 12,13, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13, 1, 2, 3, 4, 5, + 6, 7, 8, 9,10,11,12, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 1, + 2, 3, 4, 5, 6, 7, 8, 9,10,11, 1, 2, 3, 4, 5, 6, 7, 8, 9,10, + 11, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11, 1, 2, 3, 4, 5, 6, 7, 8, + 9,10, 1, 2, 3, 4, 5, 6, 7, 8, 9,10, 1, 2, 3, 4, 5, 6, 7, 8, + 9,10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, + 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, + 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, + 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, + 3, 4, 5, 6, 7]) + +matrix1 = np.zeros((245,16)) +matrix2 = np.zeros((245,16)) +matrix3 = np.zeros((245,16)) +for i in range(245): + a = int(idx1[i]) + b = int(idx2[i]) + c = int(idx3[i]) + list1 = np.zeros(16) + list2 = np.zeros(16) + list3 = np.zeros(16) + list1[a] = 1 + list2[b] = 1 + list3[c] = 1 + matrix1[i] = list1 + matrix2[i] = list2 + matrix3[i] = list3 + +c5z = jnp.zeros(245) +for i in range(245): + c5z = c5z.at[i].set(f5z*c5zA[i] + fbasis*cbasis[i]+ fcore*ccore[i] + frest*crest[i]) +deoh = f5z*deohA +phh1 = f5z*phh1A*jnp.exp(phh2) +costhe = -0.24780227221366464506 + +Eh_J = 4.35974434e-18 +Na = 6.02214129e+23 +kcal_J = 4184.0 +c0 = 299792458.0 +h_Js = 6.62606957e-34 +cal2joule = 4.184 +Eh_kcalmol = Eh_J*Na/kcal_J +Eh_cm1 = 1.0e-2*Eh_J/(c0*h_Js) +cm1_kcalmol = Eh_kcalmol/Eh_cm1 + + +## compute intra +def onebodyenergy(positions, box): + box_inv = jnp.linalg.inv(box) + O = positions[::3] + H1 = positions[1::3] + H2 = positions[2::3] + ROH1 = H1 - O + ROH2 = H2 - O + RHH = H1 - H2 + ROH1 = v_pbc_shift(ROH1, box, box_inv) + ROH2 = v_pbc_shift(ROH2, box, box_inv) + RHH = v_pbc_shift(RHH, box, box_inv) + dROH1 = jnp.linalg.norm(ROH1, axis=1) + dROH2 = jnp.linalg.norm(ROH2, axis=1) + dRHH = jnp.linalg.norm(RHH, axis=1) + costh = jnp.sum(ROH1 * ROH2, axis=1) / (dROH1 * dROH2) + exp1 = jnp.exp(-alphaoh*(dROH1 - roh)) + exp2 = jnp.exp(-alphaoh*(dROH2 - roh)) + Va = deoh*(exp1*(exp1 - 2.0) + exp2*(exp2 - 2.0)) + Vb = phh1*jnp.exp(-phh2*dRHH) + x1 = (dROH1 - reoh)/reoh + x2 = (dROH2 - reoh)/reoh + x3 = costh - costhe + efac = jnp.exp(-b1*(dROH1 - reoh)**2 + (dROH2 - reoh)**2) + energy = jnp.sum(onebody_kernel(x1, x2, x3, Va, Vb, efac)) + return energy + + + +@vmap +@jit_condition(static_argnums={}) +def onebody_kernel(x1, x2, x3, Va, Vb, efac): + const = jnp.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + CONST = jnp.array([const,const,const]) + list1 = jnp.array([x1**i for i in range(-1, 15)]) + list2 = jnp.array([x2**i for i in range(-1, 15)]) + list3 = jnp.array([x3**i for i in range(-1, 15)]) + fmat = jnp.array([list1, list2, list3]) + fmat *= CONST + F1 = jnp.sum(fmat[0].T * matrix1, axis=1) # fmat[0][inI] 1*245 + F2 = jnp.sum(fmat[1].T * matrix2, axis=1) #fmat[1][inJ] 1*245 + F3 = jnp.sum(fmat[0].T * matrix2, axis=1) #fmat[0][inJ] 1*245 + F4 = jnp.sum(fmat[1].T * matrix1, axis=1) #fmat[1][inI] 1*245 + F5 = jnp.sum(fmat[2].T * matrix3, axis=1) #fmat[2][inK] 1*245 + total = c5z * (F1*F2 + F3*F4)* F5 + sum0 = jnp.sum(total[1:245]) + Vc = 2*c5z[0] + efac*sum0 + e1 = Va + Vb + Vc + e1 += 0.44739574026257 + e1 *= cm1_kcalmol + e1 *= cal2joule # conver cal 2 j + return e1 + + +def validation(pdb): + xml = 'mpidwater.xml' + pdbinfo = read_pdb(pdb) + serials = pdbinfo['serials'] + names = pdbinfo['names'] + resNames = pdbinfo['resNames'] + resSeqs = pdbinfo['resSeqs'] + positions = pdbinfo['positions'] + box = pdbinfo['box'] # a, b, c, α, β, γ + charges = pdbinfo['charges'] + positions = jnp.asarray(positions) + lx, ly, lz, _, _, _ = box + box = jnp.eye(3)*jnp.array([lx, ly, lz]) + + mScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) + pScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) + dScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) + + rc = 4 # in Angstrom + ethresh = 1e-4 + + n_atoms = len(serials) + + # compute intra + + + grad_E1 = value_and_grad(onebodyenergy,argnums=(0)) + ene, force = grad_E1(positions, box) + print(ene,force) + return + + +# below is the validation code +if __name__ == '__main__': + validation(sys.argv[1]) + + diff --git a/dmff/admp/pairwise.py b/dmff/admp/pairwise.py index d7b6547ce..d8b7b80d2 100755 --- a/dmff/admp/pairwise.py +++ b/dmff/admp/pairwise.py @@ -1,10 +1,12 @@ import sys from jax import vmap import jax.numpy as jnp -from dmff.utils import jit_condition +from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales from dmff.admp.spatial import v_pbc_shift from functools import partial +DIELECTRIC = 1389.35455846 + # for debug # from jax_md import partition, space # from admp.parser import * @@ -62,13 +64,17 @@ def generate_pairwise_interaction(pair_int_kernel, covalent_map, static_args): ''' def pair_int(positions, box, pairs, mScales, *atomic_params): - pairs = pairs[pairs[:, 0] < pairs[:, 1]] + pairs = regularize_pairs(pairs) + ri = distribute_v3(positions, pairs[:, 0]) rj = distribute_v3(positions, pairs[:, 1]) # ri = positions[pairs[:, 0]] # rj = positions[pairs[:, 1]] nbonds = covalent_map[pairs[:, 0], pairs[:, 1]] mscales = distribute_scalar(mScales, nbonds-1) + + buffer_scales = pair_buffer_scales(pairs) + mscales = mscales * buffer_scales # mscales = mScales[nbonds-1] box_inv = jnp.linalg.inv(box) dr = ri - rj @@ -82,7 +88,7 @@ def pair_int(positions, box, pairs, mScales, *atomic_params): # pair_params.append(param[pairs[:, 0]]) # pair_params.append(param[pairs[:, 1]]) - energy = jnp.sum(pair_int_kernel(dr, mscales, *pair_params)) + energy = jnp.sum(pair_int_kernel(dr, mscales, *pair_params) * buffer_scales) return energy return pair_int @@ -110,6 +116,58 @@ def TT_damping_qq_c6_kernel(dr, m, ai, aj, bi, bj, qi, qj, ci, cj): return f * m +@vmap +@jit_condition(static_argnums={}) +def TT_damping_qq_kernel(dr, m, bi, bj, qi, qj): + b = jnp.sqrt(bi * bj) + q = qi * qj + br = b * dr + exp_br = jnp.exp(-br) + f = - DIELECTRIC * exp_br * (1+br) * q / dr + return f * m + + +@vmap +@jit_condition(static_argnums=()) +def slater_disp_damping_kernel(dr, m, bi, bj, c6i, c6j, c8i, c8j, c10i, c10j): + ''' + Slater-ISA type damping for dispersion: + f(x) = -e^{-x} * \sum_{k} x^k/k! + x = Br - \frac{2*(Br)^2 + 3Br}{(Br)^2 + 3*Br + 3} + see jctc 12 3851 + ''' + b = jnp.sqrt(bi * bj) + c6 = c6i * c6j + c8 = c8i * c8j + c10 = c10i * c10j + br = b * dr + br2 = br * br + x = br - (2*br2 + 3*br) / (br2 + 3*br + 3) + s6 = 1 + x + x**2/2 + x**3/6 + x**4/24 + x**5/120 + x**6/720 + s8 = s6 + x**7/5040 + x**8/40320 + s10 = s8 + x**9/362880 + x**10/3628800 + exp_x = jnp.exp(-x) + f6 = exp_x * s6 + f8 = exp_x * s8 + f10 = exp_x * s10 + return (f6*c6/dr**6 + f8*c8/dr**8 + f10*c10/dr**10) * m + + +@vmap +@jit_condition(static_argnums=()) +def slater_sr_kernel(dr, m, ai, aj, bi, bj): + ''' + Slater-ISA type short range terms + see jctc 12 3851 + ''' + b = jnp.sqrt(bi * bj) + a = ai * aj + br = b * dr + br2 = br * br + P = 1/3 * br2 + br + 1 + return a * P * jnp.exp(-br) * m + + def validation(pdb): xml = 'mpidwater.xml' pdbinfo = read_pdb(pdb) diff --git a/dmff/admp/pme.py b/dmff/admp/pme.py index 6112950d7..2cde47a1c 100755 --- a/dmff/admp/pme.py +++ b/dmff/admp/pme.py @@ -6,7 +6,7 @@ from jax.scipy.special import erf from dmff.settings import DO_JIT from dmff.admp.settings import POL_CONV, MAX_N_POL -from dmff.utils import jit_condition +from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales from dmff.admp.multipole import C1_c2h, convert_cart2harm from dmff.admp.multipole import rot_ind_global2local, rot_global2local, rot_local2global from dmff.admp.spatial import v_pbc_shift, generate_construct_local_frames, build_quasi_internal @@ -34,20 +34,58 @@ class ADMPPmeForce: The so called "environment paramters" means parameters that do not need to be differentiable ''' - def __init__(self, box, axis_type, axis_indices, covalent_map, rc, ethresh, lmax, lpol=False): + def __init__(self, box, axis_type, axis_indices, covalent_map, rc, ethresh, lmax, lpol=False, lpme=True, steps_pol=None): + ''' + Initialize the ADMPPmeForce calculator. + + Input: + box: + (3, 3) float, box size in row + axis_type: + (na,) int, types of local axis (bisector, z-then-x etc.) + covalent_map: + (na, na) int, covalent map matrix, labels the topological distances between atoms + rc: + float: cutoff distance + ethresh: + float: pme energy threshold + lmax: + int: max L for multipoles + lpol: + bool: polarize or not? + lpme: + bool: do pme or simple cutoff? + if False, the kappa will be set to zero and the reciprocal part will not be computed + steps: + None or int: Whether do fixed number of dipole iteration steps? + if None: converge dipoles until convergence threshold is met + if int: optimize for this many steps and stop, this is useful if you want to jit the entire function + + Output: + + ''' self.axis_type = axis_type self.axis_indices = axis_indices self.rc = rc self.ethresh = ethresh self.lmax = int(lmax) # jichen: type checking - kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box) - self.kappa = kappa - self.K1 = K1 - self.K2 = K2 - self.K3 = K3 + # turn off pme if lpme is False, this is useful when doing cluster calculations + self.lpme = lpme + if self.lpme is False: + self.kappa = 0 + self.K1 = 0 + self.K2 = 0 + self.K3 = 0 + else: + kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box) + self.kappa = kappa + self.K1 = K1 + self.K2 = K2 + self.K3 = K3 self.pme_order = 6 self.covalent_map = covalent_map self.lpol = lpol + self.steps_pol = steps_pol self.n_atoms = int(covalent_map.shape[0]) # len(axis_type) # setup calculators @@ -63,7 +101,7 @@ def get_energy(positions, box, pairs, Q_local, mScales): Q_local, None, None, None, mScales, None, None, self.covalent_map, self.construct_local_frames, self.pme_recip, - self.kappa, self.K1, self.K2, self.K3, self.lmax, False) + self.kappa, self.K1, self.K2, self.K3, self.lmax, False, lpme=self.lpme) return get_energy else: # this is the bare energy calculator, with Uind as explicit input @@ -72,14 +110,20 @@ def energy_fn(positions, box, pairs, Q_local, Uind_global, pol, tholes, mScales, Q_local, Uind_global, pol, tholes, mScales, pScales, dScales, self.covalent_map, self.construct_local_frames, self.pme_recip, - self.kappa, self.K1, self.K2, self.K3, self.lmax, True) + self.kappa, self.K1, self.K2, self.K3, self.lmax, True, lpme=self.lpme) self.energy_fn = energy_fn self.grad_U_fn = grad(self.energy_fn, argnums=(4)) self.grad_pos_fn = grad(self.energy_fn, argnums=(0)) self.U_ind = jnp.zeros((self.n_atoms, 3)) # this is the wrapper that include a Uind optimizer - def get_energy(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, U_init=self.U_ind): - self.U_ind, self.lconverg, self.n_cycle = self.optimize_Uind(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, U_init=U_init) + def get_energy( + positions, box, pairs, + Q_local, pol, tholes, mScales, pScales, dScales, + U_init=self.U_ind): + self.U_ind, self.lconverg, self.n_cycle = self.optimize_Uind( + positions, box, pairs, Q_local, pol, tholes, + mScales, pScales, dScales, + U_init=U_init, steps_pol=self.steps_pol) # here we rely on Feynman-Hellman theorem, drop the term dV/dU*dU/dr ! # self.U_ind = jax.lax.stop_gradient(U_ind) return self.energy_fn(positions, box, pairs, Q_local, self.U_ind, pol, tholes, mScales, pScales, dScales) @@ -112,7 +156,11 @@ def refresh_calculators(self): self.get_forces = value_and_grad(self.get_energy) return - def optimize_Uind(self, positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, U_init=None, maxiter=MAX_N_POL, thresh=POL_CONV): + def optimize_Uind(self, + positions, box, pairs, + Q_local, pol, tholes, mScales, pScales, dScales, + U_init=None, steps_pol=None, + maxiter=MAX_N_POL, thresh=POL_CONV): ''' This function converges the induced dipole Note that we cut all the gradient chain passing through this function as we assume Feynman-Hellman theorem @@ -131,19 +179,28 @@ def optimize_Uind(self, positions, box, pairs, Q_local, pol, tholes, mScales, pS U = jnp.zeros((self.n_atoms, 3)) else: U = U_init - site_filter = (pol>0.001) # focus on the actual polarizable sites - - for i in range(maxiter): - field = self.grad_U_fn(positions, box, pairs, Q_local, U, pol, tholes, mScales, pScales, dScales) - E = self.energy_fn(positions, box, pairs, Q_local, U, pol, tholes, mScales, pScales, dScales) - if jnp.max(jnp.abs(field[site_filter])) < thresh: - break - U = U - field * pol[:, jnp.newaxis] / DIELECTRIC - if i == maxiter-1: - flag = False - else: # converged + if steps_pol is None: + site_filter = (pol>0.001) # focus on the actual polarizable sites + + if steps_pol is None: + for i in range(maxiter): + field = self.grad_U_fn(positions, box, pairs, Q_local, U, pol, tholes, mScales, pScales, dScales) + # E = self.energy_fn(positions, box, pairs, Q_local, U, pol, tholes, mScales, pScales, dScales) + if jnp.max(jnp.abs(field[site_filter])) < thresh: + break + U = U - field * pol[:, jnp.newaxis] / DIELECTRIC + if i == maxiter-1: + flag = False + else: # converged + flag = True + else: + def update_U(i, U): + field = self.grad_U_fn(positions, box, pairs, Q_local, U, pol, tholes, mScales, pScales, dScales) + U = U - field * pol[:, jnp.newaxis] / DIELECTRIC + return U + U = jax.lax.fori_loop(0, steps_pol, update_U, U) flag = True - return U, flag, i + return U, flag, steps_pol def setup_ewald_parameters(rc, ethresh, box): @@ -179,7 +236,7 @@ def setup_ewald_parameters(rc, ethresh, box): def energy_pme(positions, box, pairs, Q_local, Uind_global, pol, tholes, mScales, pScales, dScales, covalent_map, - construct_local_frame_fn, pme_recip_fn, kappa, K1, K2, K3, lmax, lpol): + construct_local_frame_fn, pme_recip_fn, kappa, K1, K2, K3, lmax, lpol, lpme=True): ''' This is the top-level wrapper for multipole PME @@ -213,8 +270,10 @@ def energy_pme(positions, box, pairs, int: max K for reciprocal calculations lmax: int: maximum L - bool: - int: if polarizable or not? if yes, 1, otherwise 0 + lpol: + bool: if polarizable or not? if yes, 1, otherwise 0 + lpme: + bool: doing pme? If false, then turn off reciprocal space and set kappa = 0 Output: energy: total pme energy @@ -240,6 +299,9 @@ def energy_pme(positions, box, pairs, else: Q_global_tot = Q_global + if lpme is False: + kappa = 0 + if lpol: ene_real = pme_real(positions, box, pairs, Q_global, U_ind, pol, tholes, mScales, pScales, dScales, covalent_map, kappa, lmax, True) @@ -247,14 +309,23 @@ def energy_pme(positions, box, pairs, ene_real = pme_real(positions, box, pairs, Q_global, None, None, None, mScales, None, None, covalent_map, kappa, lmax, False) - ene_recip = pme_recip_fn(positions, box, Q_global_tot) + if lpme: + ene_recip = pme_recip_fn(positions, box, Q_global_tot) - ene_self = pme_self(Q_global_tot, kappa, lmax) + ene_self = pme_self(Q_global_tot, kappa, lmax) - if lpol: - ene_self += pol_penalty(U_ind, pol) + if lpol: + ene_self += pol_penalty(U_ind, pol) + + return ene_real + ene_recip + ene_self - return ene_real + ene_recip + ene_self + else: + if lpol: + ene_self = pol_penalty(U_ind, pol) + else: + ene_self = 0.0 + + return ene_real + ene_self # @partial(vmap, in_axes=(0, 0, None, None), out_axes=0) @@ -669,7 +740,10 @@ def pme_real(positions, box, pairs, ''' # expand pairwise parameters, from atomic parameters - pairs = pairs[pairs[:, 0] < pairs[:, 1]] + # debug + # pairs = pairs[pairs[:, 0] < pairs[:, 1]] + pairs = regularize_pairs(pairs) + buffer_scales = pair_buffer_scales(pairs) box_inv = jnp.linalg.inv(box) r1 = distribute_v3(positions, pairs[:, 0]) r2 = distribute_v3(positions, pairs[:, 1]) @@ -682,6 +756,7 @@ def pme_real(positions, box, pairs, nbonds = covalent_map[pairs[:, 0], pairs[:, 1]] indices = nbonds-1 mscales = distribute_scalar(mScales, indices) + mscales = mscales * buffer_scales # mscales = mScales[nbonds-1] if lpol: pol1 = distribute_scalar(pol, pairs[:, 0]) @@ -691,7 +766,9 @@ def pme_real(positions, box, pairs, Uind_extendi = distribute_v3(Uind_global, pairs[:, 0]) Uind_extendj = distribute_v3(Uind_global, pairs[:, 1]) pscales = distribute_scalar(pScales, indices) + pscales = pscales * buffer_scales dscales = distribute_scalar(dScales, indices) + dscales = dscales * buffer_scales # pol1 = pol[pairs[:,0]] # pol2 = pol[pairs[:,1]] # thole1 = tholes[pairs[:,0]] @@ -725,7 +802,10 @@ def pme_real(positions, box, pairs, qiUindJ = None # everything should be pair-specific now - ene = jnp.sum(pme_real_kernel(norm_dr, qiQI, qiQJ, qiUindI, qiUindJ, thole1, thole2, dmp, mscales, pscales, dscales, kappa, lmax, lpol)) + ene = jnp.sum( + pme_real_kernel(norm_dr, qiQI, qiQJ, qiUindI, qiUindJ, thole1, thole2, dmp, mscales, pscales, dscales, kappa, lmax, lpol) + * buffer_scales + ) return ene diff --git a/dmff/admp/settings.py b/dmff/admp/settings.py index f08e88162..aa05b5704 100644 --- a/dmff/admp/settings.py +++ b/dmff/admp/settings.py @@ -1,3 +1,3 @@ # DEFAULT THRESHOLDS -POL_CONV = 10.0 # gradient convergence thresh for induced dipoles -MAX_N_POL = 30 # maximum number of cyles for optimizing induced dipole \ No newline at end of file +POL_CONV = 1.0 # gradient convergence thresh for induced dipoles +MAX_N_POL = 30 # maximum number of cyles for optimizing induced dipole diff --git a/dmff/api.py b/dmff/api.py index ba9823ff3..fc8788231 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -10,16 +10,18 @@ from .admp.disp_pme import ADMPDispPmeForce from .admp.multipole import convert_cart2harm, rot_local2global from .admp.pairwise import TT_damping_qq_c6_kernel, generate_pairwise_interaction +from .admp.pairwise import slater_disp_damping_kernel, slater_sr_kernel, TT_damping_qq_kernel from .admp.pme import ADMPPmeForce from .admp.spatial import generate_construct_local_frames from .admp.recip import Ck_1, generate_pme_recip +from .utils import jit_condition from .classical.intra import ( HarmonicBondJaxForce, HarmonicAngleJaxForce, PeriodicTorsionJaxForce, ) from jax_md import space, partition -from jax import grad +from jax import grad, vmap import linecache import itertools from .classical.inter import ( @@ -87,7 +89,7 @@ def build_covalent_map(data, max_neighbor): if k != i and k not in j_list: covalent_map[i, k] = n_curr + 1 covalent_map[k, i] = n_curr + 1 - return covalent_map + return jnp.array(covalent_map) def findAtomTypeTexts(attribs, num): @@ -136,12 +138,25 @@ def parseElement(element, hamiltonian): def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): + methodMap = { + app.CutoffPeriodic: "CutoffPeriodic", + app.NoCutoff: "NoCutoff", + app.PME: "PME", + } + if nonbondedMethod not in methodMap: + raise ValueError("Illegal nonbonded method for ADMPDispForce") + if nonbondedMethod is app.CutoffPeriodic: + self.lpme = False + else: + self.lpme = True + n_atoms = len(data.atoms) # build index map map_atomtype = np.zeros(n_atoms, dtype=int) for i in range(n_atoms): atype = data.atomType[data.atoms[i]] map_atomtype[i] = np.where(self.types == atype)[0][0] + self.map_atomtype = map_atomtype # build covalent map covalent_map = build_covalent_map(data, 6) @@ -155,13 +170,9 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): if "ethresh" in args: self.ethresh = args["ethresh"] - Force_DispPME = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh, self.pmax) + Force_DispPME = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh, + self.pmax, lpme=self.lpme) self.disp_pme_force = Force_DispPME - # debugging - # Force_DispPME.update_env('kappa', 0.657065221219616) - # Force_DispPME.update_env('K1', 96) - # Force_DispPME.update_env('K2', 96) - # Force_DispPME.update_env('K3', 96) pot_fn_lr = Force_DispPME.get_energy pot_fn_sr = generate_pairwise_interaction( TT_damping_qq_c6_kernel, covalent_map, static_args={} @@ -195,11 +206,369 @@ def renderXML(self): # generate xml force field file pass - # register all parsers app.forcefield.parsers["ADMPDispForce"] = ADMPDispGenerator.parseElement +class ADMPDispPmeGenerator: + ''' + This one computes the undamped C6/C8/C10 interactions + u = \sum_{ij} c6/r^6 + c8/r^8 + c10/r^10 + ''' + + def __init__(self, hamiltonian): + self.ff = hamiltonian + self.params = { + "C6": [], + "C8": [], + "C10": [] + } + self._jaxPotential = None + self.types = [] + self.ethresh = 5e-4 + self.pmax = 10 + + def registerAtomType(self, atom): + self.types.append(atom["type"]) + self.params["C6"].append(float(atom["C6"])) + self.params["C8"].append(float(atom["C8"])) + self.params["C10"].append(float(atom["C10"])) + + @staticmethod + def parseElement(element, hamiltonian): + generator = ADMPDispPmeGenerator(hamiltonian) + hamiltonian.registerGenerator(generator) + # covalent scales + mScales = [] + for i in range(2, 7): + mScales.append(float(element.attrib["mScale1%d" % i])) + mScales.append(1.0) + generator.params["mScales"] = mScales + for atomtype in element.findall("Atom"): + generator.registerAtomType(atomtype.attrib) + # jax it! + for k in generator.params.keys(): + generator.params[k] = jnp.array(generator.params[k]) + generator.types = np.array(generator.types) + + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, + args): + methodMap = { + app.CutoffPeriodic: "CutoffPeriodic", + app.NoCutoff: "NoCutoff", + app.PME: "PME", + } + if nonbondedMethod not in methodMap: + raise ValueError("Illegal nonbonded method for ADMPDispPmeForce") + if nonbondedMethod is app.CutoffPeriodic: + self.lpme = False + else: + self.lpme = True + + n_atoms = len(data.atoms) + # build index map + map_atomtype = np.zeros(n_atoms, dtype=int) + for i in range(n_atoms): + atype = data.atomType[data.atoms[i]] + map_atomtype[i] = np.where(self.types == atype)[0][0] + self.map_atomtype = map_atomtype + # build covalent map + covalent_map = build_covalent_map(data, 6) + + # here box is only used to setup ewald parameters, no need to be differentiable + a, b, c = system.getDefaultPeriodicBoxVectors() + box = jnp.array([a._value, b._value, c._value]) * 10 + # get the admp calculator + rc = nonbondedCutoff.value_in_unit(unit.angstrom) + + # get calculator + if 'ethresh' in args: + self.ethresh = args['ethresh'] + + disp_force = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh, + self.pmax, self.lpme) + self.disp_force = disp_force + pot_fn_lr = disp_force.get_energy + + def potential_fn(positions, box, pairs, params): + mScales = params["mScales"] + C6_list = params["C6"][map_atomtype] * 1e6 # to kj/mol * A**6 + C8_list = params["C8"][map_atomtype] * 1e8 + C10_list = params["C10"][map_atomtype] * 1e10 + c6_list = jnp.sqrt(C6_list) + c8_list = jnp.sqrt(C8_list) + c10_list = jnp.sqrt(C10_list) + c_list = jnp.vstack((c6_list, c8_list, c10_list)) + E_lr = pot_fn_lr(positions, box, pairs, c_list.T, mScales) + return - E_lr + + self._jaxPotential = potential_fn + # self._top_data = data + + def getJaxPotential(self): + return self._jaxPotential + + def renderXML(self): + # generate xml force field file + pass + +# register all parsers +app.forcefield.parsers["ADMPDispPmeForce"] = ADMPDispPmeGenerator.parseElement + +class QqTtDampingGenerator: + ''' + This one calculates the tang-tonnies damping of charge-charge interaction + E = \sum_ij exp(-B*r)*(1+B*r)*q_i*q_j/r + ''' + def __init__(self, hamiltonian): + self.ff = hamiltonian + self.params = { + "B": [], + "Q": [], + } + self._jaxPotential = None + self.types = [] + + def registerAtomType(self, atom): + self.types.append(atom["type"]) + self.params["B"].append(float(atom["B"])) + self.params["Q"].append(float(atom["Q"])) + + @staticmethod + def parseElement(element, hamiltonian): + generator = QqTtDampingGenerator(hamiltonian) + hamiltonian.registerGenerator(generator) + # covalent scales + mScales = [] + for i in range(2, 7): + mScales.append(float(element.attrib["mScale1%d" % i])) + mScales.append(1.0) + generator.params["mScales"] = mScales + for atomtype in element.findall("Atom"): + generator.registerAtomType(atomtype.attrib) + # jax it! + for k in generator.params.keys(): + generator.params[k] = jnp.array(generator.params[k]) + generator.types = np.array(generator.types) + + # on working + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, + args): + + n_atoms = len(data.atoms) + # build index map + map_atomtype = np.zeros(n_atoms, dtype=int) + for i in range(n_atoms): + atype = data.atomType[data.atoms[i]] + map_atomtype[i] = np.where(self.types == atype)[0][0] + self.map_atomtype = map_atomtype + # build covalent map + covalent_map = build_covalent_map(data, 6) + + pot_fn_sr = generate_pairwise_interaction(TT_damping_qq_kernel, + covalent_map, + static_args={}) + + def potential_fn(positions, box, pairs, params): + mScales = params["mScales"] + b_list = params["B"][map_atomtype] / 10 # convert to A^-1 + q_list = params["Q"][map_atomtype] + + E_sr = pot_fn_sr(positions, box, pairs, mScales, b_list, q_list) + return E_sr + + self._jaxPotential = potential_fn + # self._top_data = data + + def getJaxPotential(self): + return self._jaxPotential + + def renderXML(self): + # generate xml force field file + pass + +# register all parsers +app.forcefield.parsers["QqTtDampingForce"] = QqTtDampingGenerator.parseElement + + +class SlaterDampingGenerator: + ''' + This one computes the slater-type damping function for c6/c8/c10 dispersion + E = \sum_ij (f6-1)*c6/r6 + (f8-1)*c8/r8 + (f10-1)*c10/r10 + fn = f_tt(x, n) + x = br - (2*br2 + 3*br) / (br2 + 3*br + 3) + ''' + def __init__(self, hamiltonian): + self.ff = hamiltonian + self.params = { + "B": [], + "C6": [], + "C8": [], + "C10": [], + } + self._jaxPotential = None + self.types = [] + + def registerAtomType(self, atom): + self.types.append(atom["type"]) + self.params["B"].append(float(atom["B"])) + self.params["C6"].append(float(atom["C6"])) + self.params["C8"].append(float(atom["C8"])) + self.params["C10"].append(float(atom["C10"])) + + @staticmethod + def parseElement(element, hamiltonian): + generator = SlaterDampingGenerator(hamiltonian) + hamiltonian.registerGenerator(generator) + # covalent scales + mScales = [] + for i in range(2, 7): + mScales.append(float(element.attrib["mScale1%d" % i])) + mScales.append(1.0) + generator.params["mScales"] = mScales + for atomtype in element.findall("Atom"): + generator.registerAtomType(atomtype.attrib) + # jax it! + for k in generator.params.keys(): + generator.params[k] = jnp.array(generator.params[k]) + generator.types = np.array(generator.types) + + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, + args): + + n_atoms = len(data.atoms) + # build index map + map_atomtype = np.zeros(n_atoms, dtype=int) + for i in range(n_atoms): + atype = data.atomType[data.atoms[i]] + map_atomtype[i] = np.where(self.types == atype)[0][0] + self.map_atomtype = map_atomtype + # build covalent map + covalent_map = build_covalent_map(data, 6) + + # WORKING + pot_fn_sr = generate_pairwise_interaction(slater_disp_damping_kernel, + covalent_map, + static_args={}) + + def potential_fn(positions, box, pairs, params): + mScales = params["mScales"] + b_list = params["B"][map_atomtype] / 10 # convert to A^-1 + c6_list = jnp.sqrt(params["C6"][map_atomtype] * 1e6) # to kj/mol * A**6 + c8_list = jnp.sqrt(params["C8"][map_atomtype] * 1e8) + c10_list = jnp.sqrt(params["C10"][map_atomtype] * 1e10) + E_sr = pot_fn_sr(positions, box, pairs, mScales, b_list, c6_list, c8_list, c10_list) + return E_sr + + self._jaxPotential = potential_fn + # self._top_data = data + + def getJaxPotential(self): + return self._jaxPotential + + def renderXML(self): + # generate xml force field file + pass + +app.forcefield.parsers["SlaterDampingForce"] = SlaterDampingGenerator.parseElement + + +class SlaterExGenerator: + ''' + This one computes the Slater-ISA type exchange interaction + u = \sum_ij A * (1/3*(Br)^2 + Br + 1) + ''' + + def __init__(self, hamiltonian): + self.ff = hamiltonian + self.params = { + "A": [], + "B": [], + } + self._jaxPotential = None + self.types = [] + + def registerAtomType(self, atom): + self.types.append(atom["type"]) + self.params["A"].append(float(atom["A"])) + self.params["B"].append(float(atom["B"])) + + @staticmethod + def parseElement(element, hamiltonian): + generator = SlaterExGenerator(hamiltonian) + hamiltonian.registerGenerator(generator) + # covalent scales + mScales = [] + for i in range(2, 7): + mScales.append(float(element.attrib["mScale1%d" % i])) + mScales.append(1.0) + generator.params["mScales"] = mScales + for atomtype in element.findall("Atom"): + generator.registerAtomType(atomtype.attrib) + # jax it! + for k in generator.params.keys(): + generator.params[k] = jnp.array(generator.params[k]) + generator.types = np.array(generator.types) + + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, + args): + + n_atoms = len(data.atoms) + # build index map + map_atomtype = np.zeros(n_atoms, dtype=int) + for i in range(n_atoms): + atype = data.atomType[data.atoms[i]] + map_atomtype[i] = np.where(self.types == atype)[0][0] + self.map_atomtype = map_atomtype + # build covalent map + covalent_map = build_covalent_map(data, 6) + + pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel, + covalent_map, + static_args={}) + + def potential_fn(positions, box, pairs, params): + mScales = params["mScales"] + a_list = params["A"][map_atomtype] + b_list = params["B"][map_atomtype] / 10 # nm^-1 to A^-1 + + return pot_fn_sr(positions, box, pairs, mScales, a_list, b_list) + + self._jaxPotential = potential_fn + # self._top_data = data + + def getJaxPotential(self): + return self._jaxPotential + + def renderXML(self): + # generate xml force field file + pass + +app.forcefield.parsers["SlaterExForce"] = SlaterExGenerator.parseElement + + +# Here are all the short range "charge penetration" terms +# They all have the exchange form +class SlaterSrEsGenerator(SlaterExGenerator): + def __init__(self): + super().__init__(self) +class SlaterSrPolGenerator(SlaterExGenerator): + def __init__(self): + super().__init__(self) +class SlaterSrDispGenerator(SlaterExGenerator): + def __init__(self): + super().__init__(self) +class SlaterDhfGenerator(SlaterExGenerator): + def __init__(self): + super().__init__(self) + +# register all parsers +app.forcefield.parsers["SlaterSrEsForce"] = SlaterSrEsGenerator.parseElement +app.forcefield.parsers["SlaterSrPolForce"] = SlaterSrPolGenerator.parseElement +app.forcefield.parsers["SlaterSrDispForce"] = SlaterSrDispGenerator.parseElement +app.forcefield.parsers["SlaterDhfForce"] = SlaterDhfGenerator.parseElement + + class ADMPPmeGenerator: def __init__(self, hamiltonian): self.ff = hamiltonian @@ -382,12 +751,25 @@ def parseElement(element, hamiltonian): def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): + methodMap = { + app.CutoffPeriodic: "CutoffPeriodic", + app.NoCutoff: "NoCutoff", + app.PME: "PME", + } + if nonbondedMethod not in methodMap: + raise ValueError("Illegal nonbonded method for ADMPPmeForce") + if nonbondedMethod is app.CutoffPeriodic: + self.lpme = False + else: + self.lpme = True + n_atoms = len(data.atoms) map_atomtype = np.zeros(n_atoms, dtype=int) for i in range(n_atoms): atype = data.atomType[data.atoms[i]] map_atomtype[i] = np.where(self.types == atype)[0][0] + self.map_atomtype = map_atomtype # here box is only used to setup ewald parameters, no need to be differentiable a, b, c = system.getDefaultPeriodicBoxVectors() @@ -614,7 +996,6 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): self.axis_types = None self.axis_indices = None - # get calculator if "ethresh" in args: self.ethresh = args["ethresh"] @@ -627,6 +1008,7 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): self.ethresh, self.lmax, self.lpol, + lpme=self.lpme ) self.pme_force = pme_force diff --git a/dmff/settings.py b/dmff/settings.py index 0ba35abe1..c7965bc2d 100644 --- a/dmff/settings.py +++ b/dmff/settings.py @@ -6,4 +6,4 @@ if PRECISION == 'double': config.update("jax_enable_x64", True) - \ No newline at end of file + diff --git a/dmff/utils.py b/dmff/utils.py index 6182b8123..1d5d9ac0d 100644 --- a/dmff/utils.py +++ b/dmff/utils.py @@ -1,5 +1,6 @@ from dmff.settings import DO_JIT -from jax import jit +from jax import jit, vmap +import jax.numpy as jnp def jit_condition(*args, **kwargs): def jit_deco(func): @@ -7,4 +8,23 @@ def jit_deco(func): return jit(func, *args, **kwargs) else: return func - return jit_deco \ No newline at end of file + return jit_deco + + +@jit_condition() +@vmap +def regularize_pairs(p): + dp = p[1] - p[0] + dp = jnp.piecewise(dp, (dp<=0, dp>0), (lambda x: jnp.array(1), lambda x: jnp.array(0))) + dp_vec = jnp.array([dp, 2*dp]) + p = p - dp_vec + return p + + +@jit_condition() +@vmap +def pair_buffer_scales(p): + return jnp.piecewise( + p[0] - p[1], + (p[0] - p[1] < 0, p[0] - p[1] >= 0), + (lambda x: jnp.array(1), lambda x: jnp.array(0))) diff --git a/examples/peg_slater_isa/benchmark.xml b/examples/peg_slater_isa/benchmark.xml new file mode 100644 index 000000000..2dcc5f1a4 --- /dev/null +++ b/examples/peg_slater_isa/benchmark.xml @@ -0,0 +1,77 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/examples/peg_slater_isa/calc_energy_comps.py b/examples/peg_slater_isa/calc_energy_comps.py new file mode 100755 index 000000000..8066a28d9 --- /dev/null +++ b/examples/peg_slater_isa/calc_energy_comps.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python +import numpy as np +import openmm +from openmm import * +from openmm.app import * +from openmm.unit import * +import jax +import jax_md +import jax.numpy as jnp +import dmff +from dmff.api import Hamiltonian +import pickle +import time + + +if __name__ == '__main__': + ff = 'forcefield.xml' + pdb_AB = PDBFile('peg2_dimer.pdb') + pdb_A = PDBFile('peg2.pdb') + pdb_B = PDBFile('peg2.pdb') + H_AB = Hamiltonian(ff) + H_A = Hamiltonian(ff) + H_B = Hamiltonian(ff) + pme_generator_AB, \ + disp_generator_AB, \ + ex_generator_AB, \ + sr_es_generator_AB, \ + sr_pol_generator_AB, \ + sr_disp_generator_AB, \ + dhf_generator_AB, \ + dmp_es_generator_AB, \ + dmp_disp_generator_AB = H_AB.getGenerators() + pme_generator_A, \ + disp_generator_A, \ + ex_generator_A, \ + sr_es_generator_A, \ + sr_pol_generator_A, \ + sr_disp_generator_A, \ + dhf_generator_A, \ + dmp_es_generator_A, \ + dmp_disp_generator_A = H_A.getGenerators() + pme_generator_B, \ + disp_generator_B, \ + ex_generator_B, \ + sr_es_generator_B, \ + sr_pol_generator_B, \ + sr_disp_generator_B, \ + dhf_generator_B, \ + dmp_es_generator_B, \ + dmp_disp_generator_B = H_B.getGenerators() + + rc = 15 + + # get potential functions + potentials_AB = H_AB.createPotential(pdb_AB.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4) + pot_pme_AB, \ + pot_disp_AB, \ + pot_ex_AB, \ + pot_sr_es_AB, \ + pot_sr_pol_AB, \ + pot_sr_disp_AB, \ + pot_dhf_AB, \ + pot_dmp_es_AB, \ + pot_dmp_disp_AB = potentials_AB + potentials_A = H_A.createPotential(pdb_A.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4) + pot_pme_A, \ + pot_disp_A, \ + pot_ex_A, \ + pot_sr_es_A, \ + pot_sr_pol_A, \ + pot_sr_disp_A, \ + pot_dhf_A, \ + pot_dmp_es_A, \ + pot_dmp_disp_A = potentials_A + potentials_B = H_B.createPotential(pdb_B.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4) + pot_pme_B, \ + pot_disp_B, \ + pot_ex_B, \ + pot_sr_es_B, \ + pot_sr_pol_B, \ + pot_sr_disp_B, \ + pot_dhf_B, \ + pot_dmp_es_B, \ + pot_dmp_disp_B = potentials_B + + pos_AB0 = jnp.array(pdb_AB.positions._value) * 10 + n_atoms = len(pos_AB0) + n_atoms_A = n_atoms // 2 + n_atoms_B = n_atoms // 2 + pos_A0 = jnp.array(pdb_AB.positions._value[:n_atoms_A]) * 10 + pos_B0 = jnp.array(pdb_AB.positions._value[n_atoms_A:n_atoms]) * 10 + box = jnp.array(pdb_AB.topology.getPeriodicBoxVectors()._value) * 10 + # nn list initial allocation + displacement_fn, shift_fn = jax_md.space.periodic_general(box, fractional_coordinates=False) + neighbor_list_fn = jax_md.partition.neighbor_list(displacement_fn, box, rc, 0, format=jax_md.partition.OrderedSparse) + nbr_AB = neighbor_list_fn.allocate(pos_AB0) + nbr_A = neighbor_list_fn.allocate(pos_A0) + nbr_B = neighbor_list_fn.allocate(pos_B0) + pairs_AB = np.array(nbr_AB.idx.T) + pairs_A = np.array(nbr_A.idx.T) + pairs_B = np.array(nbr_B.idx.T) + pairs_AB = pairs_AB[pairs_AB[:, 0] < pairs_AB[:, 1]] + pairs_A = pairs_A[pairs_A[:, 0] < pairs_A[:, 1]] + pairs_B = pairs_B[pairs_B[:, 0] < pairs_B[:, 1]] + + # load data + with open('data.pickle', 'rb') as ifile: + data = pickle.load(ifile) + + keys = list(data.keys()) + keys.sort() + # for sid in keys: + for sid in ['000']: + scan_res = data[sid] + scan_res['tot_full'] = scan_res['tot'].copy() + npts = len(scan_res['tot']) + print(sid) + + for ipt in range(npts): + E_es_ref = scan_res['es'][ipt] + E_pol_ref = scan_res['pol'][ipt] + E_disp_ref = scan_res['disp'][ipt] + E_ex_ref = scan_res['ex'][ipt] + E_dhf_ref = scan_res['dhf'][ipt] + E_tot_ref = scan_res['tot'][ipt] + + # get position array + pos_A = jnp.array(scan_res['posA'][ipt]) + pos_B = jnp.array(scan_res['posB'][ipt]) + pos_AB = jnp.concatenate([pos_A, pos_B], axis=0) + + + ##################### + # exchange repulsion + ##################### + E_ex_AB = pot_ex_AB(pos_AB, box, pairs_AB, ex_generator_AB.params) + E_ex_A = pot_ex_A(pos_A, box, pairs_A, ex_generator_AB.params) + E_ex_B = pot_ex_B(pos_B, box, pairs_B, ex_generator_AB.params) + E_ex = E_ex_AB - E_ex_A - E_ex_B + + ####################### + # electrostatic + pol + ####################### + E_AB = pot_pme_AB(pos_AB, box, pairs_AB, pme_generator_AB.params) + E_A = pot_pme_A(pos_A, box, pairs_A, pme_generator_A.params) + E_B = pot_pme_B(pos_B, box, pairs_A, pme_generator_B.params) + E_espol = E_AB - E_A - E_B + + # use induced dipole of monomers to compute electrostatic interaction + U_ind_AB = jnp.vstack((pme_generator_A.pme_force.U_ind, pme_generator_B.pme_force.U_ind)) + params = pme_generator_AB.params + map_atypes = pme_generator_AB.map_atomtype + Q_local = params['Q_local'][map_atypes] + pol = params['pol'][map_atypes] + tholes = params['tholes'][map_atypes] + pme_force = pme_generator_AB.pme_force + E_AB_nonpol = pme_force.energy_fn(pos_AB, box, pairs_AB, Q_local, U_ind_AB, pol, tholes, params['mScales'], params['pScales'], params['dScales']) + E_es = E_AB_nonpol - E_A - E_B + E_dmp_es = pot_dmp_es_AB(pos_AB, box, pairs_AB, dmp_es_generator_AB.params) \ + - pot_dmp_es_A(pos_A, box, pairs_A, dmp_es_generator_A.params) \ + - pot_dmp_es_B(pos_B, box, pairs_B, dmp_es_generator_B.params) + E_sr_es = pot_sr_es_AB(pos_AB, box, pairs_AB, sr_es_generator_AB.params) \ + - pot_sr_es_A(pos_A, box, pairs_A, sr_es_generator_AB.params) \ + - pot_sr_es_B(pos_B, box, pairs_B, sr_es_generator_AB.params) + + + ################################### + # polarization (induction) energy + ################################### + E_pol = E_espol - E_es + E_sr_pol = pot_sr_pol_AB(pos_AB, box, pairs_AB, sr_pol_generator_AB.params) \ + - pot_sr_pol_A(pos_A, box, pairs_A, sr_pol_generator_AB.params) \ + - pot_sr_pol_B(pos_B, box, pairs_B, sr_pol_generator_AB.params) + + + ############# + # dispersion + ############# + E_AB_disp = pot_disp_AB(pos_AB, box, pairs_AB, disp_generator_AB.params) + E_A_disp = pot_disp_A(pos_A, box, pairs_A, disp_generator_AB.params) + E_B_disp = pot_disp_B(pos_B, box, pairs_B, disp_generator_AB.params) + E_disp = E_AB_disp - E_A_disp - E_B_disp + E_dmp_disp = pot_dmp_disp_AB(pos_AB, box, pairs_AB, dmp_disp_generator_AB.params) \ + - pot_dmp_disp_A(pos_A, box, pairs_A, dmp_disp_generator_A.params) \ + - pot_dmp_disp_B(pos_B, box, pairs_B, dmp_disp_generator_B.params) + E_sr_disp = pot_sr_disp_AB(pos_AB, box, pairs_AB, sr_disp_generator_AB.params) \ + - pot_sr_disp_A(pos_A, box, pairs_A, sr_disp_generator_AB.params) \ + - pot_sr_disp_B(pos_B, box, pairs_B, sr_disp_generator_AB.params) + + ########### + # dhf + ########### + E_AB_dhf = pot_dhf_AB(pos_AB, box, pairs_AB, dhf_generator_AB.params) + E_A_dhf = pot_dhf_A(pos_A, box, pairs_A, dhf_generator_AB.params) + E_B_dhf = pot_dhf_B(pos_B, box, pairs_B, dhf_generator_AB.params) + E_dhf = E_AB_dhf - E_A_dhf - E_B_dhf + + # total energy + E_tot = (E_es + E_sr_es + E_dmp_es) + (E_ex) + (E_pol + E_sr_pol) + (E_disp + E_dmp_disp + E_sr_disp) + (E_dhf) + # print(E_dmp_es + E_disp + E_dmp_disp) + print(E_es + E_pol) + diff --git a/examples/peg_slater_isa/check.py b/examples/peg_slater_isa/check.py new file mode 100755 index 000000000..5b2f22abd --- /dev/null +++ b/examples/peg_slater_isa/check.py @@ -0,0 +1,262 @@ +#!/usr/bin/env python +import sys +import numpy as np +import openmm +from openmm import * +from openmm.app import * +from openmm.unit import * +import jax +import jax_md +import jax.numpy as jnp +import dmff +from dmff.api import Hamiltonian +import pickle +import time +from jax import value_and_grad, jit +import optax + + +if __name__ == '__main__': + ff = 'forcefield.xml' + pdb_AB = PDBFile('peg2_dimer.pdb') + pdb_A = PDBFile('peg2.pdb') + pdb_B = PDBFile('peg2.pdb') + param_file = 'params.0.pickle' + H_AB = Hamiltonian(ff) + H_A = Hamiltonian(ff) + H_B = Hamiltonian(ff) + pme_generator_AB, \ + disp_generator_AB, \ + ex_generator_AB, \ + sr_es_generator_AB, \ + sr_pol_generator_AB, \ + sr_disp_generator_AB, \ + dhf_generator_AB, \ + dmp_es_generator_AB, \ + dmp_disp_generator_AB = H_AB.getGenerators() + pme_generator_A, \ + disp_generator_A, \ + ex_generator_A, \ + sr_es_generator_A, \ + sr_pol_generator_A, \ + sr_disp_generator_A, \ + dhf_generator_A, \ + dmp_es_generator_A, \ + dmp_disp_generator_A = H_A.getGenerators() + pme_generator_B, \ + disp_generator_B, \ + ex_generator_B, \ + sr_es_generator_B, \ + sr_pol_generator_B, \ + sr_disp_generator_B, \ + dhf_generator_B, \ + dmp_es_generator_B, \ + dmp_disp_generator_B = H_B.getGenerators() + + rc = 15 + + # get potential functions + potentials_AB = H_AB.createPotential(pdb_AB.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4) + pot_pme_AB, \ + pot_disp_AB, \ + pot_ex_AB, \ + pot_sr_es_AB, \ + pot_sr_pol_AB, \ + pot_sr_disp_AB, \ + pot_dhf_AB, \ + pot_dmp_es_AB, \ + pot_dmp_disp_AB = potentials_AB + potentials_A = H_A.createPotential(pdb_A.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4) + pot_pme_A, \ + pot_disp_A, \ + pot_ex_A, \ + pot_sr_es_A, \ + pot_sr_pol_A, \ + pot_sr_disp_A, \ + pot_dhf_A, \ + pot_dmp_es_A, \ + pot_dmp_disp_A = potentials_A + potentials_B = H_B.createPotential(pdb_B.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4) + pot_pme_B, \ + pot_disp_B, \ + pot_ex_B, \ + pot_sr_es_B, \ + pot_sr_pol_B, \ + pot_sr_disp_B, \ + pot_dhf_B, \ + pot_dmp_es_B, \ + pot_dmp_disp_B = potentials_B + + # init positions used to set up neighbor list + pos_AB0 = jnp.array(pdb_AB.positions._value) * 10 + n_atoms = len(pos_AB0) + n_atoms_A = n_atoms // 2 + n_atoms_B = n_atoms // 2 + pos_A0 = jnp.array(pdb_AB.positions._value[:n_atoms_A]) * 10 + pos_B0 = jnp.array(pdb_AB.positions._value[n_atoms_A:n_atoms]) * 10 + box = jnp.array(pdb_AB.topology.getPeriodicBoxVectors()._value) * 10 + + # nn list initial allocation + displacement_fn, shift_fn = jax_md.space.periodic_general(box, fractional_coordinates=False) + neighbor_list_fn = jax_md.partition.neighbor_list(displacement_fn, box, rc, 0, format=jax_md.partition.OrderedSparse) + nbr_AB = neighbor_list_fn.allocate(pos_AB0) + nbr_A = neighbor_list_fn.allocate(pos_A0) + nbr_B = neighbor_list_fn.allocate(pos_B0) + pairs_AB = np.array(nbr_AB.idx.T) + pairs_A = np.array(nbr_A.idx.T) + pairs_B = np.array(nbr_B.idx.T) + pairs_AB = pairs_AB[pairs_AB[:, 0] < pairs_AB[:, 1]] + pairs_A = pairs_A[pairs_A[:, 0] < pairs_A[:, 1]] + pairs_B = pairs_B[pairs_B[:, 0] < pairs_B[:, 1]] + + + # construct total force field params + comps = ['ex', 'es', 'pol', 'disp', 'dhf', 'tot'] + # load parameters + with open(param_file, 'rb') as ifile: + params = pickle.load(ifile) + + # setting up params for all calculators + params_ex = {} + params_sr_es = {} + params_sr_pol = {} + params_sr_disp = {} + params_dhf = {} + params_dmp_es = {} + params_dmp_disp = {} + for k in ['B', 'mScales']: + params_ex[k] = params[k] + params_sr_es[k] = params[k] + params_sr_pol[k] = params[k] + params_sr_disp[k] = params[k] + params_dhf[k] = params[k] + params_dmp_es[k] = params[k] + params_dmp_disp[k] = params[k] + params_ex['A'] = params['A_ex'] + params_sr_es['A'] = params['A_es'] + params_sr_pol['A'] = params['A_pol'] + params_sr_disp['A'] = params['A_disp'] + params_dhf['A'] = params['A_dhf'] + # damping parameters + params_dmp_es['Q'] = params['Q'] + params_dmp_disp['C6'] = params['C6'] + params_dmp_disp['C8'] = params['C8'] + params_dmp_disp['C10'] = params['C10'] + # long range parameters + params_espol = {} + for k in ['mScales', 'pScales', 'dScales', 'Q_local', 'pol', 'tholes']: + params_espol[k] = params[k] + params_disp = {} + for k in ['B', 'C6', 'C8', 'C10', 'mScales']: + params_disp[k] = params[k] + + + # load data + with open('data.pickle', 'rb') as ifile: + data = pickle.load(ifile) + with open('data_sr.pickle', 'rb') as ifile: + data_sr = pickle.load(ifile) + with open('data_lr.pickle', 'rb') as ifile: + data_lr = pickle.load(ifile) + sids = list(data.keys()) + sids.sort() + + # run test + # for sid in sids: + for sid in [sys.argv[1]]: + scan_res = data[sid] + scan_res_sr = data_sr[sid] + scan_res_lr = data_lr[sid] + npts = len(scan_res['tot']) + + for ipt in range(npts): + E_es_ref = scan_res['es'][ipt] + E_pol_ref = scan_res['pol'][ipt] + E_disp_ref = scan_res['disp'][ipt] + E_ex_ref = scan_res['ex'][ipt] + E_dhf_ref = scan_res['dhf'][ipt] + E_tot_ref = scan_res['tot'][ipt] + + pos_A = jnp.array(scan_res['posA'][ipt]) + pos_B = jnp.array(scan_res['posB'][ipt]) + pos_AB = jnp.concatenate([pos_A, pos_B], axis=0) + + ##################### + # exchange repulsion + ##################### + E_ex_AB = pot_ex_AB(pos_AB, box, pairs_AB, params_ex) + E_ex_A = pot_ex_A(pos_A, box, pairs_A, params_ex) + E_ex_B = pot_ex_B(pos_B, box, pairs_B, params_ex) + E_ex = E_ex_AB - E_ex_A - E_ex_B + + ####################### + # electrostatic + pol + ####################### + E_AB = pot_pme_AB(pos_AB, box, pairs_AB, params_espol) + E_A = pot_pme_A(pos_A, box, pairs_A, params_espol) + E_B = pot_pme_B(pos_B, box, pairs_A, params_espol) + E_espol = E_AB - E_A - E_B + + # use induced dipole of monomers to compute electrostatic interaction + U_ind_AB = jnp.vstack((pme_generator_A.pme_force.U_ind, pme_generator_B.pme_force.U_ind)) + params = params_espol + map_atypes = pme_generator_AB.map_atomtype + Q_local = params['Q_local'][map_atypes] + pol = params['pol'][map_atypes] + tholes = params['tholes'][map_atypes] + pme_force = pme_generator_AB.pme_force + E_AB_nonpol = pme_force.energy_fn(pos_AB, box, pairs_AB, Q_local, U_ind_AB, pol, tholes, params['mScales'], params['pScales'], params['dScales']) + E_es = E_AB_nonpol - E_A - E_B + E_dmp_es = pot_dmp_es_AB(pos_AB, box, pairs_AB, params_dmp_es) \ + - pot_dmp_es_A(pos_A, box, pairs_A, params_dmp_es) \ + - pot_dmp_es_B(pos_B, box, pairs_B, params_dmp_es) + E_sr_es = pot_sr_es_AB(pos_AB, box, pairs_AB, params_sr_es) \ + - pot_sr_es_A(pos_A, box, pairs_A, params_sr_es) \ + - pot_sr_es_B(pos_B, box, pairs_B, params_sr_es) + + + ################################### + # polarization (induction) energy + ################################### + E_pol = E_espol - E_es + E_sr_pol = pot_sr_pol_AB(pos_AB, box, pairs_AB, params_sr_pol) \ + - pot_sr_pol_A(pos_A, box, pairs_A, params_sr_pol) \ + - pot_sr_pol_B(pos_B, box, pairs_B, params_sr_pol) + + + ############# + # dispersion + ############# + E_AB_disp = pot_disp_AB(pos_AB, box, pairs_AB, params_disp) + E_A_disp = pot_disp_A(pos_A, box, pairs_A, params_disp) + E_B_disp = pot_disp_B(pos_B, box, pairs_B, params_disp) + E_disp = E_AB_disp - E_A_disp - E_B_disp + E_dmp_disp = pot_dmp_disp_AB(pos_AB, box, pairs_AB, params_dmp_disp) \ + - pot_dmp_disp_A(pos_A, box, pairs_A, params_dmp_disp) \ + - pot_dmp_disp_B(pos_B, box, pairs_B, params_dmp_disp) + E_sr_disp = pot_sr_disp_AB(pos_AB, box, pairs_AB, params_sr_disp) \ + - pot_sr_disp_A(pos_A, box, pairs_A, params_sr_disp) \ + - pot_sr_disp_B(pos_B, box, pairs_B, params_sr_disp) + + ########### + # dhf + ########### + E_AB_dhf = pot_dhf_AB(pos_AB, box, pairs_AB, params_dhf) + E_A_dhf = pot_dhf_A(pos_A, box, pairs_A, params_dhf) + E_B_dhf = pot_dhf_B(pos_B, box, pairs_B, params_dhf) + E_dhf = E_AB_dhf - E_A_dhf - E_B_dhf + + # total energy + E_tot = (E_es + E_sr_es + E_dmp_es) + (E_ex) + (E_pol + E_sr_pol) + (E_disp + E_dmp_disp + E_sr_disp) + (E_dhf) + E_tot_sr = (E_sr_es + E_dmp_es) + (E_ex) + (E_sr_pol) + (E_sr_disp + E_dmp_disp) + (E_dhf) + E_tot_lr = E_es + E_pol + E_disp + + print(ipt, E_tot, E_tot_ref) + # print(ipt, E_tot, E_tot_ref, E_tot_sr, data_sr[sid]['tot'][ipt], E_tot_lr, data[sid]['tot'][ipt]-data_sr[sid]['tot'][ipt]) + # print(ipt, E_tot_lr, scan_res_lr['tot'][ipt]) + # print(ipt, E_tot_sr, scan_res_sr['tot'][ipt], scan_res['tot'][ipt]) + # if scan_res['tot'][ipt] < 25: + # print(scan_res_sr['tot'][ipt], scan_res_sr['tot'][ipt], E_tot_sr) + # # print(scan_res['tot'][ipt], scan_res['tot'][ipt], E_tot) + # sys.stdout.flush() + diff --git a/examples/peg_slater_isa/data.pickle b/examples/peg_slater_isa/data.pickle new file mode 100644 index 000000000..e2ab58d9c Binary files /dev/null and b/examples/peg_slater_isa/data.pickle differ diff --git a/examples/peg_slater_isa/data_lr.pickle b/examples/peg_slater_isa/data_lr.pickle new file mode 100644 index 000000000..928c60557 Binary files /dev/null and b/examples/peg_slater_isa/data_lr.pickle differ diff --git a/examples/peg_slater_isa/data_sr.pickle b/examples/peg_slater_isa/data_sr.pickle new file mode 100644 index 000000000..82276c420 Binary files /dev/null and b/examples/peg_slater_isa/data_sr.pickle differ diff --git a/examples/peg_slater_isa/fit.py b/examples/peg_slater_isa/fit.py new file mode 100755 index 000000000..d5b5bccb8 --- /dev/null +++ b/examples/peg_slater_isa/fit.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python +import sys +import numpy as np +import openmm +from openmm import * +from openmm.app import * +from openmm.unit import * +import jax +import jax_md +import jax.numpy as jnp +import dmff +from dmff.api import Hamiltonian +import pickle +import time +from jax import value_and_grad, jit +import optax + + +if __name__ == '__main__': + restart = 'params.0.pickle' # None + ff = 'forcefield.xml' + pdb_AB = PDBFile('peg2_dimer.pdb') + pdb_A = PDBFile('peg2.pdb') + pdb_B = PDBFile('peg2.pdb') + H_AB = Hamiltonian(ff) + H_A = Hamiltonian(ff) + H_B = Hamiltonian(ff) + pme_generator_AB, \ + disp_generator_AB, \ + ex_generator_AB, \ + sr_es_generator_AB, \ + sr_pol_generator_AB, \ + sr_disp_generator_AB, \ + dhf_generator_AB, \ + dmp_es_generator_AB, \ + dmp_disp_generator_AB = H_AB.getGenerators() + pme_generator_A, \ + disp_generator_A, \ + ex_generator_A, \ + sr_es_generator_A, \ + sr_pol_generator_A, \ + sr_disp_generator_A, \ + dhf_generator_A, \ + dmp_es_generator_A, \ + dmp_disp_generator_A = H_A.getGenerators() + pme_generator_B, \ + disp_generator_B, \ + ex_generator_B, \ + sr_es_generator_B, \ + sr_pol_generator_B, \ + sr_disp_generator_B, \ + dhf_generator_B, \ + dmp_es_generator_B, \ + dmp_disp_generator_B = H_B.getGenerators() + + rc = 15 + + # get potential functions + potentials_AB = H_AB.createPotential(pdb_AB.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4) + pot_pme_AB, \ + pot_disp_AB, \ + pot_ex_AB, \ + pot_sr_es_AB, \ + pot_sr_pol_AB, \ + pot_sr_disp_AB, \ + pot_dhf_AB, \ + pot_dmp_es_AB, \ + pot_dmp_disp_AB = potentials_AB + potentials_A = H_A.createPotential(pdb_A.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4) + pot_pme_A, \ + pot_disp_A, \ + pot_ex_A, \ + pot_sr_es_A, \ + pot_sr_pol_A, \ + pot_sr_disp_A, \ + pot_dhf_A, \ + pot_dmp_es_A, \ + pot_dmp_disp_A = potentials_A + potentials_B = H_B.createPotential(pdb_B.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4) + pot_pme_B, \ + pot_disp_B, \ + pot_ex_B, \ + pot_sr_es_B, \ + pot_sr_pol_B, \ + pot_sr_disp_B, \ + pot_dhf_B, \ + pot_dmp_es_B, \ + pot_dmp_disp_B = potentials_B + + pos_AB0 = jnp.array(pdb_AB.positions._value) * 10 + n_atoms = len(pos_AB0) + n_atoms_A = n_atoms // 2 + n_atoms_B = n_atoms // 2 + pos_A0 = jnp.array(pdb_AB.positions._value[:n_atoms_A]) * 10 + pos_B0 = jnp.array(pdb_AB.positions._value[n_atoms_A:n_atoms]) * 10 + box = jnp.array(pdb_AB.topology.getPeriodicBoxVectors()._value) * 10 + # nn list initial allocation + displacement_fn, shift_fn = jax_md.space.periodic_general(box, fractional_coordinates=False) + neighbor_list_fn = jax_md.partition.neighbor_list(displacement_fn, box, rc, 0, format=jax_md.partition.OrderedSparse) + nbr_AB = neighbor_list_fn.allocate(pos_AB0) + nbr_A = neighbor_list_fn.allocate(pos_A0) + nbr_B = neighbor_list_fn.allocate(pos_B0) + pairs_AB = np.array(nbr_AB.idx.T) + pairs_A = np.array(nbr_A.idx.T) + pairs_B = np.array(nbr_B.idx.T) + + + # construct total force field params + comps = ['ex', 'es', 'pol', 'disp', 'dhf', 'tot'] + weights_comps = jnp.array([0.001, 0.001, 0.001, 0.001, 0.001, 1.0]) + if restart is None: + params = {} + sr_generators = { + 'ex': ex_generator_AB, + 'es': sr_es_generator_AB, + 'pol': sr_pol_generator_AB, + 'disp': sr_disp_generator_AB, + 'dhf': dhf_generator_AB, + } + for k in pme_generator_AB.params: + params[k] = pme_generator_AB.params[k] + for k in disp_generator_AB.params: + params[k] = disp_generator_AB.params[k] + for c in comps: + if c == 'tot': + continue + gen = sr_generators[c] + for k in gen.params: + if k == 'A': + params['A_'+c] = gen.params[k] + else: + params[k] = gen.params[k] + # a random initialization of A + for c in comps: + if c == 'tot': + continue + params['A_'+c] = jnp.array(np.random.random(params['A_'+c].shape)) + # specify charges for es damping + params['Q'] = dmp_es_generator_AB.params['Q'] + else: + with open(restart, 'rb') as ifile: + params = pickle.load(ifile) + + + @jit + def MSELoss(params, scan_res): + ''' + The weighted mean squared error loss function + Conducted for each scan + ''' + E_tot_full = scan_res['tot_full'] + kT = 2.494 # 300 K = 2.494 kJ/mol + weights_pts = jnp.piecewise(E_tot_full, [E_tot_full<25, E_tot_full>=25], [lambda x: jnp.array(1.0), lambda x: jnp.exp(-(x-25)/kT)]) + npts = len(weights_pts) + + energies = { + 'ex': jnp.zeros(npts), + 'es': jnp.zeros(npts), + 'pol': jnp.zeros(npts), + 'disp': jnp.zeros(npts), + 'dhf': jnp.zeros(npts), + 'tot': jnp.zeros(npts) + } + + # setting up params for all calculators + params_ex = {} + params_sr_es = {} + params_sr_pol = {} + params_sr_disp = {} + params_dhf = {} + params_dmp_es = {} # electrostatic damping + params_dmp_disp = {} # dispersion damping + for k in ['B', 'mScales']: + params_ex[k] = params[k] + params_sr_es[k] = params[k] + params_sr_pol[k] = params[k] + params_sr_disp[k] = params[k] + params_dhf[k] = params[k] + params_dmp_es[k] = params[k] + params_dmp_disp[k] = params[k] + params_ex['A'] = params['A_ex'] + params_sr_es['A'] = params['A_es'] + params_sr_pol['A'] = params['A_pol'] + params_sr_disp['A'] = params['A_disp'] + params_dhf['A'] = params['A_dhf'] + # damping parameters + params_dmp_es['Q'] = params['Q'] + params_dmp_disp['C6'] = params['C6'] + params_dmp_disp['C8'] = params['C8'] + params_dmp_disp['C10'] = params['C10'] + + # calculate each points, only the short range and damping components + for ipt in range(npts): + # get position array + pos_A = jnp.array(scan_res['posA'][ipt]) + pos_B = jnp.array(scan_res['posB'][ipt]) + pos_AB = jnp.concatenate([pos_A, pos_B], axis=0) + + ##################### + # exchange repulsion + ##################### + E_ex_AB = pot_ex_AB(pos_AB, box, pairs_AB, params_ex) + E_ex_A = pot_ex_A(pos_A, box, pairs_A, params_ex) + E_ex_B = pot_ex_B(pos_B, box, pairs_B, params_ex) + E_ex = E_ex_AB - E_ex_A - E_ex_B + + ####################### + # electrostatic + pol + ####################### + E_dmp_es = pot_dmp_es_AB(pos_AB, box, pairs_AB, params_dmp_es) \ + - pot_dmp_es_A(pos_A, box, pairs_A, params_dmp_es) \ + - pot_dmp_es_B(pos_B, box, pairs_B, params_dmp_es) + E_sr_es = pot_sr_es_AB(pos_AB, box, pairs_AB, params_sr_es) \ + - pot_sr_es_A(pos_A, box, pairs_A, params_sr_es) \ + - pot_sr_es_B(pos_B, box, pairs_B, params_sr_es) + + ################################### + # polarization (induction) energy + ################################### + E_sr_pol = pot_sr_pol_AB(pos_AB, box, pairs_AB, params_sr_pol) \ + - pot_sr_pol_A(pos_A, box, pairs_A, params_sr_pol) \ + - pot_sr_pol_B(pos_B, box, pairs_B, params_sr_pol) + + ############# + # dispersion + ############# + E_dmp_disp = pot_dmp_disp_AB(pos_AB, box, pairs_AB, params_dmp_disp) \ + - pot_dmp_disp_A(pos_A, box, pairs_A, params_dmp_disp) \ + - pot_dmp_disp_B(pos_B, box, pairs_B, params_dmp_disp) + E_sr_disp = pot_sr_disp_AB(pos_AB, box, pairs_AB, params_sr_disp) \ + - pot_sr_disp_A(pos_A, box, pairs_A, params_sr_disp) \ + - pot_sr_disp_B(pos_B, box, pairs_B, params_sr_disp) + + ########### + # dhf + ########### + E_AB_dhf = pot_dhf_AB(pos_AB, box, pairs_AB, params_dhf) + E_A_dhf = pot_dhf_A(pos_A, box, pairs_A, params_dhf) + E_B_dhf = pot_dhf_B(pos_B, box, pairs_B, params_dhf) + E_dhf = E_AB_dhf - E_A_dhf - E_B_dhf + + energies['ex'] = energies['ex'].at[ipt].set(E_ex) + energies['es'] = energies['es'].at[ipt].set(E_dmp_es + E_sr_es) + energies['pol'] = energies['pol'].at[ipt].set(E_sr_pol) + energies['disp'] = energies['disp'].at[ipt].set(E_dmp_disp + E_sr_disp) + energies['dhf'] = energies['dhf'].at[ipt].set(E_dhf) + energies['tot'] = energies['tot'].at[ipt].set(E_ex + + E_dmp_es + E_sr_es + + E_sr_pol + + E_dmp_disp + E_sr_disp + + E_dhf) + + + errs = jnp.zeros(len(comps)) + for ic, c in enumerate(comps): + dE = energies[c] - scan_res[c] + mse = dE**2 * weights_pts / jnp.sum(weights_pts) + errs = errs.at[ic].set(jnp.sum(mse)) + + return jnp.sum(weights_comps * errs) + + + # load data + with open('data_sr.pickle', 'rb') as ifile: + data = pickle.load(ifile) + + err, gradients = value_and_grad(MSELoss, argnums=(0))(params, data['000']) + sids = np.array(list(data.keys())) + + + # only optimize these parameters A/B + def mask_fn(grads): + for k in grads: + if k.startswith('A_') or k == 'B': + continue + else: + grads[k] = 0.0 + return grads + + # start to do optmization + lr = 0.001 + optimizer = optax.adam(lr) + opt_state = optimizer.init(params) + + n_epochs = 1000 + for i_epoch in range(n_epochs): + np.random.shuffle(sids) + for sid in sids: + loss, grads = value_and_grad(MSELoss, argnums=(0))(params, data[sid]) + grads = mask_fn(grads) + print(loss) + sys.stdout.flush() + updates, opt_state = optimizer.update(grads, opt_state) + params = optax.apply_updates(params, updates) + with open('params.pickle', 'wb') as ofile: + pickle.dump(params, ofile) + + diff --git a/examples/peg_slater_isa/fit.sh b/examples/peg_slater_isa/fit.sh new file mode 100644 index 000000000..6e2a876a1 --- /dev/null +++ b/examples/peg_slater_isa/fit.sh @@ -0,0 +1,6 @@ +#!/bin/bash +#SBATCH -N 1 -n 1 --gres=gpu:1 +#SBATCH -t 24:00:00 -o out -e err +#SBATCH -p rtx3090 + +./fit.py > log diff --git a/examples/peg_slater_isa/forcefield.xml b/examples/peg_slater_isa/forcefield.xml new file mode 100644 index 000000000..25d90db50 --- /dev/null +++ b/examples/peg_slater_isa/forcefield.xml @@ -0,0 +1,145 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/examples/peg_slater_isa/forcefield_nonpol.xml b/examples/peg_slater_isa/forcefield_nonpol.xml new file mode 100644 index 000000000..7160d4f2d --- /dev/null +++ b/examples/peg_slater_isa/forcefield_nonpol.xml @@ -0,0 +1,75 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/examples/peg_slater_isa/params.0.pickle b/examples/peg_slater_isa/params.0.pickle new file mode 100644 index 000000000..c6e9cd939 Binary files /dev/null and b/examples/peg_slater_isa/params.0.pickle differ diff --git a/examples/peg_slater_isa/peg2.pdb b/examples/peg_slater_isa/peg2.pdb new file mode 100644 index 000000000..e44d9f14b --- /dev/null +++ b/examples/peg_slater_isa/peg2.pdb @@ -0,0 +1,35 @@ +TITLE MDANALYSIS FRAME 0: Created by PDBWriter +CRYST1 30.000 30.000 30.000 90.00 90.00 90.00 P 1 1 +ATOM 1 C00 TER M 1 14.058 13.500 16.329 0.00 1.00 SYST C +ATOM 2 H01 TER M 1 14.410 14.387 16.870 0.00 1.00 SYST H +ATOM 3 H02 TER M 1 14.412 12.614 16.871 0.00 1.00 SYST H +ATOM 4 O03 TER M 1 14.578 13.500 15.000 0.00 1.00 SYST O +ATOM 5 C04 TER M 1 16.000 13.500 15.000 0.00 1.00 SYST C +ATOM 6 H05 TER M 1 16.344 13.499 13.962 0.00 1.00 SYST H +ATOM 7 H06 TER M 1 16.382 12.602 15.496 0.00 1.00 SYST H +ATOM 8 H07 TER M 1 16.382 14.399 15.493 0.00 1.00 SYST H +ATOM 9 C00 TER M 2 12.535 13.498 16.276 0.00 1.00 SYST C +ATOM 10 H01 TER M 2 12.184 12.612 15.734 0.00 1.00 SYST H +ATOM 11 H02 TER M 2 12.182 14.385 15.735 0.00 1.00 SYST H +ATOM 12 O03 TER M 2 12.015 13.496 17.605 0.00 1.00 SYST O +ATOM 13 C04 TER M 2 10.593 13.493 17.605 0.00 1.00 SYST C +ATOM 14 H05 TER M 2 10.250 13.491 18.643 0.00 1.00 SYST H +ATOM 15 H06 TER M 2 10.213 12.595 17.109 0.00 1.00 SYST H +ATOM 16 H07 TER M 2 10.209 14.391 17.112 0.00 1.00 SYST H +CONECT 1 2 3 4 9 +CONECT 2 1 +CONECT 3 1 +CONECT 4 1 5 +CONECT 5 4 6 7 8 +CONECT 6 5 +CONECT 7 5 +CONECT 8 5 +CONECT 9 1 10 11 12 +CONECT 10 9 +CONECT 11 9 +CONECT 12 9 13 +CONECT 13 12 14 15 16 +CONECT 14 13 +CONECT 15 13 +CONECT 16 13 +END diff --git a/examples/peg_slater_isa/peg2_dimer.pdb b/examples/peg_slater_isa/peg2_dimer.pdb new file mode 100644 index 000000000..5bff255ca --- /dev/null +++ b/examples/peg_slater_isa/peg2_dimer.pdb @@ -0,0 +1,67 @@ +TITLE MDANALYSIS FRAME 0: Created by PDBWriter +CRYST1 30.000 30.000 30.000 90.00 90.00 90.00 P 1 1 +ATOM 1 C00 TER M 1 14.058 13.500 16.329 0.00 1.00 SYST C +ATOM 2 H01 TER M 1 14.410 14.387 16.870 0.00 1.00 SYST H +ATOM 3 H02 TER M 1 14.412 12.614 16.871 0.00 1.00 SYST H +ATOM 4 O03 TER M 1 14.578 13.500 15.000 0.00 1.00 SYST O +ATOM 5 C04 TER M 1 16.000 13.500 15.000 0.00 1.00 SYST C +ATOM 6 H05 TER M 1 16.344 13.499 13.962 0.00 1.00 SYST H +ATOM 7 H06 TER M 1 16.382 12.602 15.496 0.00 1.00 SYST H +ATOM 8 H07 TER M 1 16.382 14.399 15.493 0.00 1.00 SYST H +ATOM 9 C00 TER M 2 12.535 13.498 16.276 0.00 1.00 SYST C +ATOM 10 H01 TER M 2 12.184 12.612 15.734 0.00 1.00 SYST H +ATOM 11 H02 TER M 2 12.182 14.385 15.735 0.00 1.00 SYST H +ATOM 12 O03 TER M 2 12.015 13.496 17.605 0.00 1.00 SYST O +ATOM 13 C04 TER M 2 10.593 13.493 17.605 0.00 1.00 SYST C +ATOM 14 H05 TER M 2 10.250 13.491 18.643 0.00 1.00 SYST H +ATOM 15 H06 TER M 2 10.213 12.595 17.109 0.00 1.00 SYST H +ATOM 16 H07 TER M 2 10.209 14.391 17.112 0.00 1.00 SYST H +ATOM 17 C00 TER M 3 14.058 18.500 16.329 0.00 1.00 SYST C +ATOM 18 H01 TER M 3 14.410 19.387 16.870 0.00 1.00 SYST H +ATOM 19 H02 TER M 3 14.412 17.614 16.871 0.00 1.00 SYST H +ATOM 20 O03 TER M 3 14.578 18.500 15.000 0.00 1.00 SYST O +ATOM 21 C04 TER M 3 16.000 18.500 15.000 0.00 1.00 SYST C +ATOM 22 H05 TER M 3 16.344 18.499 13.962 0.00 1.00 SYST H +ATOM 23 H06 TER M 3 16.382 17.602 15.496 0.00 1.00 SYST H +ATOM 24 H07 TER M 3 16.382 19.399 15.493 0.00 1.00 SYST H +ATOM 25 C00 TER M 4 12.535 18.498 16.276 0.00 1.00 SYST C +ATOM 26 H01 TER M 4 12.184 17.612 15.734 0.00 1.00 SYST H +ATOM 27 H02 TER M 4 12.182 19.385 15.735 0.00 1.00 SYST H +ATOM 28 O03 TER M 4 12.015 18.496 17.605 0.00 1.00 SYST O +ATOM 29 C04 TER M 4 10.593 18.493 17.605 0.00 1.00 SYST C +ATOM 30 H05 TER M 4 10.250 18.491 18.643 0.00 1.00 SYST H +ATOM 31 H06 TER M 4 10.213 17.595 17.109 0.00 1.00 SYST H +ATOM 32 H07 TER M 4 10.209 19.391 17.112 0.00 1.00 SYST H +CONECT 1 2 3 4 9 +CONECT 2 1 +CONECT 3 1 +CONECT 4 1 5 +CONECT 5 4 6 7 8 +CONECT 6 5 +CONECT 7 5 +CONECT 8 5 +CONECT 9 1 10 11 12 +CONECT 10 9 +CONECT 11 9 +CONECT 12 9 13 +CONECT 13 12 14 15 16 +CONECT 14 13 +CONECT 15 13 +CONECT 16 13 +CONECT 17 18 19 20 25 +CONECT 18 17 +CONECT 19 17 +CONECT 20 17 21 +CONECT 21 20 22 23 24 +CONECT 22 21 +CONECT 23 21 +CONECT 24 21 +CONECT 25 17 26 27 28 +CONECT 26 25 +CONECT 27 25 +CONECT 28 25 29 +CONECT 29 28 30 31 32 +CONECT 30 29 +CONECT 31 29 +CONECT 32 29 +END diff --git a/examples/peg_slater_isa/peg4.pdb b/examples/peg_slater_isa/peg4.pdb new file mode 100644 index 000000000..2c11081d1 --- /dev/null +++ b/examples/peg_slater_isa/peg4.pdb @@ -0,0 +1,63 @@ +REMARK +CRYST1 50.000 50.000 50.000 90.00 90.00 90.00 P 1 1 +ATOM 1 C00 TER 1 -2.962 3.637 -1.170 +ATOM 2 H01 TER 1 -2.608 4.142 -0.296 +ATOM 3 H02 TER 1 -4.032 3.635 -1.171 +ATOM 4 O03 TER 1 -2.484 2.289 -1.168 +ATOM 5 C04 TER 1 -2.961 1.615 0.000 +ATOM 6 H05 TER 1 -2.604 0.606 0.000 +ATOM 7 H06 TER 1 -2.604 2.119 0.874 +ATOM 8 H07 TER 1 -4.031 1.615 0.000 +ATOM 9 C00 INT 2 -2.449 6.384 -3.596 +ATOM 10 H01 INT 2 -2.804 5.879 -4.470 +ATOM 11 H02 INT 2 -1.379 6.386 -3.595 +ATOM 12 O03 INT 2 -2.927 5.710 -2.429 +ATOM 13 C04 INT 2 -2.448 4.362 -2.427 +ATOM 14 H05 INT 2 -2.803 3.856 -3.301 +ATOM 15 H06 INT 2 -1.378 4.364 -2.425 +ATOM 16 C00 INT 3 -2.966 9.857 -4.767 +ATOM 17 H01 INT 3 -2.612 10.363 -3.893 +ATOM 18 H02 INT 3 -4.036 9.855 -4.768 +ATOM 19 O03 INT 3 -2.488 8.509 -4.765 +ATOM 20 C04 INT 3 -2.965 7.835 -3.597 +ATOM 21 H05 INT 3 -2.610 8.340 -2.724 +ATOM 22 H06 INT 3 -4.035 7.833 -3.599 +ATOM 23 C00 TER 4 -2.452 10.582 -6.024 +ATOM 24 H01 TER 4 -2.807 10.077 -6.898 +ATOM 25 H02 TER 4 -1.382 10.584 -6.022 +ATOM 26 O03 TER 4 -2.931 11.930 -6.026 +ATOM 27 C04 TER 4 -2.453 12.604 -7.193 +ATOM 28 H05 TER 4 -2.808 12.099 -8.067 +ATOM 29 H06 TER 4 -2.812 13.613 -7.194 +ATOM 30 H07 TER 4 -1.383 12.606 -7.192 +TER +CONECT 5 6 +CONECT 5 7 +CONECT 5 8 +CONECT 5 4 +CONECT 4 1 +CONECT 1 2 +CONECT 1 3 +CONECT 1 13 +CONECT 13 14 +CONECT 13 15 +CONECT 13 12 +CONECT 12 9 +CONECT 9 10 +CONECT 9 11 +CONECT 9 20 +CONECT 20 21 +CONECT 20 22 +CONECT 20 19 +CONECT 19 16 +CONECT 16 17 +CONECT 16 18 +CONECT 16 23 +CONECT 23 24 +CONECT 23 25 +CONECT 23 26 +CONECT 26 27 +CONECT 27 28 +CONECT 27 29 +CONECT 27 30 +END diff --git a/examples/peg_slater_isa/remove_lr.py b/examples/peg_slater_isa/remove_lr.py new file mode 100755 index 000000000..125ab83e1 --- /dev/null +++ b/examples/peg_slater_isa/remove_lr.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python +import numpy as np +import openmm +from openmm import * +from openmm.app import * +from openmm.unit import * +import jax +import jax_md +import jax.numpy as jnp +import dmff +from dmff.api import Hamiltonian +import pickle +import time + + +if __name__ == '__main__': + ff = 'forcefield.xml' + pdb_AB = PDBFile('peg2_dimer.pdb') + pdb_A = PDBFile('peg2.pdb') + pdb_B = PDBFile('peg2.pdb') + H_AB = Hamiltonian(ff) + H_A = Hamiltonian(ff) + H_B = Hamiltonian(ff) + pme_generator_AB, \ + disp_generator_AB, \ + ex_generator_AB, \ + sr_es_generator_AB, \ + sr_pol_generator_AB, \ + sr_disp_generator_AB, \ + dhf_generator_AB, \ + dmp_es_generator_AB, \ + dmp_disp_generator_AB = H_AB.getGenerators() + pme_generator_A, \ + disp_generator_A, \ + ex_generator_A, \ + sr_es_generator_A, \ + sr_pol_generator_A, \ + sr_disp_generator_A, \ + dhf_generator_A, \ + dmp_es_generator_A, \ + dmp_disp_generator_A = H_A.getGenerators() + pme_generator_B, \ + disp_generator_B, \ + ex_generator_B, \ + sr_es_generator_B, \ + sr_pol_generator_B, \ + sr_disp_generator_B, \ + dhf_generator_B, \ + dmp_es_generator_B, \ + dmp_disp_generator_B = H_B.getGenerators() + + rc = 15 + + # get potential functions + potentials_AB = H_AB.createPotential(pdb_AB.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4) + pot_pme_AB, \ + pot_disp_AB, \ + pot_ex_AB, \ + pot_sr_es_AB, \ + pot_sr_pol_AB, \ + pot_sr_disp_AB, \ + pot_dhf_AB, \ + pot_dmp_es_AB, \ + pot_dmp_disp_AB = potentials_AB + potentials_A = H_A.createPotential(pdb_A.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4) + pot_pme_A, \ + pot_disp_A, \ + pot_ex_A, \ + pot_sr_es_A, \ + pot_sr_pol_A, \ + pot_sr_disp_A, \ + pot_dhf_A, \ + pot_dmp_es_A, \ + pot_dmp_disp_A = potentials_A + potentials_B = H_B.createPotential(pdb_B.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4) + pot_pme_B, \ + pot_disp_B, \ + pot_ex_B, \ + pot_sr_es_B, \ + pot_sr_pol_B, \ + pot_sr_disp_B, \ + pot_dhf_B, \ + pot_dmp_es_B, \ + pot_dmp_disp_B = potentials_B + + pos_AB0 = jnp.array(pdb_AB.positions._value) * 10 + n_atoms = len(pos_AB0) + n_atoms_A = n_atoms // 2 + n_atoms_B = n_atoms // 2 + pos_A0 = jnp.array(pdb_AB.positions._value[:n_atoms_A]) * 10 + pos_B0 = jnp.array(pdb_AB.positions._value[n_atoms_A:n_atoms]) * 10 + box = jnp.array(pdb_AB.topology.getPeriodicBoxVectors()._value) * 10 + # nn list initial allocation + displacement_fn, shift_fn = jax_md.space.periodic_general(box, fractional_coordinates=False) + neighbor_list_fn = jax_md.partition.neighbor_list(displacement_fn, box, rc, 0, format=jax_md.partition.OrderedSparse) + nbr_AB = neighbor_list_fn.allocate(pos_AB0) + nbr_A = neighbor_list_fn.allocate(pos_A0) + nbr_B = neighbor_list_fn.allocate(pos_B0) + pairs_AB = np.array(nbr_AB.idx.T) + pairs_A = np.array(nbr_A.idx.T) + pairs_B = np.array(nbr_B.idx.T) + pairs_AB = pairs_AB[pairs_AB[:, 0] < pairs_AB[:, 1]] + pairs_A = pairs_A[pairs_A[:, 0] < pairs_A[:, 1]] + pairs_B = pairs_B[pairs_B[:, 0] < pairs_B[:, 1]] + + # load data + with open('data.pickle', 'rb') as ifile: + data = pickle.load(ifile) + + keys = list(data.keys()) + keys.sort() + data_lr = {} + for sid in keys: + scan_res = data[sid] + scan_res['tot_full'] = scan_res['tot'].copy() + npts = len(scan_res['tot']) + # long range + scan_res_lr = {} + scan_res_lr['es'] = np.zeros(npts) + scan_res_lr['pol'] = np.zeros(npts) + scan_res_lr['disp'] = np.zeros(npts) + scan_res_lr['tot'] = np.zeros(npts) + print(sid) + + for ipt in range(npts): + E_es_ref = scan_res['es'][ipt] + E_pol_ref = scan_res['pol'][ipt] + E_disp_ref = scan_res['disp'][ipt] + E_ex_ref = scan_res['ex'][ipt] + E_dhf_ref = scan_res['dhf'][ipt] + E_tot_ref = scan_res['tot'][ipt] + + # get position array + pos_A = jnp.array(scan_res['posA'][ipt]) + pos_B = jnp.array(scan_res['posB'][ipt]) + pos_AB = jnp.concatenate([pos_A, pos_B], axis=0) + + + ##################### + # exchange repulsion + ##################### + # E_ex_AB = pot_ex_AB(pos_AB, box, pairs_AB, ex_generator_AB.params) + # E_ex_A = pot_ex_A(pos_A, box, pairs_A, ex_generator_AB.params) + # E_ex_B = pot_ex_B(pos_B, box, pairs_B, ex_generator_AB.params) + # E_ex = E_ex_AB - E_ex_A - E_ex_B + + ####################### + # electrostatic + pol + ####################### + E_AB = pot_pme_AB(pos_AB, box, pairs_AB, pme_generator_AB.params) + E_A = pot_pme_A(pos_A, box, pairs_A, pme_generator_A.params) + E_B = pot_pme_B(pos_B, box, pairs_A, pme_generator_B.params) + E_espol = E_AB - E_A - E_B + + # use induced dipole of monomers to compute electrostatic interaction + U_ind_AB = jnp.vstack((pme_generator_A.pme_force.U_ind, pme_generator_B.pme_force.U_ind)) + params = pme_generator_AB.params + map_atypes = pme_generator_AB.map_atomtype + Q_local = params['Q_local'][map_atypes] + pol = params['pol'][map_atypes] + tholes = params['tholes'][map_atypes] + pme_force = pme_generator_AB.pme_force + E_AB_nonpol = pme_force.energy_fn(pos_AB, box, pairs_AB, Q_local, U_ind_AB, pol, tholes, params['mScales'], params['pScales'], params['dScales']) + E_es = E_AB_nonpol - E_A - E_B + # E_dmp_es = pot_dmp_es_AB(pos_AB, box, pairs_AB, dmp_es_generator_AB.params) \ + # - pot_dmp_es_A(pos_A, box, pairs_A, dmp_es_generator_A.params) \ + # - pot_dmp_es_B(pos_B, box, pairs_B, dmp_es_generator_B.params) + # E_sr_es = pot_sr_es_AB(pos_AB, box, pairs_AB, sr_es_generator_AB.params) \ + # - pot_sr_es_A(pos_A, box, pairs_A, sr_es_generator_AB.params) \ + # - pot_sr_es_B(pos_B, box, pairs_B, sr_es_generator_AB.params) + + + ################################### + # polarization (induction) energy + ################################### + E_pol = E_espol - E_es + # E_sr_pol = pot_sr_pol_AB(pos_AB, box, pairs_AB, sr_pol_generator_AB.params) \ + # - pot_sr_pol_A(pos_A, box, pairs_A, sr_pol_generator_AB.params) \ + # - pot_sr_pol_B(pos_B, box, pairs_B, sr_pol_generator_AB.params) + + + ############# + # dispersion + ############# + E_AB_disp = pot_disp_AB(pos_AB, box, pairs_AB, disp_generator_AB.params) + E_A_disp = pot_disp_A(pos_A, box, pairs_A, disp_generator_AB.params) + E_B_disp = pot_disp_B(pos_B, box, pairs_B, disp_generator_AB.params) + E_disp = E_AB_disp - E_A_disp - E_B_disp + # E_dmp_disp = pot_dmp_disp_AB(pos_AB, box, pairs_AB, dmp_disp_generator_AB.params) \ + # - pot_dmp_disp_A(pos_A, box, pairs_A, dmp_disp_generator_A.params) \ + # - pot_dmp_disp_B(pos_B, box, pairs_B, dmp_disp_generator_B.params) + # E_sr_disp = pot_sr_disp_AB(pos_AB, box, pairs_AB, sr_disp_generator_AB.params) \ + # - pot_sr_disp_A(pos_A, box, pairs_A, sr_disp_generator_AB.params) \ + # - pot_sr_disp_B(pos_B, box, pairs_B, sr_disp_generator_AB.params) + + ########### + # dhf + ########### + # E_AB_dhf = pot_dhf_AB(pos_AB, box, pairs_AB, dhf_generator_AB.params) + # E_A_dhf = pot_dhf_A(pos_A, box, pairs_A, dhf_generator_AB.params) + # E_B_dhf = pot_dhf_B(pos_B, box, pairs_B, dhf_generator_AB.params) + # E_dhf = E_AB_dhf - E_A_dhf - E_B_dhf + + # remove long range + scan_res['es'][ipt] -= E_es + scan_res['pol'][ipt] -= E_pol + scan_res['disp'][ipt] -= E_disp + scan_res['tot'][ipt] -= (E_es + E_pol + E_disp) + # save long range + scan_res_lr['es'][ipt] = E_es + scan_res_lr['pol'][ipt] = E_pol + scan_res_lr['disp'][ipt] = E_disp + scan_res_lr['tot'][ipt] = E_es + E_pol + E_disp + data[sid] = scan_res + data_lr[sid] = scan_res_lr + + +with open('data_sr.pickle', 'wb') as ofile: + pickle.dump(data, ofile) + +with open('data_lr.pickle', 'wb') as ofile: + pickle.dump(data_lr, ofile) diff --git a/examples/peg_slater_isa/residues.xml b/examples/peg_slater_isa/residues.xml new file mode 100644 index 000000000..d8e0ea8e7 --- /dev/null +++ b/examples/peg_slater_isa/residues.xml @@ -0,0 +1,29 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/examples/peg_slater_isa/run_amoeba.py b/examples/peg_slater_isa/run_amoeba.py new file mode 100755 index 000000000..bf2f9b891 --- /dev/null +++ b/examples/peg_slater_isa/run_amoeba.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python +import sys +import numpy as np +import openmm +from openmm import * +from openmm.app import * +from openmm.unit import * +import pickle + +if __name__ == '__main__': + pdb_AB = PDBFile('peg2_dimer.pdb') + pdb_A = PDBFile('peg2.pdb') + pdb_B = PDBFile('peg2.pdb') + forcefield = ForceField('benchmark.xml') + + system_AB = forcefield.createSystem(pdb_AB.topology, nonbondedMethod=PME, nonbondedCutoff=15*angstrom) + system_A = forcefield.createSystem(pdb_A.topology, nonbondedMethod=PME, nonbondedCutoff=15*angstrom) + system_B = forcefield.createSystem(pdb_B.topology, nonbondedMethod=PME, nonbondedCutoff=15*angstrom) + forces_AB = system_AB.getForces() + forces_A = system_A.getForces() + forces_B = system_B.getForces() + for i in range(len(forces_AB)): + forces_AB[i].setForceGroup(i) + forces_A[i].setForceGroup(i) + forces_B[i].setForceGroup(i) + + platform_AB = Platform.getPlatformByName('CUDA') + platform_A = Platform.getPlatformByName('CUDA') + platform_B = Platform.getPlatformByName('CUDA') + properties = {} + + integrator_AB = LangevinIntegrator(300*kelvin, 1.0/picosecond, 1*femtosecond) + integrator_A = LangevinIntegrator(300*kelvin, 1.0/picosecond, 1*femtosecond) + integrator_B = LangevinIntegrator(300*kelvin, 1.0/picosecond, 1*femtosecond) + + simulation_AB = Simulation(pdb_AB.topology, system_AB, integrator_AB, platform=platform_AB) + simulation_A = Simulation(pdb_A.topology, system_A, integrator_A, platform=platform_A) + simulation_B = Simulation(pdb_B.topology, system_B, integrator_B, platform=platform_B) + + + pos_AB0 = np.array(pdb_AB.positions._value) * 10 + n_atoms = len(pos_AB0) + n_atoms_A = n_atoms // 2 + n_atoms_B = n_atoms // 2 + pos_A0 = pos_AB0[:n_atoms_A] + pos_B0 = pos_AB0[n_atoms_A: n_atoms] + + # dr = np.average(pos_B0) - np.average(pos_A0) + # dn = dr / np.linalg.norm(dr) + + # for dz in np.arange(0, 4, 0.1): + # pos_A = pos_A0 + # pos_B = pos_B0 + dz * dn + # pos_AB = np.vstack((pos_A, pos_B)) + # simulation_AB.context.setPositions(pos_AB * angstrom) + # simulation_A.context.setPositions(pos_A * angstrom) + # simulation_B.context.setPositions(pos_B * angstrom) + + # state_AB = simulation_AB.context.getState(getEnergy=True) + # state_A = simulation_A.context.getState(getEnergy=True) + # state_B = simulation_B.context.getState(getEnergy=True) + + # E_AB = state_AB.getPotentialEnergy()._value + # E_A = state_A.getPotentialEnergy()._value + # E_B = state_B.getPotentialEnergy()._value + + # print(dz, E_AB - E_A - E_B) + + with open('data.pickle', 'rb') as ifile: + data = pickle.load(ifile) + + for sid in ['000']: + scan_res = data[sid] + + for ipt in range(len(scan_res['posA'])): + pos_A = np.array(scan_res['posA'][ipt]) + pos_B = np.array(scan_res['posB'][ipt]) + pos_AB = np.vstack([pos_A, pos_B]) + E_es_ref = scan_res['es'][ipt] + E_pol_ref = scan_res['pol'][ipt] + + simulation_AB.context.setPositions(pos_AB * angstrom) + simulation_A.context.setPositions(pos_A * angstrom) + simulation_B.context.setPositions(pos_B * angstrom) + + state_AB = simulation_AB.context.getState(getEnergy=True, groups=2**0) + state_A = simulation_A.context.getState(getEnergy=True, groups=2**0) + state_B = simulation_B.context.getState(getEnergy=True, groups=2**0) + # state_AB = simulation_AB.context.getState(getEnergy=True) + # state_A = simulation_A.context.getState(getEnergy=True) + # state_B = simulation_B.context.getState(getEnergy=True) + + E_AB = state_AB.getPotentialEnergy()._value + E_A = state_A.getPotentialEnergy()._value + E_B = state_B.getPotentialEnergy()._value + + print(E_AB - E_A - E_B) diff --git a/examples/peg_slater_isa/tot_slater.xml b/examples/peg_slater_isa/tot_slater.xml new file mode 100644 index 000000000..105d498cf --- /dev/null +++ b/examples/peg_slater_isa/tot_slater.xml @@ -0,0 +1,78 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/examples/sgnn/model1.pickle b/examples/sgnn/model1.pickle index a9efb1262..0c3959cd9 100644 Binary files a/examples/sgnn/model1.pickle and b/examples/sgnn/model1.pickle differ diff --git a/examples/sgnn/pth2pickle.py b/examples/sgnn/pth2pickle.py index 0ef791de3..d05bb9081 100755 --- a/examples/sgnn/pth2pickle.py +++ b/examples/sgnn/pth2pickle.py @@ -8,6 +8,9 @@ pth = sys.argv[1] state_dict = torch.load(sys.argv[1]) +for k in state_dict: + state_dict[k] = state_dict[k].numpy() + ofn = re.sub('\.pth$', '.pickle', pth) with open(ofn, 'wb') as ofile: pickle.dump(state_dict, ofile)