diff --git a/dmff/admp/disp_pme.py b/dmff/admp/disp_pme.py
index 11c30eab0..722bf22fb 100755
--- a/dmff/admp/disp_pme.py
+++ b/dmff/admp/disp_pme.py
@@ -1,11 +1,14 @@
+from functools import partial
+
import jax.numpy as jnp
-from jax import vmap, value_and_grad
-from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales
-from dmff.admp.spatial import pbc_shift
+from dmff.admp.pairwise import (distribute_dispcoeff, distribute_scalar,
+ distribute_v3)
from dmff.admp.pme import setup_ewald_parameters
-from dmff.admp.recip import generate_pme_recip, Ck_6, Ck_8, Ck_10
-from dmff.admp.pairwise import distribute_scalar, distribute_v3, distribute_dispcoeff
-from functools import partial
+from dmff.admp.recip import Ck_6, Ck_8, Ck_10, generate_pme_recip
+from dmff.admp.spatial import pbc_shift
+from dmff.utils import jit_condition, pair_buffer_scales, regularize_pairs
+from jax import value_and_grad, vmap
+
class ADMPDispPmeForce:
'''
diff --git a/dmff/admp/mbpol_intra.py b/dmff/admp/mbpol_intra.py
index c850b0acc..3a9966e13 100755
--- a/dmff/admp/mbpol_intra.py
+++ b/dmff/admp/mbpol_intra.py
@@ -1,28 +1,24 @@
-import sys
+
import numpy as np
import jax.numpy as jnp
-from jax import grad, value_and_grad
-from dmff.settings import DO_JIT
-from dmff.utils import jit_condition
+import numpy as np
from dmff.admp.spatial import v_pbc_shift
-from dmff.admp.pme import ADMPPmeForce
-from dmff.admp.parser import *
+from dmff.utils import jit_condition
from jax import vmap
-import time
#const
f5z = 0.999677885
fbasis = 0.15860145369897
fcore = -1.6351695982132
frest = 1.0
-reoh = 0.958649;
-thetae = 104.3475;
-b1 = 2.0;
-roh = 0.9519607159623009;
-alphaoh = 2.587949757553683;
-deohA = 42290.92019288289;
-phh1A = 16.94879431193463;
-phh2 = 12.66426998162947;
+reoh = 0.958649
+thetae = 104.3475
+b1 = 2.0
+roh = 0.9519607159623009
+alphaoh = 2.587949757553683
+deohA = 42290.92019288289
+phh1A = 16.94879431193463
+phh2 = 12.66426998162947
c5zA = jnp.array([4.2278462684916e+04, 4.5859382909906e-02, 9.4804986183058e+03,
7.5485566680955e+02, 1.9865052511496e+03, 4.3768071560862e+02,
@@ -487,4 +483,3 @@ def onebody_kernel(x1, x2, x3, Va, Vb, efac):
e1 *= cm1_kcalmol
e1 *= cal2joule # conver cal 2 j
return e1
-
diff --git a/dmff/admp/multipole.py b/dmff/admp/multipole.py
index 7f957a823..e863f9e72 100644
--- a/dmff/admp/multipole.py
+++ b/dmff/admp/multipole.py
@@ -1,8 +1,8 @@
-import sys
+from functools import partial
+
import jax.numpy as jnp
-from jax import vmap
from dmff.utils import jit_condition
-from functools import partial
+from jax import vmap
# This module deals with the transformations and rotations of multipoles
@@ -48,7 +48,7 @@ def convert_cart2harm(Theta, lmax):
n * (l+1)^2, stores the spherical multipoles
'''
if lmax > 2:
- sys.exit('l > 2 (beyond quadrupole) not supported')
+ raise ValueError('l > 2 (beyond quadrupole) not supported')
Q_mono = Theta[0:1]
@@ -90,7 +90,7 @@ def convert_harm2cart(Q, lmax):
'''
if lmax > 2:
- sys.exit('l > 2 (beyond quadrupole) not supported')
+ raise ValueError('l > 2 (beyond quadrupole) not supported')
T_mono = Q[0:1]
diff --git a/dmff/admp/pairwise.py b/dmff/admp/pairwise.py
index 3ff33ad1a..c0adc0837 100755
--- a/dmff/admp/pairwise.py
+++ b/dmff/admp/pairwise.py
@@ -1,9 +1,9 @@
-import sys
-from jax import vmap
+from functools import partial
+
import jax.numpy as jnp
-from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales
from dmff.admp.spatial import v_pbc_shift
-from functools import partial
+from dmff.utils import jit_condition, pair_buffer_scales, regularize_pairs
+from jax import vmap
DIELECTRIC = 1389.35455846
@@ -170,106 +170,3 @@ def slater_sr_kernel(dr, m, ai, aj, bi, bj):
P = 1/3 * br2 + br + 1
return a * P * jnp.exp(-br) * m
-
-def validation(pdb):
- xml = 'mpidwater.xml'
- pdbinfo = read_pdb(pdb)
- serials = pdbinfo['serials']
- names = pdbinfo['names']
- resNames = pdbinfo['resNames']
- resSeqs = pdbinfo['resSeqs']
- positions = pdbinfo['positions']
- box = pdbinfo['box'] # a, b, c, α, β, γ
- charges = pdbinfo['charges']
- positions = jnp.asarray(positions)
- lx, ly, lz, _, _, _ = box
- box = jnp.eye(3)*jnp.array([lx, ly, lz])
-
- mScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
- pScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
- dScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
-
- rc = 4 # in Angstrom
- ethresh = 1e-4
-
- n_atoms = len(serials)
-
- atomTemplate, residueTemplate = read_xml(xml)
- atomDicts, residueDicts = init_residues(serials, names, resNames, resSeqs, positions, charges, atomTemplate, residueTemplate)
-
- covalent_map = assemble_covalent(residueDicts, n_atoms)
- displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=False)
- neighbor_list_fn = partition.neighbor_list(displacement_fn, box, rc, 0, format=partition.OrderedSparse)
- nbr = neighbor_list_fn.allocate(positions)
- pairs = nbr.idx.T
-
- pmax = 10
- kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
- kappa = 0.657065221219616
-
- # construct the C list
- c_list = np.zeros((3, n_atoms))
- a_list = np.zeros(n_atoms)
- q_list = np.zeros(n_atoms)
- b_list = np.zeros(n_atoms)
- nmol=int(n_atoms/3)
- for i in range(nmol):
- a = i*3
- b = i*3+1
- c = i*3+2
- # dispersion coeff
- c_list[0][a]=37.199677405
- c_list[0][b]=7.6111103
- c_list[0][c]=7.6111103
- c_list[1][a]=85.26810658
- c_list[1][b]=11.90220148
- c_list[1][c]=11.90220148
- c_list[2][a]=134.44874488
- c_list[2][b]=15.05074749
- c_list[2][c]=15.05074749
- # q
- q_list[a] = -0.741706
- q_list[b] = 0.370853
- q_list[c] = 0.370853
- # b, Bohr^-1
- b_list[a] = 2.00095977
- b_list[b] = 1.999519942
- b_list[c] = 1.999519942
- # a, Hartree
- a_list[a] = 458.3777
- a_list[b] = 0.0317
- a_list[c] = 0.0317
-
-
- c_list = jnp.array(c_list)
-
-# @partial(vmap, in_axes=(0, 0, 0, 0), out_axes=(0))
-# @jit_condition(static_argnums=())
-# def disp6_pme_real_kernel(dr, m, ci, cj):
-# # unpack static arguments
-# kappa = static_args['kappa']
-# # calculate distance
-# dr2 = dr ** 2
-# dr6 = dr2 ** 3
-# # do calculation
-# x2 = kappa**2 * dr2
-# exp_x2 = jnp.exp(-x2)
-# x4 = x2 * x2
-# g = (1 + x2 + 0.5*x4) * exp_x2
-# return (m + g - 1) * ci * cj / dr6
-
-# static_args = {'kappa': kappa}
-# disp6_pme_real = generate_pairwise_interaction(disp6_pme_real_kernel, covalent_map, static_args)
-# print(disp6_pme_real(positions, box, pairs, mScales, c_list[0, :]))
-
- TT_damping_qq_c6 = generate_pairwise_interaction(TT_damping_qq_c6_kernel, covalent_map, static_args={})
-
- TT_damping_qq_c6(positions, box, pairs, mScales, a_list, b_list, q_list, c_list[0])
- print('ok')
- print(TT_damping_qq_c6(positions, box, pairs, mScales, a_list, b_list, q_list, c_list[0]))
- return
-
-
-# below is the validation code
-if __name__ == '__main__':
- validation(sys.argv[1])
diff --git a/dmff/admp/recip.py b/dmff/admp/recip.py
index 66ee5e7bc..3987f23d8 100755
--- a/dmff/admp/recip.py
+++ b/dmff/admp/recip.py
@@ -1,4 +1,3 @@
-
import numpy as np
import jax.numpy as jnp
import jax.scipy as jsp
diff --git a/dmff/api.py b/dmff/api.py
index 85864d210..dfd7fa6de 100644
--- a/dmff/api.py
+++ b/dmff/api.py
@@ -10,8 +10,8 @@
import openmm as mm
import openmm.app as app
-import openmm.unit as unit
import openmm.app.element as elem
+import openmm.unit as unit
from dmff.admp.disp_pme import ADMPDispPmeForce
from dmff.admp.multipole import convert_cart2harm, convert_harm2cart
@@ -33,7 +33,14 @@
LennardJonesLongRangeForce,
CoulombPMEForce,
CoulNoCutoffForce,
+ CoulombPMEForce,
CoulReactionFieldForce,
+ LennardJonesForce,
+)
+from .classical.intra import (
+ HarmonicAngleJaxForce,
+ HarmonicBondJaxForce,
+ PeriodicTorsionJaxForce,
)
from dmff.classical.fep import (
LennardJonesFreeEnergyForce,
@@ -44,11 +51,9 @@
class XMLNodeInfo:
-
@staticmethod
- def to_str(value)->str:
- """ convert value to string if it can
- """
+ def to_str(value) -> str:
+ """convert value to string if it can"""
if isinstance(value, str):
return value
elif isinstance(value, (jnp.ndarray, np.ndarray)):
@@ -62,17 +67,16 @@ def to_str(value)->str:
return str(value)
class XMLElementInfo:
-
def __init__(self, name):
self.name = name
self.attributes = {}
-
+
def addAttribute(self, key, value):
self.attributes[key] = XMLNodeInfo.to_str(value)
-
+
def __repr__(self):
return f'<{self.name} {" ".join([f"{k}={v}" for k, v in self.attributes.items()])}>'
-
+
def __getitem__(self, name):
return self.attributes[name]
@@ -80,13 +84,12 @@ def __init__(self, name):
self.name = name
self.attributes = {}
self.elements = []
-
+
def __getitem__(self, name):
if isinstance(name, str):
return self.attributes[name]
elif isinstance(name, int):
return self.elements[name]
-
def addAttribute(self, key, value):
self.attributes[key] = XMLNodeInfo.to_str(value)
@@ -103,10 +106,9 @@ def modResidue(self, residue, atom, key, value):
def __repr__(self):
# tricy string formatting
left = f'<{self.name} {" ".join([f"{k}={v}" for k, v in self.attributes.items()])}> \n\t'
- right = f'<\\{self.name}>'
- content = '\n\t'.join([repr(e) for e in self.elements])
- return left + content + '\n' + right
-
+ right = f"<\\{self.name}>"
+ content = "\n\t".join([repr(e) for e in self.elements])
+ return left + content + "\n" + right
def get_line_context(file_path, line_number):
@@ -136,8 +138,8 @@ def build_covalent_map(data, max_neighbor):
def findAtomTypeTexts(attribs, num):
typetxt = []
- for n in range(1, num+1):
- for key in ["type%i"%n, "class%i"%n]:
+ for n in range(1, num + 1):
+ for key in ["type%i" % n, "class%i" % n]:
if key in attribs:
typetxt.append((key, attribs[key]))
break
@@ -213,8 +215,9 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
if "ethresh" in args:
self.ethresh = args["ethresh"]
- Force_DispPME = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh,
- self.pmax, lpme=self.lpme)
+ Force_DispPME = ADMPDispPmeForce(
+ box, covalent_map, rc, self.ethresh, self.pmax, lpme=self.lpme
+ )
self.disp_pme_force = Force_DispPME
pot_fn_lr = Force_DispPME.get_energy
pot_fn_sr = generate_pairwise_interaction(
@@ -247,36 +250,41 @@ def getJaxPotential(self):
def renderXML(self):
# generate xml force field file
- finfo = XMLNodeInfo('ADMPDispForce')
- finfo.addAttribute('mScale12', self.params["mScales"][0])
- finfo.addAttribute('mScale13', self.params["mScales"][1])
- finfo.addAttribute('mScale14', self.params["mScales"][2])
- finfo.addAttribute('mScale15', self.params["mScales"][3])
- finfo.addAttribute('mScale16', self.params["mScales"][4])
-
+ finfo = XMLNodeInfo("ADMPDispForce")
+ finfo.addAttribute("mScale12", self.params["mScales"][0])
+ finfo.addAttribute("mScale13", self.params["mScales"][1])
+ finfo.addAttribute("mScale14", self.params["mScales"][2])
+ finfo.addAttribute("mScale15", self.params["mScales"][3])
+ finfo.addAttribute("mScale16", self.params["mScales"][4])
+
for i in range(len(self.types)):
- ainfo = {'type': self.types[i], 'A': self.params["A"][i], 'B': self.params["B"][i], 'Q': self.params["Q"][i], 'C6': self.params["C6"][i], 'C8': self.params["C8"][i], 'C10': self.params["C10"][i]}
- finfo.addElement('Atom', ainfo)
-
+ ainfo = {
+ "type": self.types[i],
+ "A": self.params["A"][i],
+ "B": self.params["B"][i],
+ "Q": self.params["Q"][i],
+ "C6": self.params["C6"][i],
+ "C8": self.params["C8"][i],
+ "C10": self.params["C10"][i],
+ }
+ finfo.addElement("Atom", ainfo)
+
return finfo
+
# register all parsers
app.forcefield.parsers["ADMPDispForce"] = ADMPDispGenerator.parseElement
class ADMPDispPmeGenerator:
- r'''
+ r"""
This one computes the undamped C6/C8/C10 interactions
u = \sum_{ij} c6/r^6 + c8/r^8 + c10/r^10
- '''
+ """
def __init__(self, hamiltonian):
self.ff = hamiltonian
- self.params = {
- "C6": [],
- "C8": [],
- "C10": []
- }
+ self.params = {"C6": [], "C8": [], "C10": []}
self._jaxPotential = None
self.types = []
self.ethresh = 5e-4
@@ -306,8 +314,7 @@ def parseElement(element, hamiltonian):
generator.params[k] = jnp.array(generator.params[k])
generator.types = np.array(generator.types)
- def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
- args):
+ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
methodMap = {
app.CutoffPeriodic: "CutoffPeriodic",
app.NoCutoff: "NoCutoff",
@@ -337,17 +344,18 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
rc = nonbondedCutoff.value_in_unit(unit.angstrom)
# get calculator
- if 'ethresh' in args:
- self.ethresh = args['ethresh']
+ if "ethresh" in args:
+ self.ethresh = args["ethresh"]
- disp_force = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh,
- self.pmax, self.lpme)
+ disp_force = ADMPDispPmeForce(
+ box, covalent_map, rc, self.ethresh, self.pmax, self.lpme
+ )
self.disp_force = disp_force
pot_fn_lr = disp_force.get_energy
def potential_fn(positions, box, pairs, params):
mScales = params["mScales"]
- C6_list = params["C6"][map_atomtype] * 1e6 # to kj/mol * A**6
+ C6_list = params["C6"][map_atomtype] * 1e6 # to kj/mol * A**6
C8_list = params["C8"][map_atomtype] * 1e8
C10_list = params["C10"][map_atomtype] * 1e10
c6_list = jnp.sqrt(C6_list)
@@ -355,7 +363,7 @@ def potential_fn(positions, box, pairs, params):
c10_list = jnp.sqrt(C10_list)
c_list = jnp.vstack((c6_list, c8_list, c10_list))
E_lr = pot_fn_lr(positions, box, pairs, c_list.T, mScales)
- return - E_lr
+ return -E_lr
self._jaxPotential = potential_fn
# self._top_data = data
@@ -367,14 +375,17 @@ def renderXML(self):
# generate xml force field file
pass
+
# register all parsers
app.forcefield.parsers["ADMPDispPmeForce"] = ADMPDispPmeGenerator.parseElement
+
class QqTtDampingGenerator:
- r'''
+ r"""
This one calculates the tang-tonnies damping of charge-charge interaction
E = \sum_ij exp(-B*r)*(1+B*r)*q_i*q_j/r
- '''
+ """
+
def __init__(self, hamiltonian):
self.ff = hamiltonian
self.params = {
@@ -408,8 +419,7 @@ def parseElement(element, hamiltonian):
generator.types = np.array(generator.types)
# on working
- def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
- args):
+ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
n_atoms = len(data.atoms)
# build index map
@@ -421,13 +431,13 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
# build covalent map
covalent_map = build_covalent_map(data, 6)
- pot_fn_sr = generate_pairwise_interaction(TT_damping_qq_kernel,
- covalent_map,
- static_args={})
+ pot_fn_sr = generate_pairwise_interaction(
+ TT_damping_qq_kernel, covalent_map, static_args={}
+ )
def potential_fn(positions, box, pairs, params):
mScales = params["mScales"]
- b_list = params["B"][map_atomtype] / 10 # convert to A^-1
+ b_list = params["B"][map_atomtype] / 10 # convert to A^-1
q_list = params["Q"][map_atomtype]
E_sr = pot_fn_sr(positions, box, pairs, mScales, b_list, q_list)
@@ -443,17 +453,19 @@ def renderXML(self):
# generate xml force field file
pass
+
# register all parsers
app.forcefield.parsers["QqTtDampingForce"] = QqTtDampingGenerator.parseElement
class SlaterDampingGenerator:
- r'''
+ r"""
This one computes the slater-type damping function for c6/c8/c10 dispersion
E = \sum_ij (f6-1)*c6/r6 + (f8-1)*c8/r8 + (f10-1)*c10/r10
fn = f_tt(x, n)
x = br - (2*br2 + 3*br) / (br2 + 3*br + 3)
- '''
+ """
+
def __init__(self, hamiltonian):
self.ff = hamiltonian
self.params = {
@@ -490,8 +502,7 @@ def parseElement(element, hamiltonian):
generator.params[k] = jnp.array(generator.params[k])
generator.types = np.array(generator.types)
- def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
- args):
+ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
n_atoms = len(data.atoms)
# build index map
@@ -504,22 +515,24 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
covalent_map = build_covalent_map(data, 6)
# WORKING
- pot_fn_sr = generate_pairwise_interaction(slater_disp_damping_kernel,
- covalent_map,
- static_args={})
+ pot_fn_sr = generate_pairwise_interaction(
+ slater_disp_damping_kernel, covalent_map, static_args={}
+ )
def potential_fn(positions, box, pairs, params):
mScales = params["mScales"]
- b_list = params["B"][map_atomtype] / 10 # convert to A^-1
- c6_list = jnp.sqrt(params["C6"][map_atomtype] * 1e6) # to kj/mol * A**6
+ b_list = params["B"][map_atomtype] / 10 # convert to A^-1
+ c6_list = jnp.sqrt(params["C6"][map_atomtype] * 1e6) # to kj/mol * A**6
c8_list = jnp.sqrt(params["C8"][map_atomtype] * 1e8)
c10_list = jnp.sqrt(params["C10"][map_atomtype] * 1e10)
- E_sr = pot_fn_sr(positions, box, pairs, mScales, b_list, c6_list, c8_list, c10_list)
+ E_sr = pot_fn_sr(
+ positions, box, pairs, mScales, b_list, c6_list, c8_list, c10_list
+ )
return E_sr
self._jaxPotential = potential_fn
# self._top_data = data
-
+
def getJaxPotential(self):
return self._jaxPotential
@@ -527,21 +540,22 @@ def renderXML(self):
# generate xml force field file
pass
+
app.forcefield.parsers["SlaterDampingForce"] = SlaterDampingGenerator.parseElement
class SlaterExGenerator:
- r'''
+ r"""
This one computes the Slater-ISA type exchange interaction
u = \sum_ij A * (1/3*(Br)^2 + Br + 1)
- '''
+ """
def __init__(self, hamiltonian):
self.ff = hamiltonian
self.params = {
- "A": [],
- "B": [],
- }
+ "A": [],
+ "B": [],
+ }
self._jaxPotential = None
self.types = []
self.name = "SlaterEx"
@@ -568,8 +582,7 @@ def parseElement(element, hamiltonian):
generator.params[k] = jnp.array(generator.params[k])
generator.types = np.array(generator.types)
- def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
- args):
+ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
n_atoms = len(data.atoms)
# build index map
@@ -581,14 +594,14 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
# build covalent map
covalent_map = build_covalent_map(data, 6)
- pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel,
- covalent_map,
- static_args={})
+ pot_fn_sr = generate_pairwise_interaction(
+ slater_sr_kernel, covalent_map, static_args={}
+ )
def potential_fn(positions, box, pairs, params):
mScales = params["mScales"]
a_list = params["A"][map_atomtype]
- b_list = params["B"][map_atomtype] / 10 # nm^-1 to A^-1
+ b_list = params["B"][map_atomtype] / 10 # nm^-1 to A^-1
return pot_fn_sr(positions, box, pairs, mScales, a_list, b_list)
@@ -602,6 +615,7 @@ def renderXML(self):
# generate xml force field file
pass
+
app.forcefield.parsers["SlaterExForce"] = SlaterExGenerator.parseElement
@@ -624,6 +638,7 @@ def __init__(self):
super().__init__(self)
self.name = "SlaterDhf"
+
# register all parsers
app.forcefield.parsers["SlaterSrEsForce"] = SlaterSrEsGenerator.parseElement
app.forcefield.parsers["SlaterSrPolForce"] = SlaterSrPolGenerator.parseElement
@@ -697,33 +712,33 @@ def registerAtomType(self, atom: dict):
@staticmethod
def parseElement(element, hamiltonian):
- r""" parse admp related parameters in XML file
-
- example:
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+ r"""parse admp related parameters in XML file
+
+ example:
+
+
+
+
+
+
+
+
+
+
+
+
+
+
"""
generator = ADMPPmeGenerator(hamiltonian)
@@ -1112,39 +1127,50 @@ def getJaxPotential(self):
def renderXML(self):
#
-
- finfo = XMLNodeInfo('ADMPPmeForce')
- finfo.addAttribute('lmax', str(self.lmax))
+
+ finfo = XMLNodeInfo("ADMPPmeForce")
+ finfo.addAttribute("lmax", str(self.lmax))
outputparams = deepcopy(self.params)
- mScales = outputparams.pop('mScales')
- pScales = outputparams.pop('pScales')
- dScales = outputparams.pop('dScales')
+ mScales = outputparams.pop("mScales")
+ pScales = outputparams.pop("pScales")
+ dScales = outputparams.pop("dScales")
for i in range(len(mScales)):
- finfo.addAttribute(f'mScale1{i+2}', str(mScales[i]))
+ finfo.addAttribute(f"mScale1{i+2}", str(mScales[i]))
for i in range(len(pScales)):
- finfo.addAttribute(f'pScale{i+1}', str(pScales[i]))
+ finfo.addAttribute(f"pScale{i+1}", str(pScales[i]))
for i in range(len(dScales)):
- finfo.addAttribute(f'dScale{i+1}', str(dScales[i]))
-
- Q = outputparams['Q_local']
+ finfo.addAttribute(f"dScale{i+1}", str(dScales[i]))
+
+ Q = outputparams["Q_local"]
Q_global = convert_harm2cart(Q, self.lmax)
-
+
#
for atom in range(self.n_atoms):
- info = {'type': self.map_atomtype[atom]}
- info.update({ktype:self.kStrings[ktype][atom] for ktype in ['kz', 'kx', 'ky']})
- for i, key in enumerate(['c0', 'dX', 'dY', 'dZ', 'qXX', 'qXY', 'qXZ', 'qYY', 'qYZ', 'qZZ']):
+ info = {"type": self.map_atomtype[atom]}
+ info.update(
+ {ktype: self.kStrings[ktype][atom] for ktype in ["kz", "kx", "ky"]}
+ )
+ for i, key in enumerate(
+ ["c0", "dX", "dY", "dZ", "qXX", "qXY", "qXZ", "qYY", "qYZ", "qZZ"]
+ ):
info[key] = "%.8f" % Q_global[atom][i]
- finfo.addElement('Atom', info)
-
+ finfo.addElement("Atom", info)
+
#
for t in range(len(self.types)):
- info = {
- 'type': self.types[t]
- }
- info.update({p: "%.8f" % self.params['pol'][t] for p in ['polarizabilityXX', 'polarizabilityYY', 'polarizabilityZZ']})
- finfo.addElement('Polarize', info)
-
+ info = {"type": self.types[t]}
+ info.update(
+ {
+ p: "%.8f" % self.params["pol"][t]
+ for p in [
+ "polarizabilityXX",
+ "polarizabilityYY",
+ "polarizabilityZZ",
+ ]
+ }
+ )
+ finfo.addElement("Polarize", info)
+
return finfo
@@ -1172,14 +1198,14 @@ def registerBondType(self, bond):
def parseElement(element, hamiltonian):
r"""parse section in XML file
-
- example:
-
-
-
-
- <\HarmonicBondForce>
-
+
+ example:
+
+
+
+
+ <\HarmonicBondForce>
+
"""
existing = [f for f in hamiltonian._forces if isinstance(f, HarmonicBondJaxGenerator)]
if len(existing) == 0:
@@ -1242,7 +1268,7 @@ def renderXML(self):
binfo[k1] = v1
binfo[k2] = v2
for key in self.params.keys():
- binfo[key] = "%.8f"%self.params[key][ntype]
+ binfo[key] = "%.8f" % self.params[key][ntype]
finfo.addElement("Bond", binfo)
return finfo
@@ -1267,13 +1293,13 @@ def registerAngleType(self, angle):
@staticmethod
def parseElement(element, hamiltonian):
- r""" parse section in XML file
+ r"""parse section in XML file
- example:
-
-
-
- <\HarmonicAngleForce>
+ example:
+
+
+
+ <\HarmonicAngleForce>
"""
generator = HarmonicAngleJaxGenerator(hamiltonian)
@@ -1343,9 +1369,15 @@ def renderXML(self):
finfo = XMLNodeInfo("HarmonicAngleForce")
for i, type in enumerate(self.types):
t1, t2, t3 = type
- ainfo = {'type1': t1, 'type2': t2, 'type3': t3, 'k': self.params['k'][i], 'angle': self.params['angle'][i]}
- finfo.addElement('Angle', ainfo)
-
+ ainfo = {
+ "type1": t1,
+ "type2": t2,
+ "type3": t3,
+ "k": self.params["k"][i],
+ "angle": self.params["angle"][i],
+ }
+ finfo.addElement("Angle", ainfo)
+
return finfo
@@ -1353,7 +1385,6 @@ def renderXML(self):
app.forcefield.parsers["HarmonicAngleForce"] = HarmonicAngleJaxGenerator.parseElement
-
def _matchImproper(data, torsion, generator):
type1 = data.atomType[data.atoms[torsion[0]]]
type2 = data.atomType[data.atoms[torsion[1]]]
@@ -1557,14 +1588,14 @@ def registerImproperTorsion(self, parameters, ordering="default"):
@staticmethod
def parseElement(element, ff):
- """ parse section in XML file
-
- example:
-
-
-
-
-
+ """parse section in XML file
+
+ example:
+
+
+
+
+
"""
existing = [f for f in ff._forces if isinstance(f, PeriodicTorsionJaxGenerator)]
@@ -1901,50 +1932,58 @@ def getJaxPotential(self):
def renderXML(self):
params = self.params
# generate xml force field file
- finfo = XMLNodeInfo('PeriodicTorsionForce')
+ finfo = XMLNodeInfo("PeriodicTorsionForce")
for i in range(len(self.proper)):
proper = self.proper[i]
-
- finfo.addElement('Proper',
- {'type1': proper.types1, 'type2': proper.types2,
- 'type3': proper.types3, 'type4': proper.types4,
- 'periodicity1': proper.periodicity[0],
- 'phase1': params['psi1_p'][i],
- 'k1': params['k1_p'][i],
- 'periodicity2': proper.periodicity[1],
- 'phase2': params['psi2_p'][i],
- 'k2': params['k2_p'][i],
- 'periodicity3': proper.periodicity[2],
- 'phase3': params['psi3_p'][i],
- 'k3': params['k3_p'][i],
- 'periodicity4': proper.periodicity[3],
- 'phase4': params['psi4_p'][i],
- 'k4': params['k4_p'][i],
- }
+
+ finfo.addElement(
+ "Proper",
+ {
+ "type1": proper.types1,
+ "type2": proper.types2,
+ "type3": proper.types3,
+ "type4": proper.types4,
+ "periodicity1": proper.periodicity[0],
+ "phase1": params["psi1_p"][i],
+ "k1": params["k1_p"][i],
+ "periodicity2": proper.periodicity[1],
+ "phase2": params["psi2_p"][i],
+ "k2": params["k2_p"][i],
+ "periodicity3": proper.periodicity[2],
+ "phase3": params["psi3_p"][i],
+ "k3": params["k3_p"][i],
+ "periodicity4": proper.periodicity[3],
+ "phase4": params["psi4_p"][i],
+ "k4": params["k4_p"][i],
+ },
)
-
+
for i in range(len(self.improper)):
-
+
improper = self.improper[i]
-
- finfo.addElement('Improper',
- {'type1': improper.types1, 'type2': improper.types2,
- 'type3': improper.types3, 'type4': improper.types4,
- 'periodicity1': improper.periodicity[0],
- 'phase1': params['psi1_i'][i],
- 'k1': params['k1_i'][i],
- 'periodicity2': proper.periodicity[1],
- 'phase2': params['psi2_i'][i],
- 'k2': params['k2_i'][i],
- 'periodicity3': proper.periodicity[2],
- 'phase3': params['psi3_i'][i],
- 'k3': params['k3_i'][i],
- 'periodicity4': proper.periodicity[3],
- 'phase4': params['psi4_i'][i],
- 'k4': params['k4_i'][i],
- }
+
+ finfo.addElement(
+ "Improper",
+ {
+ "type1": improper.types1,
+ "type2": improper.types2,
+ "type3": improper.types3,
+ "type4": improper.types4,
+ "periodicity1": improper.periodicity[0],
+ "phase1": params["psi1_i"][i],
+ "k1": params["k1_i"][i],
+ "periodicity2": proper.periodicity[1],
+ "phase2": params["psi2_i"][i],
+ "k2": params["k2_i"][i],
+ "periodicity3": proper.periodicity[2],
+ "phase3": params["psi3_i"][i],
+ "k3": params["k3_i"][i],
+ "periodicity4": proper.periodicity[3],
+ "phase4": params["psi4_i"][i],
+ "k4": params["k4_i"][i],
+ },
)
-
+
return finfo
@@ -1989,9 +2028,9 @@ def registerAtom(self, atom):
@staticmethod
def parseElement(element, ff):
"""parse section in XML file
-
+
example:
-
+
@@ -2023,9 +2062,9 @@ def parseElement(element, ff):
generator.useAttributeFromResidue.append(eprm)
for atom in element.findall("Atom"):
generator.registerAtom(atom.attrib)
-
+
generator.n_atoms = len(element.findall("Atom"))
-
+
# jax it!
for k in generator.params.keys():
generator.params[k] = jnp.array(generator.params[k])
@@ -2103,7 +2142,6 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
# TODO: implement NBFIX
map_nbfix = []
map_nbfix = np.array(map_nbfix, dtype=int).reshape((-1, 2))
-
colv_map = build_covalent_map(data, 6)
@@ -2352,23 +2390,27 @@ def getJaxPotential(self):
return self._jaxPotential
def renderXML(self):
-
+
#
- finfo = XMLNodeInfo('NonbondedForce')
- finfo.addAttribute('coulomb14scale', str(self.coulomb14scale))
- finfo.addAttribute('lj14scale', str(self.lj14scale))
-
+ finfo = XMLNodeInfo("NonbondedForce")
+ finfo.addAttribute("coulomb14scale", str(self.coulomb14scale))
+ finfo.addAttribute("lj14scale", str(self.lj14scale))
+
for atom in range(self.n_atoms):
- info = {'type': self.types[atom], 'charge': self.params['charge'][atom], 'sigma': self.params['sigma'][atom], 'epsilon': self.params['epsilon'][atom]}
- finfo.addElement('Atom', info)
-
+ info = {
+ "type": self.types[atom],
+ "charge": self.params["charge"][atom],
+ "sigma": self.params["sigma"][atom],
+ "epsilon": self.params["epsilon"][atom],
+ }
+ finfo.addElement("Atom", info)
+
return finfo
app.forcefield.parsers["NonbondedForce"] = NonbondJaxGenerator.parseElement
-
class Hamiltonian(app.forcefield.ForceField):
def __init__(self, *xmlnames):
super().__init__(*xmlnames)
@@ -2382,13 +2424,13 @@ def createPotential(
topology,
nonbondedMethod=app.NoCutoff,
nonbondedCutoff=1.0 * unit.nanometer,
- **args
+ **args,
):
system = self.createSystem(
topology,
nonbondedMethod=nonbondedMethod,
nonbondedCutoff=nonbondedCutoff,
- **args
+ **args,
)
# load_constraints_from_system_if_needed
# create potentials
@@ -2438,4 +2480,4 @@ def getParameters(self):
params = {}
for gen in self.getGenerators():
params[gen.name] = gen.params
- return params
\ No newline at end of file
+ return params
diff --git a/dmff/common/nblist.py b/dmff/common/nblist.py
index 7cb5f9437..16de9715f 100644
--- a/dmff/common/nblist.py
+++ b/dmff/common/nblist.py
@@ -98,4 +98,4 @@ def distance(self):
jnp.ndarray: (nPairs, )
"""
- return jnp.linalg.norm(self.dr, axis=1)
\ No newline at end of file
+ return jnp.linalg.norm(self.dr, axis=1)
diff --git a/docs/about.md b/docs/about.md
deleted file mode 100644
index 835deeb6f..000000000
--- a/docs/about.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# Lisense
-
-# Contributor
\ No newline at end of file
diff --git a/docs/assets/js/mathjax.js b/docs/assets/js/mathjax.js
new file mode 100644
index 000000000..e648674dd
--- /dev/null
+++ b/docs/assets/js/mathjax.js
@@ -0,0 +1,16 @@
+window.MathJax = {
+ tex: {
+ inlineMath: [["\\(", "\\)"]],
+ displayMath: [["\\[", "\\]"]],
+ processEscapes: true,
+ processEnvironments: true
+ },
+ options: {
+ ignoreHtmlClass: ".*|",
+ processHtmlClass: "arithmatex"
+ }
+ };
+
+ document$.subscribe(() => {
+ MathJax.typesetPromise()
+ })
\ No newline at end of file
diff --git a/docs/dev_guide/profile.md b/docs/dev_guide/profile.md
deleted file mode 100644
index 14a981b77..000000000
--- a/docs/dev_guide/profile.md
+++ /dev/null
@@ -1,2 +0,0 @@
-# How to profile
-
diff --git a/docs/gen_ref_pages.py b/docs/gen_ref_pages.py
new file mode 100644
index 000000000..77fbb0a3c
--- /dev/null
+++ b/docs/gen_ref_pages.py
@@ -0,0 +1,38 @@
+"""Generate the code reference pages."""
+
+from pathlib import Path
+
+import mkdocs_gen_files
+
+nav = mkdocs_gen_files.Nav()
+
+for path in sorted(Path("dmff").rglob("*.py")): #
+
+ module_path = path.relative_to('dmff').with_suffix("") #
+
+ doc_path = path.relative_to('dmff').with_suffix(".md") #
+
+ full_doc_path = Path("refs", doc_path) #
+
+ parts = list(module_path.parts)
+
+ if parts[-1] == "__init__": #
+ continue
+ elif parts[-1] == "__main__":
+ continue
+
+ nav[parts] = doc_path.as_posix()
+ print(full_doc_path)
+ with mkdocs_gen_files.open(full_doc_path, "w") as fd: #
+
+ identifier = ".".join(parts) #
+
+ print("::: dmff." + identifier, file=fd) #
+
+
+ mkdocs_gen_files.set_edit_path(full_doc_path, path) #
+
+with mkdocs_gen_files.open("refs/SUMMARY.md", "w") as nav_file: #
+
+ nav_file.writelines(nav.build_literate_nav()) #
+
diff --git a/docs/license.md b/docs/license.md
new file mode 100644
index 000000000..d2af63711
--- /dev/null
+++ b/docs/license.md
@@ -0,0 +1,3 @@
+# Lisense
+
+The project DeePMD-kit is licensed under [GNU LGPLv3.0](https://github.com/deepmodeling/deepmd-kit/blob/master/LICENSE).
\ No newline at end of file
diff --git a/docs/user_guide/xml_spec.md b/docs/user_guide/xml_spec.md
index 5304e3141..7386c6345 100644
--- a/docs/user_guide/xml_spec.md
+++ b/docs/user_guide/xml_spec.md
@@ -151,7 +151,7 @@ The `` node of the residue part defines all the atoms involved in the resi
- <\displaylines{Bond atomName1="CA" atomName2="HA"/>
+
diff --git a/mkdocs.yml b/mkdocs.yml
index ea1b37ba5..8a751b209 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -1,4 +1,5 @@
site_name: DMFF
+nav:
nav:
- Home: index.md
- User Guide:
@@ -18,7 +19,11 @@ nav:
- ADMP:
- Introduction: admp/readme.md
- Frontends: admp/frontend.md
- - About: about.md
+
+ - API: refs/
+
+ - About:
+ - License: license.md
theme: readthedocs
@@ -26,7 +31,17 @@ markdown_extensions:
- pymdownx.arithmatex:
generic: true
+plugins:
+- search
+- gen-files:
+ scripts:
+ - docs/gen_ref_pages.py
+- literate-nav:
+ nav_file: SUMMARY.md
+- mkdocstrings:
+
+
extra_javascript:
- - javascripts/mathjax.js
+ - assets/js/mathjax.js
- https://polyfill.io/v3/polyfill.min.js?features=es6
- https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
diff --git a/requirements.txt b/requirements.txt
index 7360452fb..343447502 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,9 @@
numpy>=1.18
jax>=0.3.7
jax-md>=0.1.28
+mkdocs>=1.3.0
+mkdocs-autorefs>=0.4.1
+mkdocs-gen-files>=0.3.4
+mkdocs-literate-nav>=0.4.1
+mkdocstrings>=0.19.0
+mkdocstrings-python>=0.7.0
\ No newline at end of file
diff --git a/tests/test_common/test_nblist.py b/tests/test_common/test_nblist.py
index b2b66484e..d5f709103 100644
--- a/tests/test_common/test_nblist.py
+++ b/tests/test_common/test_nblist.py
@@ -51,4 +51,4 @@ def test_dr(self, nblist):
def test_distance(self, nblist):
- assert nblist.distance.shape == (15, )
\ No newline at end of file
+ assert nblist.distance.shape == (15, )