Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 41 additions & 106 deletions dmff/admp/disp_pme.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import jax.numpy as jnp
from jax import vmap, value_and_grad
from dmff.utils import jit_condition
from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales
from dmff.admp.spatial import pbc_shift
from dmff.admp.pme import setup_ewald_parameters
from dmff.admp.recip import generate_pme_recip, Ck_6, Ck_8, Ck_10
Expand All @@ -14,17 +14,24 @@ class ADMPDispPmeForce:
The so called "environment paramters" means parameters that do not need to be differentiable
'''

def __init__(self, box, covalent_map, rc, ethresh, pmax):
def __init__(self, box, covalent_map, rc, ethresh, pmax, lpme=True):
self.covalent_map = covalent_map
self.rc = rc
self.ethresh = ethresh
self.pmax = pmax
# Need a different function for dispersion ??? Need tests
kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
self.kappa = kappa
self.K1 = K1
self.K2 = K2
self.K3 = K3
self.lpme = lpme
if lpme:
kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
self.kappa = kappa
self.K1 = K1
self.K2 = K2
self.K3 = K3
else:
self.kappa = 0.0
self.K1 = 0
self.K2 = 0
self.K3 = 0
self.pme_order = 6
# setup calculators
self.refresh_calculators()
Expand All @@ -36,7 +43,7 @@ def get_energy(positions, box, pairs, c_list, mScales):
return energy_disp_pme(positions, box, pairs,
c_list, mScales, self.covalent_map,
self.kappa, self.K1, self.K2, self.K3, self.pmax,
self.d6_recip, self.d8_recip, self.d10_recip)
self.d6_recip, self.d8_recip, self.d10_recip, lpme=self.lpme)
return get_energy


Expand Down Expand Up @@ -70,7 +77,7 @@ def refresh_calculators(self):
def energy_disp_pme(positions, box, pairs,
c_list, mScales, covalent_map,
kappa, K1, K2, K3, pmax,
recip_fn6, recip_fn8, recip_fn10):
recip_fn6, recip_fn8, recip_fn10, lpme=True):
'''
Top level wrapper for dispersion pme

