Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions dmff/admp/disp_pme.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from functools import partial

import jax.numpy as jnp
from jax import vmap, value_and_grad
from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales
from dmff.admp.spatial import pbc_shift
from dmff.admp.pairwise import (distribute_dispcoeff, distribute_scalar,
distribute_v3)
from dmff.admp.pme import setup_ewald_parameters
from dmff.admp.recip import generate_pme_recip, Ck_6, Ck_8, Ck_10
from dmff.admp.pairwise import distribute_scalar, distribute_v3, distribute_dispcoeff
from functools import partial
from dmff.admp.recip import Ck_6, Ck_8, Ck_10, generate_pme_recip
from dmff.admp.spatial import pbc_shift
from dmff.utils import jit_condition, pair_buffer_scales, regularize_pairs
from jax import value_and_grad, vmap


class ADMPDispPmeForce:
'''
Expand Down
27 changes: 11 additions & 16 deletions dmff/admp/mbpol_intra.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
import sys

import numpy as np
import jax.numpy as jnp
from jax import grad, value_and_grad
from dmff.settings import DO_JIT
from dmff.utils import jit_condition
import numpy as np
from dmff.admp.spatial import v_pbc_shift
from dmff.admp.pme import ADMPPmeForce
from dmff.admp.parser import *
from dmff.utils import jit_condition
from jax import vmap
import time

#const
f5z = 0.999677885
fbasis = 0.15860145369897
fcore = -1.6351695982132
frest = 1.0
reoh = 0.958649;
thetae = 104.3475;
b1 = 2.0;
roh = 0.9519607159623009;
alphaoh = 2.587949757553683;
deohA = 42290.92019288289;
phh1A = 16.94879431193463;
phh2 = 12.66426998162947;
reoh = 0.958649
thetae = 104.3475
b1 = 2.0
roh = 0.9519607159623009
alphaoh = 2.587949757553683
deohA = 42290.92019288289
phh1A = 16.94879431193463
phh2 = 12.66426998162947

c5zA = jnp.array([4.2278462684916e+04, 4.5859382909906e-02, 9.4804986183058e+03,
7.5485566680955e+02, 1.9865052511496e+03, 4.3768071560862e+02,
Expand Down Expand Up @@ -487,4 +483,3 @@ def onebody_kernel(x1, x2, x3, Va, Vb, efac):
e1 *= cm1_kcalmol
e1 *= cal2joule # conver cal 2 j
return e1

10 changes: 5 additions & 5 deletions dmff/admp/multipole.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import sys
from functools import partial

import jax.numpy as jnp
from jax import vmap
from dmff.utils import jit_condition
from functools import partial
from jax import vmap

# This module deals with the transformations and rotations of multipoles

Expand Down Expand Up @@ -48,7 +48,7 @@ def convert_cart2harm(Theta, lmax):
n * (l+1)^2, stores the spherical multipoles
'''
if lmax > 2:
sys.exit('l > 2 (beyond quadrupole) not supported')
raise ValueError('l > 2 (beyond quadrupole) not supported')

Q_mono = Theta[0:1]

