diff --git a/dmff/admp/disp_pme.py b/dmff/admp/disp_pme.py index 11c30eab0..722bf22fb 100755 --- a/dmff/admp/disp_pme.py +++ b/dmff/admp/disp_pme.py @@ -1,11 +1,14 @@ +from functools import partial + import jax.numpy as jnp -from jax import vmap, value_and_grad -from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales -from dmff.admp.spatial import pbc_shift +from dmff.admp.pairwise import (distribute_dispcoeff, distribute_scalar, + distribute_v3) from dmff.admp.pme import setup_ewald_parameters -from dmff.admp.recip import generate_pme_recip, Ck_6, Ck_8, Ck_10 -from dmff.admp.pairwise import distribute_scalar, distribute_v3, distribute_dispcoeff -from functools import partial +from dmff.admp.recip import Ck_6, Ck_8, Ck_10, generate_pme_recip +from dmff.admp.spatial import pbc_shift +from dmff.utils import jit_condition, pair_buffer_scales, regularize_pairs +from jax import value_and_grad, vmap + class ADMPDispPmeForce: ''' diff --git a/dmff/admp/mbpol_intra.py b/dmff/admp/mbpol_intra.py index c850b0acc..3a9966e13 100755 --- a/dmff/admp/mbpol_intra.py +++ b/dmff/admp/mbpol_intra.py @@ -1,28 +1,24 @@ -import sys + import numpy as np import jax.numpy as jnp -from jax import grad, value_and_grad -from dmff.settings import DO_JIT -from dmff.utils import jit_condition +import numpy as np from dmff.admp.spatial import v_pbc_shift -from dmff.admp.pme import ADMPPmeForce -from dmff.admp.parser import * +from dmff.utils import jit_condition from jax import vmap -import time #const f5z = 0.999677885 fbasis = 0.15860145369897 fcore = -1.6351695982132 frest = 1.0 -reoh = 0.958649; -thetae = 104.3475; -b1 = 2.0; -roh = 0.9519607159623009; -alphaoh = 2.587949757553683; -deohA = 42290.92019288289; -phh1A = 16.94879431193463; -phh2 = 12.66426998162947; +reoh = 0.958649 +thetae = 104.3475 +b1 = 2.0 +roh = 0.9519607159623009 +alphaoh = 2.587949757553683 +deohA = 42290.92019288289 +phh1A = 16.94879431193463 +phh2 = 12.66426998162947 c5zA = jnp.array([4.2278462684916e+04, 4.5859382909906e-02, 9.4804986183058e+03, 7.5485566680955e+02, 1.9865052511496e+03, 4.3768071560862e+02, @@ -487,4 +483,3 @@ def onebody_kernel(x1, x2, x3, Va, Vb, efac): e1 *= cm1_kcalmol e1 *= cal2joule # conver cal 2 j return e1 - diff --git a/dmff/admp/multipole.py b/dmff/admp/multipole.py index 7f957a823..e863f9e72 100644 --- a/dmff/admp/multipole.py +++ b/dmff/admp/multipole.py @@ -1,8 +1,8 @@ -import sys +from functools import partial + import jax.numpy as jnp -from jax import vmap from dmff.utils import jit_condition -from functools import partial +from jax import vmap # This module deals with the transformations and rotations of multipoles @@ -48,7 +48,7 @@ def convert_cart2harm(Theta, lmax): n * (l+1)^2, stores the spherical multipoles ''' if lmax > 2: - sys.exit('l > 2 (beyond quadrupole) not supported') + raise ValueError('l > 2 (beyond quadrupole) not supported') Q_mono = Theta[0:1] @@ -90,7 +90,7 @@ def convert_harm2cart(Q, lmax): ''' if lmax > 2: - sys.exit('l > 2 (beyond quadrupole) not supported') + raise ValueError('l > 2 (beyond quadrupole) not supported') T_mono = Q[0:1] diff --git a/dmff/admp/pairwise.py b/dmff/admp/pairwise.py index 3ff33ad1a..c0adc0837 100755 --- a/dmff/admp/pairwise.py +++ b/dmff/admp/pairwise.py @@ -1,9 +1,9 @@ -import sys -from jax import vmap +from functools import partial + import jax.numpy as jnp -from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales from dmff.admp.spatial import v_pbc_shift -from functools import partial +from dmff.utils import jit_condition, pair_buffer_scales, regularize_pairs +from jax import vmap DIELECTRIC = 1389.35455846 @@ -170,106 +170,3 @@ def slater_sr_kernel(dr, m, ai, aj, bi, bj): P = 1/3 * br2 + br + 1 return a * P * jnp.exp(-br) * m - -def validation(pdb): - xml = 'mpidwater.xml' - pdbinfo = read_pdb(pdb) - serials = pdbinfo['serials'] - names = pdbinfo['names'] - resNames = pdbinfo['resNames'] - resSeqs = pdbinfo['resSeqs'] - positions = pdbinfo['positions'] - box = pdbinfo['box'] # a, b, c, α, β, γ - charges = pdbinfo['charges'] - positions = jnp.asarray(positions) - lx, ly, lz, _, _, _ = box - box = jnp.eye(3)*jnp.array([lx, ly, lz]) - - mScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) - pScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) - dScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) - - rc = 4 # in Angstrom - ethresh = 1e-4 - - n_atoms = len(serials) - - atomTemplate, residueTemplate = read_xml(xml) - atomDicts, residueDicts = init_residues(serials, names, resNames, resSeqs, positions, charges, atomTemplate, residueTemplate) - - covalent_map = assemble_covalent(residueDicts, n_atoms) - displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=False) - neighbor_list_fn = partition.neighbor_list(displacement_fn, box, rc, 0, format=partition.OrderedSparse) - nbr = neighbor_list_fn.allocate(positions) - pairs = nbr.idx.T - - pmax = 10 - kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box) - kappa = 0.657065221219616 - - # construct the C list - c_list = np.zeros((3, n_atoms)) - a_list = np.zeros(n_atoms) - q_list = np.zeros(n_atoms) - b_list = np.zeros(n_atoms) - nmol=int(n_atoms/3) - for i in range(nmol): - a = i*3 - b = i*3+1 - c = i*3+2 - # dispersion coeff - c_list[0][a]=37.199677405 - c_list[0][b]=7.6111103 - c_list[0][c]=7.6111103 - c_list[1][a]=85.26810658 - c_list[1][b]=11.90220148 - c_list[1][c]=11.90220148 - c_list[2][a]=134.44874488 - c_list[2][b]=15.05074749 - c_list[2][c]=15.05074749 - # q - q_list[a] = -0.741706 - q_list[b] = 0.370853 - q_list[c] = 0.370853 - # b, Bohr^-1 - b_list[a] = 2.00095977 - b_list[b] = 1.999519942 - b_list[c] = 1.999519942 - # a, Hartree - a_list[a] = 458.3777 - a_list[b] = 0.0317 - a_list[c] = 0.0317 - - - c_list = jnp.array(c_list) - -# @partial(vmap, in_axes=(0, 0, 0, 0), out_axes=(0)) -# @jit_condition(static_argnums=()) -# def disp6_pme_real_kernel(dr, m, ci, cj): -# # unpack static arguments -# kappa = static_args['kappa'] -# # calculate distance -# dr2 = dr ** 2 -# dr6 = dr2 ** 3 -# # do calculation -# x2 = kappa**2 * dr2 -# exp_x2 = jnp.exp(-x2) -# x4 = x2 * x2 -# g = (1 + x2 + 0.5*x4) * exp_x2 -# return (m + g - 1) * ci * cj / dr6 - -# static_args = {'kappa': kappa} -# disp6_pme_real = generate_pairwise_interaction(disp6_pme_real_kernel, covalent_map, static_args) -# print(disp6_pme_real(positions, box, pairs, mScales, c_list[0, :])) - - TT_damping_qq_c6 = generate_pairwise_interaction(TT_damping_qq_c6_kernel, covalent_map, static_args={}) - - TT_damping_qq_c6(positions, box, pairs, mScales, a_list, b_list, q_list, c_list[0]) - print('ok') - print(TT_damping_qq_c6(positions, box, pairs, mScales, a_list, b_list, q_list, c_list[0])) - return - - -# below is the validation code -if __name__ == '__main__': - validation(sys.argv[1]) diff --git a/dmff/admp/recip.py b/dmff/admp/recip.py index 66ee5e7bc..3987f23d8 100755 --- a/dmff/admp/recip.py +++ b/dmff/admp/recip.py @@ -1,4 +1,3 @@ - import numpy as np import jax.numpy as jnp import jax.scipy as jsp diff --git a/dmff/api.py b/dmff/api.py index 85864d210..dfd7fa6de 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -10,8 +10,8 @@ import openmm as mm import openmm.app as app -import openmm.unit as unit import openmm.app.element as elem +import openmm.unit as unit from dmff.admp.disp_pme import ADMPDispPmeForce from dmff.admp.multipole import convert_cart2harm, convert_harm2cart @@ -33,7 +33,14 @@ LennardJonesLongRangeForce, CoulombPMEForce, CoulNoCutoffForce, + CoulombPMEForce, CoulReactionFieldForce, + LennardJonesForce, +) +from .classical.intra import ( + HarmonicAngleJaxForce, + HarmonicBondJaxForce, + PeriodicTorsionJaxForce, ) from dmff.classical.fep import ( LennardJonesFreeEnergyForce, @@ -44,11 +51,9 @@ class XMLNodeInfo: - @staticmethod - def to_str(value)->str: - """ convert value to string if it can - """ + def to_str(value) -> str: + """convert value to string if it can""" if isinstance(value, str): return value elif isinstance(value, (jnp.ndarray, np.ndarray)): @@ -62,17 +67,16 @@ def to_str(value)->str: return str(value) class XMLElementInfo: - def __init__(self, name): self.name = name self.attributes = {} - + def addAttribute(self, key, value): self.attributes[key] = XMLNodeInfo.to_str(value) - + def __repr__(self): return f'<{self.name} {" ".join([f"{k}={v}" for k, v in self.attributes.items()])}>' - + def __getitem__(self, name): return self.attributes[name] @@ -80,13 +84,12 @@ def __init__(self, name): self.name = name self.attributes = {} self.elements = [] - + def __getitem__(self, name): if isinstance(name, str): return self.attributes[name] elif isinstance(name, int): return self.elements[name] - def addAttribute(self, key, value): self.attributes[key] = XMLNodeInfo.to_str(value) @@ -103,10 +106,9 @@ def modResidue(self, residue, atom, key, value): def __repr__(self): # tricy string formatting left = f'<{self.name} {" ".join([f"{k}={v}" for k, v in self.attributes.items()])}> \n\t' - right = f'<\\{self.name}>' - content = '\n\t'.join([repr(e) for e in self.elements]) - return left + content + '\n' + right - + right = f"<\\{self.name}>" + content = "\n\t".join([repr(e) for e in self.elements]) + return left + content + "\n" + right def get_line_context(file_path, line_number): @@ -136,8 +138,8 @@ def build_covalent_map(data, max_neighbor): def findAtomTypeTexts(attribs, num): typetxt = [] - for n in range(1, num+1): - for key in ["type%i"%n, "class%i"%n]: + for n in range(1, num + 1): + for key in ["type%i" % n, "class%i" % n]: if key in attribs: typetxt.append((key, attribs[key])) break @@ -213,8 +215,9 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): if "ethresh" in args: self.ethresh = args["ethresh"] - Force_DispPME = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh, - self.pmax, lpme=self.lpme) + Force_DispPME = ADMPDispPmeForce( + box, covalent_map, rc, self.ethresh, self.pmax, lpme=self.lpme + ) self.disp_pme_force = Force_DispPME pot_fn_lr = Force_DispPME.get_energy pot_fn_sr = generate_pairwise_interaction( @@ -247,36 +250,41 @@ def getJaxPotential(self): def renderXML(self): # generate xml force field file - finfo = XMLNodeInfo('ADMPDispForce') - finfo.addAttribute('mScale12', self.params["mScales"][0]) - finfo.addAttribute('mScale13', self.params["mScales"][1]) - finfo.addAttribute('mScale14', self.params["mScales"][2]) - finfo.addAttribute('mScale15', self.params["mScales"][3]) - finfo.addAttribute('mScale16', self.params["mScales"][4]) - + finfo = XMLNodeInfo("ADMPDispForce") + finfo.addAttribute("mScale12", self.params["mScales"][0]) + finfo.addAttribute("mScale13", self.params["mScales"][1]) + finfo.addAttribute("mScale14", self.params["mScales"][2]) + finfo.addAttribute("mScale15", self.params["mScales"][3]) + finfo.addAttribute("mScale16", self.params["mScales"][4]) + for i in range(len(self.types)): - ainfo = {'type': self.types[i], 'A': self.params["A"][i], 'B': self.params["B"][i], 'Q': self.params["Q"][i], 'C6': self.params["C6"][i], 'C8': self.params["C8"][i], 'C10': self.params["C10"][i]} - finfo.addElement('Atom', ainfo) - + ainfo = { + "type": self.types[i], + "A": self.params["A"][i], + "B": self.params["B"][i], + "Q": self.params["Q"][i], + "C6": self.params["C6"][i], + "C8": self.params["C8"][i], + "C10": self.params["C10"][i], + } + finfo.addElement("Atom", ainfo) + return finfo + # register all parsers app.forcefield.parsers["ADMPDispForce"] = ADMPDispGenerator.parseElement class ADMPDispPmeGenerator: - r''' + r""" This one computes the undamped C6/C8/C10 interactions u = \sum_{ij} c6/r^6 + c8/r^8 + c10/r^10 - ''' + """ def __init__(self, hamiltonian): self.ff = hamiltonian - self.params = { - "C6": [], - "C8": [], - "C10": [] - } + self.params = {"C6": [], "C8": [], "C10": []} self._jaxPotential = None self.types = [] self.ethresh = 5e-4 @@ -306,8 +314,7 @@ def parseElement(element, hamiltonian): generator.params[k] = jnp.array(generator.params[k]) generator.types = np.array(generator.types) - def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, - args): + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): methodMap = { app.CutoffPeriodic: "CutoffPeriodic", app.NoCutoff: "NoCutoff", @@ -337,17 +344,18 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, rc = nonbondedCutoff.value_in_unit(unit.angstrom) # get calculator - if 'ethresh' in args: - self.ethresh = args['ethresh'] + if "ethresh" in args: + self.ethresh = args["ethresh"] - disp_force = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh, - self.pmax, self.lpme) + disp_force = ADMPDispPmeForce( + box, covalent_map, rc, self.ethresh, self.pmax, self.lpme + ) self.disp_force = disp_force pot_fn_lr = disp_force.get_energy def potential_fn(positions, box, pairs, params): mScales = params["mScales"] - C6_list = params["C6"][map_atomtype] * 1e6 # to kj/mol * A**6 + C6_list = params["C6"][map_atomtype] * 1e6 # to kj/mol * A**6 C8_list = params["C8"][map_atomtype] * 1e8 C10_list = params["C10"][map_atomtype] * 1e10 c6_list = jnp.sqrt(C6_list) @@ -355,7 +363,7 @@ def potential_fn(positions, box, pairs, params): c10_list = jnp.sqrt(C10_list) c_list = jnp.vstack((c6_list, c8_list, c10_list)) E_lr = pot_fn_lr(positions, box, pairs, c_list.T, mScales) - return - E_lr + return -E_lr self._jaxPotential = potential_fn # self._top_data = data @@ -367,14 +375,17 @@ def renderXML(self): # generate xml force field file pass + # register all parsers app.forcefield.parsers["ADMPDispPmeForce"] = ADMPDispPmeGenerator.parseElement + class QqTtDampingGenerator: - r''' + r""" This one calculates the tang-tonnies damping of charge-charge interaction E = \sum_ij exp(-B*r)*(1+B*r)*q_i*q_j/r - ''' + """ + def __init__(self, hamiltonian): self.ff = hamiltonian self.params = { @@ -408,8 +419,7 @@ def parseElement(element, hamiltonian): generator.types = np.array(generator.types) # on working - def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, - args): + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): n_atoms = len(data.atoms) # build index map @@ -421,13 +431,13 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, # build covalent map covalent_map = build_covalent_map(data, 6) - pot_fn_sr = generate_pairwise_interaction(TT_damping_qq_kernel, - covalent_map, - static_args={}) + pot_fn_sr = generate_pairwise_interaction( + TT_damping_qq_kernel, covalent_map, static_args={} + ) def potential_fn(positions, box, pairs, params): mScales = params["mScales"] - b_list = params["B"][map_atomtype] / 10 # convert to A^-1 + b_list = params["B"][map_atomtype] / 10 # convert to A^-1 q_list = params["Q"][map_atomtype] E_sr = pot_fn_sr(positions, box, pairs, mScales, b_list, q_list) @@ -443,17 +453,19 @@ def renderXML(self): # generate xml force field file pass + # register all parsers app.forcefield.parsers["QqTtDampingForce"] = QqTtDampingGenerator.parseElement class SlaterDampingGenerator: - r''' + r""" This one computes the slater-type damping function for c6/c8/c10 dispersion E = \sum_ij (f6-1)*c6/r6 + (f8-1)*c8/r8 + (f10-1)*c10/r10 fn = f_tt(x, n) x = br - (2*br2 + 3*br) / (br2 + 3*br + 3) - ''' + """ + def __init__(self, hamiltonian): self.ff = hamiltonian self.params = { @@ -490,8 +502,7 @@ def parseElement(element, hamiltonian): generator.params[k] = jnp.array(generator.params[k]) generator.types = np.array(generator.types) - def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, - args): + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): n_atoms = len(data.atoms) # build index map @@ -504,22 +515,24 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, covalent_map = build_covalent_map(data, 6) # WORKING - pot_fn_sr = generate_pairwise_interaction(slater_disp_damping_kernel, - covalent_map, - static_args={}) + pot_fn_sr = generate_pairwise_interaction( + slater_disp_damping_kernel, covalent_map, static_args={} + ) def potential_fn(positions, box, pairs, params): mScales = params["mScales"] - b_list = params["B"][map_atomtype] / 10 # convert to A^-1 - c6_list = jnp.sqrt(params["C6"][map_atomtype] * 1e6) # to kj/mol * A**6 + b_list = params["B"][map_atomtype] / 10 # convert to A^-1 + c6_list = jnp.sqrt(params["C6"][map_atomtype] * 1e6) # to kj/mol * A**6 c8_list = jnp.sqrt(params["C8"][map_atomtype] * 1e8) c10_list = jnp.sqrt(params["C10"][map_atomtype] * 1e10) - E_sr = pot_fn_sr(positions, box, pairs, mScales, b_list, c6_list, c8_list, c10_list) + E_sr = pot_fn_sr( + positions, box, pairs, mScales, b_list, c6_list, c8_list, c10_list + ) return E_sr self._jaxPotential = potential_fn # self._top_data = data - + def getJaxPotential(self): return self._jaxPotential @@ -527,21 +540,22 @@ def renderXML(self): # generate xml force field file pass + app.forcefield.parsers["SlaterDampingForce"] = SlaterDampingGenerator.parseElement class SlaterExGenerator: - r''' + r""" This one computes the Slater-ISA type exchange interaction u = \sum_ij A * (1/3*(Br)^2 + Br + 1) - ''' + """ def __init__(self, hamiltonian): self.ff = hamiltonian self.params = { - "A": [], - "B": [], - } + "A": [], + "B": [], + } self._jaxPotential = None self.types = [] self.name = "SlaterEx" @@ -568,8 +582,7 @@ def parseElement(element, hamiltonian): generator.params[k] = jnp.array(generator.params[k]) generator.types = np.array(generator.types) - def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, - args): + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): n_atoms = len(data.atoms) # build index map @@ -581,14 +594,14 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, # build covalent map covalent_map = build_covalent_map(data, 6) - pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel, - covalent_map, - static_args={}) + pot_fn_sr = generate_pairwise_interaction( + slater_sr_kernel, covalent_map, static_args={} + ) def potential_fn(positions, box, pairs, params): mScales = params["mScales"] a_list = params["A"][map_atomtype] - b_list = params["B"][map_atomtype] / 10 # nm^-1 to A^-1 + b_list = params["B"][map_atomtype] / 10 # nm^-1 to A^-1 return pot_fn_sr(positions, box, pairs, mScales, a_list, b_list) @@ -602,6 +615,7 @@ def renderXML(self): # generate xml force field file pass + app.forcefield.parsers["SlaterExForce"] = SlaterExGenerator.parseElement @@ -624,6 +638,7 @@ def __init__(self): super().__init__(self) self.name = "SlaterDhf" + # register all parsers app.forcefield.parsers["SlaterSrEsForce"] = SlaterSrEsGenerator.parseElement app.forcefield.parsers["SlaterSrPolForce"] = SlaterSrPolGenerator.parseElement @@ -697,33 +712,33 @@ def registerAtomType(self, atom: dict): @staticmethod def parseElement(element, hamiltonian): - r""" parse admp related parameters in XML file - - example: - - - - - - - - - - - - - - + r"""parse admp related parameters in XML file + + example: + + + + + + + + + + + + + + """ generator = ADMPPmeGenerator(hamiltonian) @@ -1112,39 +1127,50 @@ def getJaxPotential(self): def renderXML(self): # - - finfo = XMLNodeInfo('ADMPPmeForce') - finfo.addAttribute('lmax', str(self.lmax)) + + finfo = XMLNodeInfo("ADMPPmeForce") + finfo.addAttribute("lmax", str(self.lmax)) outputparams = deepcopy(self.params) - mScales = outputparams.pop('mScales') - pScales = outputparams.pop('pScales') - dScales = outputparams.pop('dScales') + mScales = outputparams.pop("mScales") + pScales = outputparams.pop("pScales") + dScales = outputparams.pop("dScales") for i in range(len(mScales)): - finfo.addAttribute(f'mScale1{i+2}', str(mScales[i])) + finfo.addAttribute(f"mScale1{i+2}", str(mScales[i])) for i in range(len(pScales)): - finfo.addAttribute(f'pScale{i+1}', str(pScales[i])) + finfo.addAttribute(f"pScale{i+1}", str(pScales[i])) for i in range(len(dScales)): - finfo.addAttribute(f'dScale{i+1}', str(dScales[i])) - - Q = outputparams['Q_local'] + finfo.addAttribute(f"dScale{i+1}", str(dScales[i])) + + Q = outputparams["Q_local"] Q_global = convert_harm2cart(Q, self.lmax) - + # for atom in range(self.n_atoms): - info = {'type': self.map_atomtype[atom]} - info.update({ktype:self.kStrings[ktype][atom] for ktype in ['kz', 'kx', 'ky']}) - for i, key in enumerate(['c0', 'dX', 'dY', 'dZ', 'qXX', 'qXY', 'qXZ', 'qYY', 'qYZ', 'qZZ']): + info = {"type": self.map_atomtype[atom]} + info.update( + {ktype: self.kStrings[ktype][atom] for ktype in ["kz", "kx", "ky"]} + ) + for i, key in enumerate( + ["c0", "dX", "dY", "dZ", "qXX", "qXY", "qXZ", "qYY", "qYZ", "qZZ"] + ): info[key] = "%.8f" % Q_global[atom][i] - finfo.addElement('Atom', info) - + finfo.addElement("Atom", info) + # for t in range(len(self.types)): - info = { - 'type': self.types[t] - } - info.update({p: "%.8f" % self.params['pol'][t] for p in ['polarizabilityXX', 'polarizabilityYY', 'polarizabilityZZ']}) - finfo.addElement('Polarize', info) - + info = {"type": self.types[t]} + info.update( + { + p: "%.8f" % self.params["pol"][t] + for p in [ + "polarizabilityXX", + "polarizabilityYY", + "polarizabilityZZ", + ] + } + ) + finfo.addElement("Polarize", info) + return finfo @@ -1172,14 +1198,14 @@ def registerBondType(self, bond): def parseElement(element, hamiltonian): r"""parse section in XML file - - example: - - - - - <\HarmonicBondForce> - + + example: + + + + + <\HarmonicBondForce> + """ existing = [f for f in hamiltonian._forces if isinstance(f, HarmonicBondJaxGenerator)] if len(existing) == 0: @@ -1242,7 +1268,7 @@ def renderXML(self): binfo[k1] = v1 binfo[k2] = v2 for key in self.params.keys(): - binfo[key] = "%.8f"%self.params[key][ntype] + binfo[key] = "%.8f" % self.params[key][ntype] finfo.addElement("Bond", binfo) return finfo @@ -1267,13 +1293,13 @@ def registerAngleType(self, angle): @staticmethod def parseElement(element, hamiltonian): - r""" parse section in XML file + r"""parse section in XML file - example: - - - - <\HarmonicAngleForce> + example: + + + + <\HarmonicAngleForce> """ generator = HarmonicAngleJaxGenerator(hamiltonian) @@ -1343,9 +1369,15 @@ def renderXML(self): finfo = XMLNodeInfo("HarmonicAngleForce") for i, type in enumerate(self.types): t1, t2, t3 = type - ainfo = {'type1': t1, 'type2': t2, 'type3': t3, 'k': self.params['k'][i], 'angle': self.params['angle'][i]} - finfo.addElement('Angle', ainfo) - + ainfo = { + "type1": t1, + "type2": t2, + "type3": t3, + "k": self.params["k"][i], + "angle": self.params["angle"][i], + } + finfo.addElement("Angle", ainfo) + return finfo @@ -1353,7 +1385,6 @@ def renderXML(self): app.forcefield.parsers["HarmonicAngleForce"] = HarmonicAngleJaxGenerator.parseElement - def _matchImproper(data, torsion, generator): type1 = data.atomType[data.atoms[torsion[0]]] type2 = data.atomType[data.atoms[torsion[1]]] @@ -1557,14 +1588,14 @@ def registerImproperTorsion(self, parameters, ordering="default"): @staticmethod def parseElement(element, ff): - """ parse section in XML file - - example: - - - - - + """parse section in XML file + + example: + + + + + """ existing = [f for f in ff._forces if isinstance(f, PeriodicTorsionJaxGenerator)] @@ -1901,50 +1932,58 @@ def getJaxPotential(self): def renderXML(self): params = self.params # generate xml force field file - finfo = XMLNodeInfo('PeriodicTorsionForce') + finfo = XMLNodeInfo("PeriodicTorsionForce") for i in range(len(self.proper)): proper = self.proper[i] - - finfo.addElement('Proper', - {'type1': proper.types1, 'type2': proper.types2, - 'type3': proper.types3, 'type4': proper.types4, - 'periodicity1': proper.periodicity[0], - 'phase1': params['psi1_p'][i], - 'k1': params['k1_p'][i], - 'periodicity2': proper.periodicity[1], - 'phase2': params['psi2_p'][i], - 'k2': params['k2_p'][i], - 'periodicity3': proper.periodicity[2], - 'phase3': params['psi3_p'][i], - 'k3': params['k3_p'][i], - 'periodicity4': proper.periodicity[3], - 'phase4': params['psi4_p'][i], - 'k4': params['k4_p'][i], - } + + finfo.addElement( + "Proper", + { + "type1": proper.types1, + "type2": proper.types2, + "type3": proper.types3, + "type4": proper.types4, + "periodicity1": proper.periodicity[0], + "phase1": params["psi1_p"][i], + "k1": params["k1_p"][i], + "periodicity2": proper.periodicity[1], + "phase2": params["psi2_p"][i], + "k2": params["k2_p"][i], + "periodicity3": proper.periodicity[2], + "phase3": params["psi3_p"][i], + "k3": params["k3_p"][i], + "periodicity4": proper.periodicity[3], + "phase4": params["psi4_p"][i], + "k4": params["k4_p"][i], + }, ) - + for i in range(len(self.improper)): - + improper = self.improper[i] - - finfo.addElement('Improper', - {'type1': improper.types1, 'type2': improper.types2, - 'type3': improper.types3, 'type4': improper.types4, - 'periodicity1': improper.periodicity[0], - 'phase1': params['psi1_i'][i], - 'k1': params['k1_i'][i], - 'periodicity2': proper.periodicity[1], - 'phase2': params['psi2_i'][i], - 'k2': params['k2_i'][i], - 'periodicity3': proper.periodicity[2], - 'phase3': params['psi3_i'][i], - 'k3': params['k3_i'][i], - 'periodicity4': proper.periodicity[3], - 'phase4': params['psi4_i'][i], - 'k4': params['k4_i'][i], - } + + finfo.addElement( + "Improper", + { + "type1": improper.types1, + "type2": improper.types2, + "type3": improper.types3, + "type4": improper.types4, + "periodicity1": improper.periodicity[0], + "phase1": params["psi1_i"][i], + "k1": params["k1_i"][i], + "periodicity2": proper.periodicity[1], + "phase2": params["psi2_i"][i], + "k2": params["k2_i"][i], + "periodicity3": proper.periodicity[2], + "phase3": params["psi3_i"][i], + "k3": params["k3_i"][i], + "periodicity4": proper.periodicity[3], + "phase4": params["psi4_i"][i], + "k4": params["k4_i"][i], + }, ) - + return finfo @@ -1989,9 +2028,9 @@ def registerAtom(self, atom): @staticmethod def parseElement(element, ff): """parse section in XML file - + example: - + @@ -2023,9 +2062,9 @@ def parseElement(element, ff): generator.useAttributeFromResidue.append(eprm) for atom in element.findall("Atom"): generator.registerAtom(atom.attrib) - + generator.n_atoms = len(element.findall("Atom")) - + # jax it! for k in generator.params.keys(): generator.params[k] = jnp.array(generator.params[k]) @@ -2103,7 +2142,6 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): # TODO: implement NBFIX map_nbfix = [] map_nbfix = np.array(map_nbfix, dtype=int).reshape((-1, 2)) - colv_map = build_covalent_map(data, 6) @@ -2352,23 +2390,27 @@ def getJaxPotential(self): return self._jaxPotential def renderXML(self): - + # - finfo = XMLNodeInfo('NonbondedForce') - finfo.addAttribute('coulomb14scale', str(self.coulomb14scale)) - finfo.addAttribute('lj14scale', str(self.lj14scale)) - + finfo = XMLNodeInfo("NonbondedForce") + finfo.addAttribute("coulomb14scale", str(self.coulomb14scale)) + finfo.addAttribute("lj14scale", str(self.lj14scale)) + for atom in range(self.n_atoms): - info = {'type': self.types[atom], 'charge': self.params['charge'][atom], 'sigma': self.params['sigma'][atom], 'epsilon': self.params['epsilon'][atom]} - finfo.addElement('Atom', info) - + info = { + "type": self.types[atom], + "charge": self.params["charge"][atom], + "sigma": self.params["sigma"][atom], + "epsilon": self.params["epsilon"][atom], + } + finfo.addElement("Atom", info) + return finfo app.forcefield.parsers["NonbondedForce"] = NonbondJaxGenerator.parseElement - class Hamiltonian(app.forcefield.ForceField): def __init__(self, *xmlnames): super().__init__(*xmlnames) @@ -2382,13 +2424,13 @@ def createPotential( topology, nonbondedMethod=app.NoCutoff, nonbondedCutoff=1.0 * unit.nanometer, - **args + **args, ): system = self.createSystem( topology, nonbondedMethod=nonbondedMethod, nonbondedCutoff=nonbondedCutoff, - **args + **args, ) # load_constraints_from_system_if_needed # create potentials @@ -2438,4 +2480,4 @@ def getParameters(self): params = {} for gen in self.getGenerators(): params[gen.name] = gen.params - return params \ No newline at end of file + return params diff --git a/dmff/common/nblist.py b/dmff/common/nblist.py index 7cb5f9437..16de9715f 100644 --- a/dmff/common/nblist.py +++ b/dmff/common/nblist.py @@ -98,4 +98,4 @@ def distance(self): jnp.ndarray: (nPairs, ) """ - return jnp.linalg.norm(self.dr, axis=1) \ No newline at end of file + return jnp.linalg.norm(self.dr, axis=1) diff --git a/docs/about.md b/docs/about.md deleted file mode 100644 index 835deeb6f..000000000 --- a/docs/about.md +++ /dev/null @@ -1,3 +0,0 @@ -# Lisense - -# Contributor \ No newline at end of file diff --git a/docs/assets/js/mathjax.js b/docs/assets/js/mathjax.js new file mode 100644 index 000000000..e648674dd --- /dev/null +++ b/docs/assets/js/mathjax.js @@ -0,0 +1,16 @@ +window.MathJax = { + tex: { + inlineMath: [["\\(", "\\)"]], + displayMath: [["\\[", "\\]"]], + processEscapes: true, + processEnvironments: true + }, + options: { + ignoreHtmlClass: ".*|", + processHtmlClass: "arithmatex" + } + }; + + document$.subscribe(() => { + MathJax.typesetPromise() + }) \ No newline at end of file diff --git a/docs/dev_guide/profile.md b/docs/dev_guide/profile.md deleted file mode 100644 index 14a981b77..000000000 --- a/docs/dev_guide/profile.md +++ /dev/null @@ -1,2 +0,0 @@ -# How to profile - diff --git a/docs/gen_ref_pages.py b/docs/gen_ref_pages.py new file mode 100644 index 000000000..77fbb0a3c --- /dev/null +++ b/docs/gen_ref_pages.py @@ -0,0 +1,38 @@ +"""Generate the code reference pages.""" + +from pathlib import Path + +import mkdocs_gen_files + +nav = mkdocs_gen_files.Nav() + +for path in sorted(Path("dmff").rglob("*.py")): # + + module_path = path.relative_to('dmff').with_suffix("") # + + doc_path = path.relative_to('dmff').with_suffix(".md") # + + full_doc_path = Path("refs", doc_path) # + + parts = list(module_path.parts) + + if parts[-1] == "__init__": # + continue + elif parts[-1] == "__main__": + continue + + nav[parts] = doc_path.as_posix() + print(full_doc_path) + with mkdocs_gen_files.open(full_doc_path, "w") as fd: # + + identifier = ".".join(parts) # + + print("::: dmff." + identifier, file=fd) # + + + mkdocs_gen_files.set_edit_path(full_doc_path, path) # + +with mkdocs_gen_files.open("refs/SUMMARY.md", "w") as nav_file: # + + nav_file.writelines(nav.build_literate_nav()) # + diff --git a/docs/license.md b/docs/license.md new file mode 100644 index 000000000..d2af63711 --- /dev/null +++ b/docs/license.md @@ -0,0 +1,3 @@ +# Lisense + +The project DeePMD-kit is licensed under [GNU LGPLv3.0](https://github.com/deepmodeling/deepmd-kit/blob/master/LICENSE). \ No newline at end of file diff --git a/docs/user_guide/xml_spec.md b/docs/user_guide/xml_spec.md index 5304e3141..7386c6345 100644 --- a/docs/user_guide/xml_spec.md +++ b/docs/user_guide/xml_spec.md @@ -151,7 +151,7 @@ The `` node of the residue part defines all the atoms involved in the resi - <\displaylines{Bond atomName1="CA" atomName2="HA"/> + diff --git a/mkdocs.yml b/mkdocs.yml index ea1b37ba5..8a751b209 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,4 +1,5 @@ site_name: DMFF +nav: nav: - Home: index.md - User Guide: @@ -18,7 +19,11 @@ nav: - ADMP: - Introduction: admp/readme.md - Frontends: admp/frontend.md - - About: about.md + + - API: refs/ + + - About: + - License: license.md theme: readthedocs @@ -26,7 +31,17 @@ markdown_extensions: - pymdownx.arithmatex: generic: true +plugins: +- search +- gen-files: + scripts: + - docs/gen_ref_pages.py +- literate-nav: + nav_file: SUMMARY.md +- mkdocstrings: + + extra_javascript: - - javascripts/mathjax.js + - assets/js/mathjax.js - https://polyfill.io/v3/polyfill.min.js?features=es6 - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js diff --git a/requirements.txt b/requirements.txt index 7360452fb..343447502 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,9 @@ numpy>=1.18 jax>=0.3.7 jax-md>=0.1.28 +mkdocs>=1.3.0 +mkdocs-autorefs>=0.4.1 +mkdocs-gen-files>=0.3.4 +mkdocs-literate-nav>=0.4.1 +mkdocstrings>=0.19.0 +mkdocstrings-python>=0.7.0 \ No newline at end of file diff --git a/tests/test_common/test_nblist.py b/tests/test_common/test_nblist.py index b2b66484e..d5f709103 100644 --- a/tests/test_common/test_nblist.py +++ b/tests/test_common/test_nblist.py @@ -51,4 +51,4 @@ def test_dr(self, nblist): def test_distance(self, nblist): - assert nblist.distance.shape == (15, ) \ No newline at end of file + assert nblist.distance.shape == (15, )