Expand All @@ -95,22 +102,29 @@ def energy_disp_pme(positions, box, pairs,
int: max K for reciprocal calculations
pmax:
int array: maximal exponents (p) to compute, e.g., (6, 8, 10)
lpme:
bool: whether do pme or not, useful when doing cluster calculations

Output:
energy: total dispersion pme energy
'''

ene_real = disp_pme_real(positions, box, pairs, c_list, mScales, covalent_map, kappa, pmax)
if lpme is False:
kappa = 0

ene_recip = recip_fn6(positions, box, c_list[:, 0, jnp.newaxis])
if pmax >= 8:
ene_recip += recip_fn8(positions, box, c_list[:, 1, jnp.newaxis])
if pmax >= 10:
ene_recip += recip_fn10(positions, box, c_list[:, 2, jnp.newaxis])
ene_real = disp_pme_real(positions, box, pairs, c_list, mScales, covalent_map, kappa, pmax)

ene_self = disp_pme_self(c_list, kappa, pmax)
if lpme:
ene_recip = recip_fn6(positions, box, c_list[:, 0, jnp.newaxis])
if pmax >= 8:
ene_recip += recip_fn8(positions, box, c_list[:, 1, jnp.newaxis])
if pmax >= 10:
ene_recip += recip_fn10(positions, box, c_list[:, 2, jnp.newaxis])
ene_self = disp_pme_self(c_list, kappa, pmax)
return ene_real + ene_recip + ene_self

return ene_real + ene_recip + ene_self
else:
return ene_real


def disp_pme_real(positions, box, pairs,
Expand Down Expand Up @@ -144,24 +158,26 @@ def disp_pme_real(positions, box, pairs,
'''

# expand pairwise parameters
pairs = pairs[pairs[:, 0] < pairs[:, 1]]
# pairs = pairs[pairs[:, 0] < pairs[:, 1]]
pairs = regularize_pairs(pairs)

box_inv = jnp.linalg.inv(box)

ri = distribute_v3(positions, pairs[:, 0])
rj = distribute_v3(positions, pairs[:, 1])
# ri = positions[pairs[:, 0]]
# rj = positions[pairs[:, 1]]
nbonds = covalent_map[pairs[:, 0], pairs[:, 1]]
mscales = distribute_scalar(mScales, nbonds-1)
# mscales = mScales[nbonds-1]

buffer_scales = pair_buffer_scales(pairs)
mscales = mscales * buffer_scales

ci = distribute_dispcoeff(c_list, pairs[:, 0])
cj = distribute_dispcoeff(c_list, pairs[:, 1])
# ci = c_list[pairs[:, 0], :]
# cj = c_list[pairs[:, 1], :]

ene_real = jnp.sum(disp_pme_real_kernel(ri, rj, ci, cj, box, box_inv, mscales, kappa, pmax))
ene_real = jnp.sum(
disp_pme_real_kernel(ri, rj, ci, cj, box, box_inv, mscales, kappa, pmax)
* buffer_scales
)

return jnp.sum(ene_real)

Expand Down Expand Up @@ -193,6 +209,7 @@ def disp_pme_real_kernel(ri, rj, ci, cj, box, box_inv, mscales, kappa, pmax):
dr = ri - rj
dr = pbc_shift(dr, box, box_inv)
dr2 = jnp.dot(dr, dr)

x2 = kappa * kappa * dr2
g = g_p(x2, pmax)
dr6 = dr2 * dr2 * dr2
Expand Down Expand Up @@ -269,85 +286,3 @@ def disp_pme_self(c_list, kappa, pmax):
return E


# def validation(pdb):
# xml = 'mpidwater.xml'
# pdbinfo = read_pdb(pdb)
# serials = pdbinfo['serials']
# names = pdbinfo['names']
# resNames = pdbinfo['resNames']
# resSeqs = pdbinfo['resSeqs']
# positions = pdbinfo['positions']
# box = pdbinfo['box'] # a, b, c, α, β, γ
# charges = pdbinfo['charges']
# positions = jnp.asarray(positions)
# lx, ly, lz, _, _, _ = box
# box = jnp.eye(3)*jnp.array([lx, ly, lz])

# mScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
# pScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
# dScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])

# rc = 4 # in Angstrom
# ethresh = 1e-4

# n_atoms = len(serials)

# atomTemplate, residueTemplate = read_xml(xml)
# atomDicts, residueDicts = init_residues(serials, names, resNames, resSeqs, positions, charges, atomTemplate, residueTemplate)

# covalent_map = assemble_covalent(residueDicts, n_atoms)
# displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=False)
# neighbor_list_fn = partition.neighbor_list(displacement_fn, box, rc, 0, format=partition.OrderedSparse)
# nbr = neighbor_list_fn.allocate(positions)
# pairs = nbr.idx.T

# pmax = 10
# kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
# kappa = 0.657065221219616

# # construct the C list
# c_list = np.zeros((3,n_atoms))
# nmol=int(n_atoms/3)
# for i in range(nmol):
# a = i*3
# b = i*3+1
# c = i*3+2
# c_list[0][a]=37.19677405
# c_list[0][b]=7.6111103
# c_list[0][c]=7.6111103
# c_list[1][a]=85.26810658
# c_list[1][b]=11.90220148
# c_list[1][c]=11.90220148
# c_list[2][a]=134.44874488
# c_list[2][b]=15.05074749
# c_list[2][c]=15.05074749
# c_list = jnp.array(c_list.T)


# # Finish data preparation
# # -------------------------------------------------------------------------------------
# # pme_order = 6
# # d6_recip = generate_pme_recip(Ck_6, kappa, True, pme_order, K1, K2, K3, 0)
# # d8_recip = generate_pme_recip(Ck_8, kappa, True, pme_order, K1, K2, K3, 0)
# # d10_recip = generate_pme_recip(Ck_10, kappa, True, pme_order, K1, K2, K3, 0)
# # disp_pme_recip_fns = [d6_recip, d8_recip, d10_recip]
# # energy_force_disp_pme = value_and_grad(energy_disp_pme)
# # e, f = energy_force_disp_pme(positions, box, pairs, c_list, mScales, covalent_map, kappa, K1, K2, K3, pmax, *disp_pme_recip_fns)
# # print('ok')
# # e, f = energy_force_disp_pme(positions, box, pairs, c_list, mScales, covalent_map, kappa, K1, K2, K3, pmax, *disp_pme_recip_fns)
# # print(e)

# disp_pme_force = ADMPDispPmeForce(box, covalent_map, rc, ethresh, pmax)
# disp_pme_force.update_env('kappa', 0.657065221219616)

# print(c_list[:4])
# E, F = disp_pme_force.get_forces(positions, box, pairs, c_list, mScales)
# print('ok')
# E, F = disp_pme_force.get_forces(positions, box, pairs, c_list, mScales)
# print(E)
# return


# # below is the validation code
# if __name__ == '__main__':
# validation(sys.argv[1])
Loading