Expand Down Expand Up @@ -90,7 +90,7 @@ def convert_harm2cart(Q, lmax):
'''

if lmax > 2:
sys.exit('l > 2 (beyond quadrupole) not supported')
raise ValueError('l > 2 (beyond quadrupole) not supported')

T_mono = Q[0:1]

Expand Down
111 changes: 4 additions & 107 deletions dmff/admp/pairwise.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import sys
from jax import vmap
from functools import partial

import jax.numpy as jnp
from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales
from dmff.admp.spatial import v_pbc_shift
from functools import partial
from dmff.utils import jit_condition, pair_buffer_scales, regularize_pairs
from jax import vmap

DIELECTRIC = 1389.35455846

Expand Down Expand Up @@ -170,106 +170,3 @@ def slater_sr_kernel(dr, m, ai, aj, bi, bj):
P = 1/3 * br2 + br + 1
return a * P * jnp.exp(-br) * m


def validation(pdb):
xml = 'mpidwater.xml'
pdbinfo = read_pdb(pdb)
serials = pdbinfo['serials']
names = pdbinfo['names']
resNames = pdbinfo['resNames']
resSeqs = pdbinfo['resSeqs']
positions = pdbinfo['positions']
box = pdbinfo['box'] # a, b, c, α, β, γ
charges = pdbinfo['charges']
positions = jnp.asarray(positions)
lx, ly, lz, _, _, _ = box
box = jnp.eye(3)*jnp.array([lx, ly, lz])

mScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
pScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
dScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])

rc = 4 # in Angstrom
ethresh = 1e-4

n_atoms = len(serials)

atomTemplate, residueTemplate = read_xml(xml)
atomDicts, residueDicts = init_residues(serials, names, resNames, resSeqs, positions, charges, atomTemplate, residueTemplate)

covalent_map = assemble_covalent(residueDicts, n_atoms)
displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=False)
neighbor_list_fn = partition.neighbor_list(displacement_fn, box, rc, 0, format=partition.OrderedSparse)
nbr = neighbor_list_fn.allocate(positions)
pairs = nbr.idx.T

pmax = 10
kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
kappa = 0.657065221219616

# construct the C list
c_list = np.zeros((3, n_atoms))
a_list = np.zeros(n_atoms)
q_list = np.zeros(n_atoms)
b_list = np.zeros(n_atoms)
nmol=int(n_atoms/3)
for i in range(nmol):
a = i*3
b = i*3+1
c = i*3+2
# dispersion coeff
c_list[0][a]=37.199677405
c_list[0][b]=7.6111103
c_list[0][c]=7.6111103
c_list[1][a]=85.26810658
c_list[1][b]=11.90220148
c_list[1][c]=11.90220148
c_list[2][a]=134.44874488
c_list[2][b]=15.05074749
c_list[2][c]=15.05074749
# q
q_list[a] = -0.741706
q_list[b] = 0.370853
q_list[c] = 0.370853
# b, Bohr^-1
b_list[a] = 2.00095977
b_list[b] = 1.999519942
b_list[c] = 1.999519942
# a, Hartree
a_list[a] = 458.3777
a_list[b] = 0.0317
a_list[c] = 0.0317


c_list = jnp.array(c_list)

# @partial(vmap, in_axes=(0, 0, 0, 0), out_axes=(0))
# @jit_condition(static_argnums=())
# def disp6_pme_real_kernel(dr, m, ci, cj):
# # unpack static arguments
# kappa = static_args['kappa']
# # calculate distance
# dr2 = dr ** 2
# dr6 = dr2 ** 3
# # do calculation
# x2 = kappa**2 * dr2
# exp_x2 = jnp.exp(-x2)
# x4 = x2 * x2
# g = (1 + x2 + 0.5*x4) * exp_x2
# return (m + g - 1) * ci * cj / dr6

# static_args = {'kappa': kappa}
# disp6_pme_real = generate_pairwise_interaction(disp6_pme_real_kernel, covalent_map, static_args)
# print(disp6_pme_real(positions, box, pairs, mScales, c_list[0, :]))

TT_damping_qq_c6 = generate_pairwise_interaction(TT_damping_qq_c6_kernel, covalent_map, static_args={})

TT_damping_qq_c6(positions, box, pairs, mScales, a_list, b_list, q_list, c_list[0])
print('ok')
print(TT_damping_qq_c6(positions, box, pairs, mScales, a_list, b_list, q_list, c_list[0]))
return


# below is the validation code
if __name__ == '__main__':
validation(sys.argv[1])
1 change: 0 additions & 1 deletion dmff/admp/recip.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import numpy as np
import jax.numpy as jnp
import jax.scipy as jsp
Expand Down
Loading