From 8b2b031aa7bfa83c06ad1b6176262fb93c6ddf00 Mon Sep 17 00:00:00 2001 From: Roy-Kid Date: Thu, 30 Jun 2022 22:26:09 +0800 Subject: [PATCH 1/5] refactor: admp api.py using old xmlholder api --- dmff/api.py | 381 +++++++++++++++++++++++++- tests/data/methane_water_modified.xml | 59 ++++ tests/test_admp/test_compute.py | 22 +- 3 files changed, 458 insertions(+), 4 deletions(-) create mode 100644 tests/data/methane_water_modified.xml diff --git a/dmff/api.py b/dmff/api.py index 3833198c0..fc3a89dd5 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -1025,9 +1025,388 @@ def potential_fn(positions, box, pairs, params): def getJaxPotential(self): return self._jaxPotential +class ADMPPmeGenerator: -# app.forcefield.parsers["ADMPPmeForce"] = ADMPPmeGenerator.parseElement + def __init__(self, ff): + + self.name = 'ADMPPmeForce' + self.ff = ff + self.fftree = ff.fftree + self.paramtree = ff.paramtree + + def extract(self): + + self.lmax = self.fftree.get_attrib(f'{self.name}', 'lmax')[0] # return [lmax] + + mScales = [self.fftree.get_attrib(f'{self.name}', f'mScale1{i}') for i in range(2, 7)] + pScales = [self.fftree.get_attrib(f'{self.name}', f'mScale1{i}') for i in range(2, 7)] + dScales = [self.fftree.get_attrib(f'{self.name}', f'mScale1{i}') for i in range(2, 7)] + + # make sure the last digit is 1.0 + mScales.append(1.0) + pScales.append(1.0) + dScales.append(1.0) + + self.paramtree[self.name] = {} + self.paramtree[self.name]['mScales'] = mScales + self.paramtree[self.name]['pScales'] = pScales + self.paramtree[self.name]['dScales'] = dScales + + # check if polarize + polarize = self.fftree.get_node(f'{self.name}/Polarize') + if polarize: + self.lpol = True + else: + self.lpol = False + + atomTypes = self.fftree.get_attrib(f'{self.name}/Atom', 'type') + self.atomTypes = np.array(atomTypes, dtype=int).astype(str) + # kx = self.fftree.get_attrib(f'{self.name}/Atom', 'kx') + # ky = self.fftree.get_attrib(f'{self.name}/Atom', 'ky') + # kz = self.fftree.get_attrib(f'{self.name}/Atom', 'kz') + c0 = self.fftree.get_attrib(f'{self.name}/Atom', 'c0') + dX = self.fftree.get_attrib(f'{self.name}/Atom', 'dX') + dY = self.fftree.get_attrib(f'{self.name}/Atom', 'dY') + dZ = self.fftree.get_attrib(f'{self.name}/Atom', 'dZ') + qXX = self.fftree.get_attrib(f'{self.name}/Atom', 'qXX') + qYY = self.fftree.get_attrib(f'{self.name}/Atom', 'qYY') + qZZ = self.fftree.get_attrib(f'{self.name}/Atom', 'qZZ') + qXY = self.fftree.get_attrib(f'{self.name}/Atom', 'qXY') + qYZ = self.fftree.get_attrib(f'{self.name}/Atom', 'qYZ') + + # assume that polarize tag match the per atom type + polarizabilityXX = self.fftree.get_attrib(f'{self.name}/Polarize', 'polarizabilityXX') + polarizabilityYY = self.fftree.get_attrib(f'{self.name}/Polarize', 'polarizabilityYY') + polarizabilityZZ = self.fftree.get_attrib(f'{self.name}/Polarize', 'polarizabilityZZ') + thole = self.fftree.get_attrib(f'{self.name}/Polarize', 'thole') + + n_atoms = len(atomTypes) + assert n_atoms == len(polarizabilityXX), "Number of polarizabilityXX does not match number of atoms!" + + # map atom multipole moments + if self.lmax == 0: + n_mtps = 1 + elif self.lmax == 1: + n_mtps = 4 + elif self.lmax == 2: + n_mtps = 10 + Q = np.zeros((n_atoms, n_mtps)) + + # TDDO: unit conversion + Q[:, 0] = c0 + if self.lmax >= 1: + Q[:, 1] = dX + Q[:, 2] = dY + Q[:, 3] = dZ + Q[:, 1:4] *= 10 + if self.lmax >= 2: + Q[:, 4] = qXX + Q[:, 5] = qYY + Q[:, 6] = qZZ + Q[:, 7] = qXY + Q[:, 8] = qYZ + Q[:, 4:9] *= 300 + + # add all differentiable params to self.params + Q_local = convert_cart2harm(Q, self.lmax) + self.paramtree[self.name]["Q_local"] = Q_local + + if self.lpol: + pol = jnp.vstack(( + polarizabilityXX, + polarizabilityYY, + polarizabilityZZ, + )).T + pol = 1000 * jnp.mean(pol, axis=1) + tholes = jnp.array(thole) + self.paramtree[self.name]["pol"] = pol + self.paramtree[self.name]["tholes"] = tholes + else: + pol = None + tholes = None + + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, + args): + + methodMap = { + app.CutoffPeriodic: "CutoffPeriodic", + app.NoCutoff: "NoCutoff", + app.PME: "PME", + } + if nonbondedMethod not in methodMap: + raise ValueError("Illegal nonbonded method for ADMPPmeForce") + if nonbondedMethod is app.CutoffPeriodic: + self.lpme = False + else: + self.lpme = True + + n_atoms = len(data.atoms) + map_atomtype = np.zeros(n_atoms, dtype=int) + + for i in range(n_atoms): + atype = data.atomType[data.atoms[i]] # convert str to int to match atomTypes + map_atomtype[i] = np.where(self.atomTypes == atype)[0][0] + self.map_atomtype = map_atomtype + + # here box is only used to setup ewald parameters, no need to be differentiable + a, b, c = system.getDefaultPeriodicBoxVectors() + box = jnp.array([a._value, b._value, c._value]) * 10 + + # get the admp calculator + rc = nonbondedCutoff.value_in_unit(unit.angstrom) + + # build covalent map + covalent_map = build_covalent_map(data, 6) + + # build intra-molecule axis + # the following code is the direct transplant of forcefield.py in openmm 7.4.0 + + if self.lmax > 0: + + # setting up axis_indices and axis_type + ZThenX = 0 + Bisector = 1 + ZBisect = 2 + ThreeFold = 3 + ZOnly = 4 # typo fix + NoAxisType = 5 + LastAxisTypeIndex = 6 + + self.axis_types = [] + self.axis_indices = [] + for i_atom in range(n_atoms): + atom = data.atoms[i_atom] + t = data.atomType[atom] + # if t is in type list? + if t in self.atomTypes: + itypes = np.where(self.atomTypes == t)[0] + hit = 0 + # try to assign multipole parameters via only 1-2 connected atoms + for itype in itypes: + if hit != 0: + break + kz = int(self.kStrings["kz"][itype]) + kx = int(self.kStrings["kx"][itype]) + ky = int(self.kStrings["ky"][itype]) + neighbors = np.where(covalent_map[i_atom] == 1)[0] + zaxis = -1 + xaxis = -1 + yaxis = -1 + for z_index in neighbors: + if hit != 0: + break + z_type = int(data.atomType[data.atoms[z_index]]) + if z_type == abs( + kz + ): # find the z atom, start searching for x + for x_index in neighbors: + if x_index == z_index or hit != 0: + continue + x_type = int( + data.atomType[data.atoms[x_index]]) + if x_type == abs( + kx + ): # find the x atom, start searching for y + if ky == 0: + zaxis = z_index + xaxis = x_index + # cannot ditinguish x and z? use the smaller index for z, and the larger index for x + if x_type == z_type and xaxis < zaxis: + swap = z_axis + z_axis = x_axis + x_axis = swap + # otherwise, try to see if we can find an even smaller index for x? + else: + for x_index in neighbors: + x_type1 = int( + data.atomType[ + data. + atoms[x_index]]) + if (x_type1 == abs(kx) and + x_index != z_index + and + x_index < xaxis): + xaxis = x_index + hit = 1 # hit, finish matching + matched_itype = itype + else: + for y_index in neighbors: + if (y_index == z_index + or y_index == x_index + or hit != 0): + continue + y_type = int(data.atomType[ + data.atoms[y_index]]) + if y_type == abs(ky): + zaxis = z_index + xaxis = x_index + yaxis = y_index + hit = 2 + matched_itype = itype + # assign multipole parameters via 1-2 and 1-3 connected atoms + for itype in itypes: + if hit != 0: + break + kz = int(self.kStrings["kz"][itype]) + kx = int(self.kStrings["kx"][itype]) + ky = int(self.kStrings["ky"][itype]) + neighbors_1st = np.where(covalent_map[i_atom] == 1)[0] + neighbors_2nd = np.where(covalent_map[i_atom] == 2)[0] + zaxis = -1 + xaxis = -1 + yaxis = -1 + for z_index in neighbors_1st: + if hit != 0: + break + z_type = int(data.atomType[data.atoms[z_index]]) + if z_type == abs(kz): + for x_index in neighbors_2nd: + if x_index == z_index or hit != 0: + continue + x_type = int( + data.atomType[data.atoms[x_index]]) + # we ask x to be in 2'nd neighbor, and x is z's neighbor + if (x_type == abs(kx) + and covalent_map[z_index, + x_index] == 1): + if ky == 0: + zaxis = z_index + xaxis = x_index + # select smallest x index + for x_index in neighbors_2nd: + x_type1 = int(data.atomType[ + data.atoms[x_index]]) + if (x_type1 == abs(kx) + and x_index != z_index + and + covalent_map[x_index, + z_index] + == 1 + and x_index < xaxis): + xaxis = x_index + hit = 3 + matched_itype = itype + else: + for y_index in neighbors_2nd: + if (y_index == z_index + or y_index == x_index + or hit != 0): + continue + y_type = int(data.atomType[ + data.atoms[y_index]]) + if (y_type == abs(ky) and + covalent_map[y_index, + z_index] + == 1): + zaxis = z_index + xaxis = x_index + yaxis = y_index + hit = 4 + matched_itype = itype + # assign multipole parameters via only a z-defining atom + for itype in itypes: + if hit != 0: + break + kz = int(self.kStrings["kz"][itype]) + kx = int(self.kStrings["kx"][itype]) + zaxis = -1 + xaxis = -1 + yaxis = -1 + neighbors = np.where(covalent_map[i_atom] == 1)[0] + for z_index in neighbors: + if hit != 0: + break + z_type = int(data.atomType[data.atoms[z_index]]) + if kx == 0 and z_type == abs(kz): + zaxis = z_index + hit = 5 + matched_itype = itype + # assign multipole parameters via no connected atoms + for itype in itypes: + if hit != 0: + break + kz = int(self.kStrings["kz"][itype]) + zaxis = -1 + xaxis = -1 + yaxis = -1 + if kz == 0: + hit = 6 + matched_itype = itype + # add particle if there was a hit + if hit != 0: + map_atomtype[i_atom] = matched_itype + self.axis_indices.append([zaxis, xaxis, yaxis]) + kz = int(self.kStrings["kz"][matched_itype]) + kx = int(self.kStrings["kx"][matched_itype]) + ky = int(self.kStrings["ky"][matched_itype]) + axisType = ZThenX + if kz == 0: + axisType = NoAxisType + if kz != 0 and kx == 0: + axisType = ZOnly + if kz < 0 or kx < 0: + axisType = Bisector + if kx < 0 and ky < 0: + axisType = ZBisect + if kz < 0 and kx < 0 and ky < 0: + axisType = ThreeFold + self.axis_types.append(axisType) + + else: + sys.exit("Atom %d not matched in forcefield!" % i_atom) + + else: + sys.exit("Atom %d not matched in forcefield!" % i_atom) + self.axis_indices = np.array(self.axis_indices) + self.axis_types = np.array(self.axis_types) + else: + self.axis_types = None + self.axis_indices = None + + if "ethresh" in args: + self.ethresh = args["ethresh"] + if "step_pol" in args: + self.step_pol = args["step_pol"] + + pme_force = ADMPPmeForce(box, self.axis_types, self.axis_indices, + covalent_map, rc, self.ethresh, self.lmax, + self.lpol, self.lpme, self.step_pol) + self.pme_force = pme_force + + def potential_fn(positions, box, pairs, params): + + mScales = params["mScales"] + Q_local = params["Q_local"][map_atomtype] + if self.lpol: + pScales = params["pScales"] + dScales = params["dScales"] + pol = params["pol"][map_atomtype] + tholes = params["tholes"][map_atomtype] + + return pme_force.get_energy( + positions, + box, + pairs, + Q_local, + pol, + tholes, + mScales, + pScales, + dScales, + pme_force.U_ind, + ) + else: + return pme_force.get_energy(positions, box, pairs, Q_local, + mScales) + + self._jaxPotential = potential_fn + + def getJaxPotential(self): + return self._jaxPotential + +# app.forcefield.parsers["ADMPPmeForce"] = ADMPPmeGenerator.parseElement +jaxGenerators["ADMPPmeForce"] = ADMPPmeGenerator class Potential: def __init__(self): diff --git a/tests/data/methane_water_modified.xml b/tests/data/methane_water_modified.xml new file mode 100644 index 000000000..95a84f28e --- /dev/null +++ b/tests/data/methane_water_modified.xml @@ -0,0 +1,59 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/test_admp/test_compute.py b/tests/test_admp/test_compute.py index e6ea1f3bb..444cf8d5e 100644 --- a/tests/test_admp/test_compute.py +++ b/tests/test_admp/test_compute.py @@ -12,7 +12,7 @@ class TestADMPAPI: """ Test ADMP related generators """ - @pytest.fixture(scope='class', name='generators') + @pytest.fixture(scope='class', name='potentials') def test_init(self): """load generators from XML file @@ -25,9 +25,25 @@ def test_init(self): rc = 4.0 H = Hamiltonian('tests/data/admp.xml') pdb = app.PDBFile('tests/data/water_dimer.pdb') - H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5) + potential = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5) - yield H.getGenerators() + yield potential + + def test_ADMPPmeForce(self, potentials): + + rc = 4.0 + pdb = app.PDBFile('tests/data/water_dimer.pdb') + positions = np.array(pdb.positions._value) * 10 + a, b, c = pdb.topology.getPeriodicBoxVectors() + box = np.array([a._value, b._value, c._value]) * 10 + # neighbor list + nblist = NeighborList(box, rc) + nblist.allocate(positions) + pairs = nblist.pairs + + pot = potentials.getPotentialFunc(names=['ADMPPmeForce']) + energy = pot(positions, box, pairs, potentials.params) + def test_ADMPPmeForce_jit(self, generators): From c1be36183d60c5978efac8ce14690a6f4d9ecd35 Mon Sep 17 00:00:00 2001 From: Roy-Kid Date: Sat, 2 Jul 2022 18:02:41 +0800 Subject: [PATCH 2/5] update --- dmff/api.py | 40 ++++++++++++++++++++++++++++++--------- tests/data/admp.xml | 4 ++-- tests/test_api.py | 46 ++++++++++++++++++++++----------------------- 3 files changed, 56 insertions(+), 34 deletions(-) diff --git a/dmff/api.py b/dmff/api.py index 077c66449..174b1e737 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -1034,13 +1034,21 @@ def __init__(self, ff): self.fftree = ff.fftree self.paramtree = ff.paramtree + # default params + self._jaxPotential = None + self.types = [] + self.ethresh = 5e-4 + self.step_pol = None + self.lpol = False + self.ref_dip = "" + def extract(self): self.lmax = self.fftree.get_attrib(f'{self.name}', 'lmax')[0] # return [lmax] - mScales = [self.fftree.get_attrib(f'{self.name}', f'mScale1{i}') for i in range(2, 7)] - pScales = [self.fftree.get_attrib(f'{self.name}', f'mScale1{i}') for i in range(2, 7)] - dScales = [self.fftree.get_attrib(f'{self.name}', f'mScale1{i}') for i in range(2, 7)] + mScales = [self.fftree.get_attrib(f'{self.name}', f'mScale1{i}')[0] for i in range(2, 7)] + pScales = [self.fftree.get_attrib(f'{self.name}', f'mScale1{i}')[0] for i in range(2, 7)] + dScales = [self.fftree.get_attrib(f'{self.name}', f'mScale1{i}')[0] for i in range(2, 7)] # make sure the last digit is 1.0 mScales.append(1.0) @@ -1048,9 +1056,9 @@ def extract(self): dScales.append(1.0) self.paramtree[self.name] = {} - self.paramtree[self.name]['mScales'] = mScales - self.paramtree[self.name]['pScales'] = pScales - self.paramtree[self.name]['dScales'] = dScales + self.paramtree[self.name]['mScales'] = jnp.array(mScales) + self.paramtree[self.name]['pScales'] = jnp.array(pScales) + self.paramtree[self.name]['dScales'] = jnp.array(dScales) # check if polarize polarize = self.fftree.get_node(f'{self.name}/Polarize') @@ -1061,9 +1069,20 @@ def extract(self): atomTypes = self.fftree.get_attrib(f'{self.name}/Atom', 'type') self.atomTypes = np.array(atomTypes, dtype=int).astype(str) - # kx = self.fftree.get_attrib(f'{self.name}/Atom', 'kx') - # ky = self.fftree.get_attrib(f'{self.name}/Atom', 'ky') - # kz = self.fftree.get_attrib(f'{self.name}/Atom', 'kz') + kx = self.fftree.get_attrib(f'{self.name}/Atom', 'kx') + ky = self.fftree.get_attrib(f'{self.name}/Atom', 'ky') + kz = self.fftree.get_attrib(f'{self.name}/Atom', 'kz') + + kx = [ 0 if kx_ is None else int(kx_) for kx_ in kx ] + ky = [ 0 if ky_ is None else int(ky_) for ky_ in ky ] + kz = [ 0 if kz_ is None else int(kz_) for kz_ in kz ] + + # invoke by `self.kStrings["kz"][itype]` + self.kStrings = {} + self.kStrings['kx'] = kx + self.kStrings['ky'] = ky + self.kStrings['kz'] = kz + c0 = self.fftree.get_attrib(f'{self.name}/Atom', 'c0') dX = self.fftree.get_attrib(f'{self.name}/Atom', 'dX') dY = self.fftree.get_attrib(f'{self.name}/Atom', 'dY') @@ -1464,6 +1483,9 @@ def __init__(self, *xmlnames): for jaxGen in self._jaxGenerators: self._forces.append(jaxGen) + def getGenerators(self): + return self._jaxGenerators + def extractParameterTree(self): # load Force info for jaxgen in self._jaxGenerators: diff --git a/tests/data/admp.xml b/tests/data/admp.xml index 53e25cb14..94b78b628 100644 --- a/tests/data/admp.xml +++ b/tests/data/admp.xml @@ -21,12 +21,12 @@ - - Date: Sat, 2 Jul 2022 22:45:44 +0800 Subject: [PATCH 3/5] update: admp related api --- dmff/api.py | 322 +++++++++++++++++++++++++++++++++++++++++--- dmff/fftree.py | 2 +- tests/data/admp.xml | 4 +- tests/test_api.py | 32 ++--- 4 files changed, 319 insertions(+), 41 deletions(-) diff --git a/dmff/api.py b/dmff/api.py index 119653c8f..8469f53e9 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -200,6 +200,132 @@ def getJaxPotential(self): # register all parsers # app.forcefield.parsers["ADMPDispForce"] = ADMPDispGenerator.parseElement +class ADMPDispGenerator: + def __init__(self, ff): + + self.name = "ADMPDispForce" + self.ff = ff + self.fftree = ff.fftree + self.paramtree = ff.paramtree + + # default params + self._jaxPotential = None + self.types = [] + self.ethresh = 5e-4 + self.pmax = 10 + + def extract(self): + + mScales = [self.fftree.get_attribs(f'{self.name}', f'mScale1{i}')[0] for i in range(2, 7)] + mScales.append(1.0) + self.paramtree[self.name] = {} + self.paramtree[self.name]['mScales'] = jnp.array(mScales) + + ABQC = self.fftree.get_attribs(f'{self.name}/Atom', ['A', 'B', 'Q', 'C6', 'C8', 'C10']) + + ABQC = np.array(ABQC) + A = ABQC[:, 0] + B = ABQC[:, 1] + Q = ABQC[:, 2] + C6 = ABQC[:, 3] + C8 = ABQC[:, 4] + C10 = ABQC[:, 5] + + self.paramtree[self.name]['A'] = jnp.array(A) + self.paramtree[self.name]['B'] = jnp.array(B) + self.paramtree[self.name]['Q'] = jnp.array(Q) + self.paramtree[self.name]['C6'] = jnp.array(C6) + self.paramtree[self.name]['C8'] = jnp.array(C8) + self.paramtree[self.name]['C10'] = jnp.array(C10) + + atomTypes = self.fftree.get_attribs(f'{self.name}/Atom', f'type') + self.atomTypes = np.array(atomTypes, dtype=int).astype(str) + + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, + args): + + methodMap = { + app.CutoffPeriodic: "CutoffPeriodic", + app.NoCutoff: "NoCutoff", + app.PME: "PME", + } + if nonbondedMethod not in methodMap: + raise ValueError("Illegal nonbonded method for ADMPDispForce") + if nonbondedMethod is app.CutoffPeriodic: + self.lpme = False + else: + self.lpme = True + + n_atoms = len(data.atoms) + # build index map + map_atomtype = np.zeros(n_atoms, dtype=int) + for i in range(n_atoms): + atype = data.atomType[data.atoms[i]] + map_atomtype[i] = np.where(self.atomTypes == atype)[0][0] + self.map_atomtype = map_atomtype + # build covalent map + covalent_map = build_covalent_map(data, 6) + # here box is only used to setup ewald parameters, no need to be differentiable + a, b, c = system.getDefaultPeriodicBoxVectors() + box = jnp.array([a._value, b._value, c._value]) * 10 + # get the admp calculator + rc = nonbondedCutoff.value_in_unit(unit.angstrom) + + # get calculator + if "ethresh" in args: + self.ethresh = args["ethresh"] + + Force_DispPME = ADMPDispPmeForce(box, + covalent_map, + rc, + self.ethresh, + self.pmax, + lpme=self.lpme) + self.disp_pme_force = Force_DispPME + pot_fn_lr = Force_DispPME.get_energy + pot_fn_sr = generate_pairwise_interaction(TT_damping_qq_c6_kernel, + covalent_map, + static_args={}) + + def potential_fn(positions, box, pairs, params): + mScales = params["mScales"] + a_list = (params["A"][map_atomtype] / 2625.5 + ) # kj/mol to au, as expected by TT_damping kernel + b_list = params["B"][map_atomtype] * 0.0529177249 # nm^-1 to au + q_list = params["Q"][map_atomtype] + c6_list = jnp.sqrt(params["C6"][map_atomtype] * 1e6) + c8_list = jnp.sqrt(params["C8"][map_atomtype] * 1e8) + c10_list = jnp.sqrt(params["C10"][map_atomtype] * 1e10) + c_list = jnp.vstack((c6_list, c8_list, c10_list)) + + E_sr = pot_fn_sr(positions, box, pairs, mScales, a_list, b_list, + q_list, c_list[0]) + E_lr = pot_fn_lr(positions, box, pairs, c_list.T, mScales) + return E_sr - E_lr + + self._jaxPotential = potential_fn + # self._top_data = data + + def overwrite(self): + + self.fftree.set_attrib(f'{self.name}', 'mScale12', self.paramtree[self.name]['mScales'][0]) + self.fftree.set_attrib(f'{self.name}', 'mScale13', self.paramtree[self.name]['mScales'][1]) + self.fftree.set_attrib(f'{self.name}', 'mScale14', self.paramtree[self.name]['mScales'][2]) + self.fftree.set_attrib(f'{self.name}', 'mScale15', self.paramtree[self.name]['mScales'][3]) + self.fftree.set_attrib(f'{self.name}', 'mScale16', self.paramtree[self.name]['mScales'][4]) + + self.fftree.set_attrib(f'{self.name}/Atom', 'A', self.paramtree[self.name]['A']) + self.fftree.set_attrib(f'{self.name}/Atom', 'B', self.paramtree[self.name]['B']) + self.fftree.set_attrib(f'{self.name}/Atom', 'Q', self.paramtree[self.name]['Q']) + self.fftree.set_attrib(f'{self.name}/Atom', 'C6', self.paramtree[self.name]['C6']) + self.fftree.set_attrib(f'{self.name}/Atom', 'C8', self.paramtree[self.name]['C8']) + self.fftree.set_attrib(f'{self.name}/Atom', 'C10', self.paramtree[self.name]['C10']) + + + def getJaxPotential(self): + return self._jaxPotential + +jaxGenerators['ADMPDispForce'] = ADMPDispGenerator class ADMPDispPmeGenerator: r""" @@ -299,6 +425,114 @@ def getJaxPotential(self): # register all parsers # app.forcefield.parsers["ADMPDispPmeForce"] = ADMPDispPmeGenerator.parseElement +class ADMPDispPmeGenerator: + r""" + This one computes the undamped C6/C8/C10 interactions + u = \sum_{ij} c6/r^6 + c8/r^8 + c10/r^10 + """ + def __init__(self, ff): + self.ff = ff + self.fftree = ff.fftree + self.paramtree = ff.paramtree + + self.params = {"C6": [], "C8": [], "C10": []} + self._jaxPotential = None + self.atomTypes = None + self.ethresh = 5e-4 + self.pmax = 10 + self.name = "ADMPDispPmeForce" + + def extract(self): + + mScales = [self.fftree.get_attribs(f'{self.name}', f'mScale1{i}')[0] for i in range(2, 7)] + mScales.append(1.0) + + self.paramtree[self.name] = {} + self.paramtree[self.name]['mScales'] = jnp.array(mScales) + + C6 = self.fftree.get_attribs(f'{self.name}/Atom', f'C6') + C8 = self.fftree.get_attribs(f'{self.name}/Atom', f'C8') + C10 = self.fftree.get_attribs(f'{self.name}/Atom', f'C10') + + self.paramtree[self.name]['C6'] = jnp.array(C6) + self.paramtree[self.name]['C8'] = jnp.array(C8) + self.paramtree[self.name]['C10'] = jnp.array(C10) + + atomTypes = self.fftree.get_attribs(f'{self.name}/Atom', f'type') + self.atomTypes = np.array(atomTypes, dtype=int).astype(str) + + def overwrite(self): + + self.fftree.set_attrib(f'{self.name}', 'mScale12', self.paramtree[self.name]['mScales'][0]) + self.fftree.set_attrib(f'{self.name}', 'mScale13', self.paramtree[self.name]['mScales'][1]) + self.fftree.set_attrib(f'{self.name}', 'mScale14', self.paramtree[self.name]['mScales'][2]) + self.fftree.set_attrib(f'{self.name}', 'mScale15', self.paramtree[self.name]['mScales'][3]) + self.fftree.set_attrib(f'{self.name}', 'mScale16', self.paramtree[self.name]['mScales'][4]) + + self.fftree.set_attrib(f'{self.name}/Atom', 'C6', self.paramtree[self.name]['C6']) + self.fftree.set_attrib(f'{self.name}/Atom', 'C8', self.paramtree[self.name]['C8']) + self.fftree.set_attrib(f'{self.name}/Atom', 'C10', self.paramtree[self.name]['C10']) + + + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, + args): + methodMap = { + app.CutoffPeriodic: "CutoffPeriodic", + app.NoCutoff: "NoCutoff", + app.PME: "PME", + } + if nonbondedMethod not in methodMap: + raise ValueError("Illegal nonbonded method for ADMPDispPmeForce") + if nonbondedMethod is app.CutoffPeriodic: + self.lpme = False + else: + self.lpme = True + + n_atoms = len(data.atoms) + # build index map + map_atomtype = np.zeros(n_atoms, dtype=int) + for i in range(n_atoms): + atype = data.atomType[data.atoms[i]] + map_atomtype[i] = np.where(self.atomTypes == atype)[0][0] + self.map_atomtype = map_atomtype + # build covalent map + covalent_map = build_covalent_map(data, 6) + + # here box is only used to setup ewald parameters, no need to be differentiable + a, b, c = system.getDefaultPeriodicBoxVectors() + box = jnp.array([a._value, b._value, c._value]) * 10 + # get the admp calculator + rc = nonbondedCutoff.value_in_unit(unit.angstrom) + + # get calculator + if "ethresh" in args: + self.ethresh = args["ethresh"] + + disp_force = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh, + self.pmax, self.lpme) + self.disp_force = disp_force + pot_fn_lr = disp_force.get_energy + + def potential_fn(positions, box, pairs, params): + mScales = params["mScales"] + C6_list = params["C6"][map_atomtype] * 1e6 # to kj/mol * A**6 + C8_list = params["C8"][map_atomtype] * 1e8 + C10_list = params["C10"][map_atomtype] * 1e10 + c6_list = jnp.sqrt(C6_list) + c8_list = jnp.sqrt(C8_list) + c10_list = jnp.sqrt(C10_list) + c_list = jnp.vstack((c6_list, c8_list, c10_list)) + E_lr = pot_fn_lr(positions, box, pairs, c_list.T, mScales) + return -E_lr + + self._jaxPotential = potential_fn + # self._top_data = data + + def getJaxPotential(self): + return self._jaxPotential + +jaxGenerators['ADMPDispPmeForce'] = ADMPDispPmeGenerator + class QqTtDampingGenerator: r""" @@ -1044,11 +1278,11 @@ def __init__(self, ff): def extract(self): - self.lmax = self.fftree.get_attrib(f'{self.name}', 'lmax')[0] # return [lmax] + self.lmax = self.fftree.get_attribs(f'{self.name}', 'lmax')[0] # return [lmax] - mScales = [self.fftree.get_attrib(f'{self.name}', f'mScale1{i}')[0] for i in range(2, 7)] - pScales = [self.fftree.get_attrib(f'{self.name}', f'mScale1{i}')[0] for i in range(2, 7)] - dScales = [self.fftree.get_attrib(f'{self.name}', f'mScale1{i}')[0] for i in range(2, 7)] + mScales = [self.fftree.get_attribs(f'{self.name}', f'mScale1{i}')[0] for i in range(2, 7)] + pScales = [self.fftree.get_attribs(f'{self.name}', f'pScale1{i}')[0] for i in range(2, 7)] + dScales = [self.fftree.get_attribs(f'{self.name}', f'dScale1{i}')[0] for i in range(2, 7)] # make sure the last digit is 1.0 mScales.append(1.0) @@ -1061,17 +1295,17 @@ def extract(self): self.paramtree[self.name]['dScales'] = jnp.array(dScales) # check if polarize - polarize = self.fftree.get_node(f'{self.name}/Polarize') + polarize = self.fftree.get_nodes(f'{self.name}/Polarize') if polarize: self.lpol = True else: self.lpol = False - atomTypes = self.fftree.get_attrib(f'{self.name}/Atom', 'type') + atomTypes = self.fftree.get_attribs(f'{self.name}/Atom', 'type') self.atomTypes = np.array(atomTypes, dtype=int).astype(str) - kx = self.fftree.get_attrib(f'{self.name}/Atom', 'kx') - ky = self.fftree.get_attrib(f'{self.name}/Atom', 'ky') - kz = self.fftree.get_attrib(f'{self.name}/Atom', 'kz') + kx = self.fftree.get_attribs(f'{self.name}/Atom', 'kx') + ky = self.fftree.get_attribs(f'{self.name}/Atom', 'ky') + kz = self.fftree.get_attribs(f'{self.name}/Atom', 'kz') kx = [ 0 if kx_ is None else int(kx_) for kx_ in kx ] ky = [ 0 if ky_ is None else int(ky_) for ky_ in ky ] @@ -1083,21 +1317,21 @@ def extract(self): self.kStrings['ky'] = ky self.kStrings['kz'] = kz - c0 = self.fftree.get_attrib(f'{self.name}/Atom', 'c0') - dX = self.fftree.get_attrib(f'{self.name}/Atom', 'dX') - dY = self.fftree.get_attrib(f'{self.name}/Atom', 'dY') - dZ = self.fftree.get_attrib(f'{self.name}/Atom', 'dZ') - qXX = self.fftree.get_attrib(f'{self.name}/Atom', 'qXX') - qYY = self.fftree.get_attrib(f'{self.name}/Atom', 'qYY') - qZZ = self.fftree.get_attrib(f'{self.name}/Atom', 'qZZ') - qXY = self.fftree.get_attrib(f'{self.name}/Atom', 'qXY') - qYZ = self.fftree.get_attrib(f'{self.name}/Atom', 'qYZ') + c0 = self.fftree.get_attribs(f'{self.name}/Atom', 'c0') + dX = self.fftree.get_attribs(f'{self.name}/Atom', 'dX') + dY = self.fftree.get_attribs(f'{self.name}/Atom', 'dY') + dZ = self.fftree.get_attribs(f'{self.name}/Atom', 'dZ') + qXX = self.fftree.get_attribs(f'{self.name}/Atom', 'qXX') + qYY = self.fftree.get_attribs(f'{self.name}/Atom', 'qYY') + qZZ = self.fftree.get_attribs(f'{self.name}/Atom', 'qZZ') + qXY = self.fftree.get_attribs(f'{self.name}/Atom', 'qXY') + qYZ = self.fftree.get_attribs(f'{self.name}/Atom', 'qYZ') # assume that polarize tag match the per atom type - polarizabilityXX = self.fftree.get_attrib(f'{self.name}/Polarize', 'polarizabilityXX') - polarizabilityYY = self.fftree.get_attrib(f'{self.name}/Polarize', 'polarizabilityYY') - polarizabilityZZ = self.fftree.get_attrib(f'{self.name}/Polarize', 'polarizabilityZZ') - thole = self.fftree.get_attrib(f'{self.name}/Polarize', 'thole') + polarizabilityXX = self.fftree.get_attribs(f'{self.name}/Polarize', 'polarizabilityXX') + polarizabilityYY = self.fftree.get_attribs(f'{self.name}/Polarize', 'polarizabilityYY') + polarizabilityZZ = self.fftree.get_attribs(f'{self.name}/Polarize', 'polarizabilityZZ') + thole = self.fftree.get_attribs(f'{self.name}/Polarize', 'thole') n_atoms = len(atomTypes) assert n_atoms == len(polarizabilityXX), "Number of polarizabilityXX does not match number of atoms!" @@ -1144,6 +1378,50 @@ def extract(self): pol = None tholes = None + def overwrite(self): + + self.fftree.set_attrib(f'{self.name}', 'mScale12', self.paramtree[self.name]['mScales'][0]) + self.fftree.set_attrib(f'{self.name}', 'mScale13', self.paramtree[self.name]['mScales'][1]) + self.fftree.set_attrib(f'{self.name}', 'mScale14', self.paramtree[self.name]['mScales'][2]) + self.fftree.set_attrib(f'{self.name}', 'mScale15', self.paramtree[self.name]['mScales'][3]) + self.fftree.set_attrib(f'{self.name}', 'mScale16', self.paramtree[self.name]['mScales'][4]) + + self.fftree.set_attrib(f'{self.name}', 'pScale12', self.paramtree[self.name]['pScales'][0]) + self.fftree.set_attrib(f'{self.name}', 'pScale13', self.paramtree[self.name]['pScales'][1]) + self.fftree.set_attrib(f'{self.name}', 'pScale14', self.paramtree[self.name]['pScales'][2]) + self.fftree.set_attrib(f'{self.name}', 'pScale15', self.paramtree[self.name]['pScales'][3]) + self.fftree.set_attrib(f'{self.name}', 'pScale16', self.paramtree[self.name]['pScales'][4]) + + self.fftree.set_attrib(f'{self.name}', 'dScale12', self.paramtree[self.name]['dScales'][0]) + self.fftree.set_attrib(f'{self.name}', 'dScale13', self.paramtree[self.name]['dScales'][1]) + self.fftree.set_attrib(f'{self.name}', 'dScale14', self.paramtree[self.name]['dScales'][2]) + self.fftree.set_attrib(f'{self.name}', 'dScale15', self.paramtree[self.name]['dScales'][3]) + self.fftree.set_attrib(f'{self.name}', 'dScale16', self.paramtree[self.name]['dScales'][4]) + + Q_global = convert_harm2cart(self.paramtree[self.name]['Q_local'], self.lmax) + + + self.fftree.set_attrib(f'{self.name}/Atom', 'c0', Q_global[:, 0]) + self.fftree.set_attrib(f'{self.name}/Atom', 'dX', Q_global[:, 1]) + self.fftree.set_attrib(f'{self.name}/Atom', 'dY', Q_global[:, 2]) + self.fftree.set_attrib(f'{self.name}/Atom', 'dZ', Q_global[:, 3]) + self.fftree.set_attrib(f'{self.name}/Atom', 'qXX', Q_global[:, 4]) + self.fftree.set_attrib(f'{self.name}/Atom', 'qYY', Q_global[:, 5]) + self.fftree.set_attrib(f'{self.name}/Atom', 'qZZ', Q_global[:, 6]) + self.fftree.set_attrib(f'{self.name}/Atom', 'qXY', Q_global[:, 7]) + self.fftree.set_attrib(f'{self.name}/Atom', 'qYZ', Q_global[:, 8]) + self.fftree.set_attrib(f'{self.name}/Atom', 'qYZ', Q_global[:, 9]) + + if self.lpol: + self.fftree.set_attrib(f'{self.name}/Polarize', 'polarizabilityXX', self.paramtree[self.name]['pol'][:, 0]) + self.fftree.set_attrib(f'{self.name}/Polarize', 'polarizabilityYY', self.paramtree[self.name]['pol'][:, 1]) + self.fftree.set_attrib(f'{self.name}/Polarize', 'polarizabilityZZ', self.paramtree[self.name]['pol'][:, 2]) + self.fftree.set_attrib(f'{self.name}/Polarize', 'thole', self.paramtree[self.name]['tholes']) + + + + + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): diff --git a/dmff/fftree.py b/dmff/fftree.py index a03094d89..45cd77b31 100644 --- a/dmff/fftree.py +++ b/dmff/fftree.py @@ -78,7 +78,7 @@ def get_attribs(self, parser, attrname): ret.append(vals) return ret else: - attrs = [convertStr2Float(n.attrs[attrname]) for n in sel] + attrs = [convertStr2Float(n.attrs[attrname]) if attrname in n.attrs else None for n in sel] return attrs def set_node(self, parser, values): diff --git a/tests/data/admp.xml b/tests/data/admp.xml index 94b78b628..53e25cb14 100644 --- a/tests/data/admp.xml +++ b/tests/data/admp.xml @@ -21,12 +21,12 @@ - - Date: Mon, 4 Jul 2022 23:12:59 +0800 Subject: [PATCH 4/5] fix: fix bug in admp api --- dmff/api.py | 80 +++++++++++++++-------------- dmff/fftree.py | 90 ++++++++++++++++++++++++++++++--- tests/data/tip3p.xml | 25 +++++++++ tests/test_admp/test_compute.py | 22 ++++---- tests/test_api.py | 90 ++++++++++++++++----------------- tests/test_classical/test_lj.py | 2 +- 6 files changed, 208 insertions(+), 101 deletions(-) create mode 100644 tests/data/tip3p.xml diff --git a/dmff/api.py b/dmff/api.py index 8469f53e9..7c2e11a7d 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -308,18 +308,18 @@ def potential_fn(positions, box, pairs, params): def overwrite(self): - self.fftree.set_attrib(f'{self.name}', 'mScale12', self.paramtree[self.name]['mScales'][0]) - self.fftree.set_attrib(f'{self.name}', 'mScale13', self.paramtree[self.name]['mScales'][1]) - self.fftree.set_attrib(f'{self.name}', 'mScale14', self.paramtree[self.name]['mScales'][2]) - self.fftree.set_attrib(f'{self.name}', 'mScale15', self.paramtree[self.name]['mScales'][3]) - self.fftree.set_attrib(f'{self.name}', 'mScale16', self.paramtree[self.name]['mScales'][4]) - - self.fftree.set_attrib(f'{self.name}/Atom', 'A', self.paramtree[self.name]['A']) - self.fftree.set_attrib(f'{self.name}/Atom', 'B', self.paramtree[self.name]['B']) - self.fftree.set_attrib(f'{self.name}/Atom', 'Q', self.paramtree[self.name]['Q']) - self.fftree.set_attrib(f'{self.name}/Atom', 'C6', self.paramtree[self.name]['C6']) - self.fftree.set_attrib(f'{self.name}/Atom', 'C8', self.paramtree[self.name]['C8']) - self.fftree.set_attrib(f'{self.name}/Atom', 'C10', self.paramtree[self.name]['C10']) + self.fftree.set_attrib(f'{self.name}', 'mScale12', [self.paramtree[self.name]['mScales'][0]]) + self.fftree.set_attrib(f'{self.name}', 'mScale13', [self.paramtree[self.name]['mScales'][1]]) + self.fftree.set_attrib(f'{self.name}', 'mScale14', [self.paramtree[self.name]['mScales'][2]]) + self.fftree.set_attrib(f'{self.name}', 'mScale15', [self.paramtree[self.name]['mScales'][3]]) + self.fftree.set_attrib(f'{self.name}', 'mScale16', [self.paramtree[self.name]['mScales'][4]]) + + self.fftree.set_attrib(f'{self.name}/Atom', 'A', [self.paramtree[self.name]['A']]) + self.fftree.set_attrib(f'{self.name}/Atom', 'B', [self.paramtree[self.name]['B']]) + self.fftree.set_attrib(f'{self.name}/Atom', 'Q', [self.paramtree[self.name]['Q']]) + self.fftree.set_attrib(f'{self.name}/Atom', 'C6', [self.paramtree[self.name]['C6']]) + self.fftree.set_attrib(f'{self.name}/Atom', 'C8', [self.paramtree[self.name]['C8']]) + self.fftree.set_attrib(f'{self.name}/Atom', 'C10', [self.paramtree[self.name]['C10']]) def getJaxPotential(self): @@ -463,11 +463,11 @@ def extract(self): def overwrite(self): - self.fftree.set_attrib(f'{self.name}', 'mScale12', self.paramtree[self.name]['mScales'][0]) - self.fftree.set_attrib(f'{self.name}', 'mScale13', self.paramtree[self.name]['mScales'][1]) - self.fftree.set_attrib(f'{self.name}', 'mScale14', self.paramtree[self.name]['mScales'][2]) - self.fftree.set_attrib(f'{self.name}', 'mScale15', self.paramtree[self.name]['mScales'][3]) - self.fftree.set_attrib(f'{self.name}', 'mScale16', self.paramtree[self.name]['mScales'][4]) + self.fftree.set_attrib(f'{self.name}', 'mScale12', [self.paramtree[self.name]['mScales'][0]]) + self.fftree.set_attrib(f'{self.name}', 'mScale13', [self.paramtree[self.name]['mScales'][1]]) + self.fftree.set_attrib(f'{self.name}', 'mScale14', [self.paramtree[self.name]['mScales'][2]]) + self.fftree.set_attrib(f'{self.name}', 'mScale15', [self.paramtree[self.name]['mScales'][3]]) + self.fftree.set_attrib(f'{self.name}', 'mScale16', [self.paramtree[self.name]['mScales'][4]]) self.fftree.set_attrib(f'{self.name}/Atom', 'C6', self.paramtree[self.name]['C6']) self.fftree.set_attrib(f'{self.name}/Atom', 'C8', self.paramtree[self.name]['C8']) @@ -1380,23 +1380,23 @@ def extract(self): def overwrite(self): - self.fftree.set_attrib(f'{self.name}', 'mScale12', self.paramtree[self.name]['mScales'][0]) - self.fftree.set_attrib(f'{self.name}', 'mScale13', self.paramtree[self.name]['mScales'][1]) - self.fftree.set_attrib(f'{self.name}', 'mScale14', self.paramtree[self.name]['mScales'][2]) - self.fftree.set_attrib(f'{self.name}', 'mScale15', self.paramtree[self.name]['mScales'][3]) - self.fftree.set_attrib(f'{self.name}', 'mScale16', self.paramtree[self.name]['mScales'][4]) + self.fftree.set_attrib(f'{self.name}', 'mScale12', [self.paramtree[self.name]['mScales'][0]]) + self.fftree.set_attrib(f'{self.name}', 'mScale13', [self.paramtree[self.name]['mScales'][1]]) + self.fftree.set_attrib(f'{self.name}', 'mScale14', [self.paramtree[self.name]['mScales'][2]]) + self.fftree.set_attrib(f'{self.name}', 'mScale15', [self.paramtree[self.name]['mScales'][3]]) + self.fftree.set_attrib(f'{self.name}', 'mScale16', [self.paramtree[self.name]['mScales'][4]]) - self.fftree.set_attrib(f'{self.name}', 'pScale12', self.paramtree[self.name]['pScales'][0]) - self.fftree.set_attrib(f'{self.name}', 'pScale13', self.paramtree[self.name]['pScales'][1]) - self.fftree.set_attrib(f'{self.name}', 'pScale14', self.paramtree[self.name]['pScales'][2]) - self.fftree.set_attrib(f'{self.name}', 'pScale15', self.paramtree[self.name]['pScales'][3]) - self.fftree.set_attrib(f'{self.name}', 'pScale16', self.paramtree[self.name]['pScales'][4]) + self.fftree.set_attrib(f'{self.name}', 'pScale12', [self.paramtree[self.name]['pScales'][0]]) + self.fftree.set_attrib(f'{self.name}', 'pScale13', [self.paramtree[self.name]['pScales'][1]]) + self.fftree.set_attrib(f'{self.name}', 'pScale14', [self.paramtree[self.name]['pScales'][2]]) + self.fftree.set_attrib(f'{self.name}', 'pScale15', [self.paramtree[self.name]['pScales'][3]]) + self.fftree.set_attrib(f'{self.name}', 'pScale16', [self.paramtree[self.name]['pScales'][4]]) - self.fftree.set_attrib(f'{self.name}', 'dScale12', self.paramtree[self.name]['dScales'][0]) - self.fftree.set_attrib(f'{self.name}', 'dScale13', self.paramtree[self.name]['dScales'][1]) - self.fftree.set_attrib(f'{self.name}', 'dScale14', self.paramtree[self.name]['dScales'][2]) - self.fftree.set_attrib(f'{self.name}', 'dScale15', self.paramtree[self.name]['dScales'][3]) - self.fftree.set_attrib(f'{self.name}', 'dScale16', self.paramtree[self.name]['dScales'][4]) + self.fftree.set_attrib(f'{self.name}', 'dScale12', [self.paramtree[self.name]['dScales'][0]]) + self.fftree.set_attrib(f'{self.name}', 'dScale13', [self.paramtree[self.name]['dScales'][1]]) + self.fftree.set_attrib(f'{self.name}', 'dScale14', [self.paramtree[self.name]['dScales'][2]]) + self.fftree.set_attrib(f'{self.name}', 'dScale15', [self.paramtree[self.name]['dScales'][3]]) + self.fftree.set_attrib(f'{self.name}', 'dScale16', [self.paramtree[self.name]['dScales'][4]]) Q_global = convert_harm2cart(self.paramtree[self.name]['Q_local'], self.lmax) @@ -1413,9 +1413,12 @@ def overwrite(self): self.fftree.set_attrib(f'{self.name}/Atom', 'qYZ', Q_global[:, 9]) if self.lpol: - self.fftree.set_attrib(f'{self.name}/Polarize', 'polarizabilityXX', self.paramtree[self.name]['pol'][:, 0]) - self.fftree.set_attrib(f'{self.name}/Polarize', 'polarizabilityYY', self.paramtree[self.name]['pol'][:, 1]) - self.fftree.set_attrib(f'{self.name}/Polarize', 'polarizabilityZZ', self.paramtree[self.name]['pol'][:, 2]) + # self.paramtree[self.name]['pol']: every element is the mean value of XX YY ZZ + # get the number of polarize element + n_pol = len(self.paramtree[self.name]['pol']) + self.fftree.set_attrib(f'{self.name}/Polarize', 'polarizabilityXX', [self.paramtree[self.name]['pol'][0]] * n_pol) + self.fftree.set_attrib(f'{self.name}/Polarize', 'polarizabilityYY', [self.paramtree[self.name]['pol'][1]] * n_pol) + self.fftree.set_attrib(f'{self.name}/Polarize', 'polarizabilityZZ', [self.paramtree[self.name]['pol'][2]] * n_pol) self.fftree.set_attrib(f'{self.name}/Polarize', 'thole', self.paramtree[self.name]['tholes']) @@ -1672,7 +1675,7 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, self.pme_force = pme_force def potential_fn(positions, box, pairs, params): - + params = params['ADMPPmeForce'] mScales = params["mScales"] Q_local = params["Q_local"][map_atomtype] if self.lpol: @@ -1727,10 +1730,11 @@ def getPotentialFunc(self, names=[]): raise DMFFException("No DMFF function in this potential object.") def totalPE(positions, box, pairs, params): - totale = sum([ + totale_list = [ self.dmff_potentials[k](positions, box, pairs, params) for k in self.dmff_potentials.keys() if (len(names) == 0 or k in names) - ]) + ] + totale = jnp.sum(jnp.array(totale_list)) return totale return totalPE diff --git a/dmff/fftree.py b/dmff/fftree.py index 45cd77b31..be63defa1 100644 --- a/dmff/fftree.py +++ b/dmff/fftree.py @@ -1,9 +1,10 @@ import xml.etree.ElementTree as ET import xml.dom.minidom from dmff.utils import convertStr2Float, DMFFException -from typing import List +from typing import Dict, List, Union, TypeVar from itertools import permutations +value = TypeVar('value') # generic type: interpreted as either a number or str class SelectError(BaseException): pass @@ -54,7 +55,25 @@ def __init__(self, tag, **attrs): super().__init__(tag, **attrs) - def get_nodes(self, parser): + def get_nodes(self, parser:str)->List[Node]: + """ + get all nodes of a certain path + + Examples + -------- + >>> fftree.get_nodes('HarmonicBondForce/Bond') + >>> [, , ...] + + Parameters + ---------- + parser : str + a path to locate nodes + + Returns + ------- + List[Node] + a list of Node + """ steps = parser.split("/") val = self for nstep, step in enumerate(steps): @@ -69,7 +88,29 @@ def get_nodes(self, parser): val = val[0] return val - def get_attribs(self, parser, attrname): + def get_attribs(self, parser:str, attrname:Union[str, List[str]])->List[Union[value, List[value]]]: + """ + get all values of attributes of nodes which nodes matching certain path + + Examples: + --------- + >>> fftree.get_attribs('HarmonicBondForce/Bond', 'k') + >>> [2.0, 2.0, 2.0, 2.0, 2.0, ...] + >>> fftree.get_attribs('HarmonicBondForce/Bond', ['k', 'r0']) + >>> [[2.0, 1.53], [2.0, 1.53], ...] + + Parameters + ---------- + parser : str + a path to locate nodes + attrname : _type_ + attribute name or a list of attribute names of a node + + Returns + ------- + List[Union[float, str]] + a list of values of attributes + """ sel = self.get_nodes(parser) if isinstance(attrname, list): ret = [] @@ -81,14 +122,51 @@ def get_attribs(self, parser, attrname): attrs = [convertStr2Float(n.attrs[attrname]) if attrname in n.attrs else None for n in sel] return attrs - def set_node(self, parser, values): + def set_node(self, parser:str, values:List[Dict[str, value]])->None: + """ + set attributes of nodes which nodes matching certain path + + Parameters + ---------- + parser : str + path to locate nodes + values : List[Dict[str, value]] + a list of Dict[str, value], where value is any type can be convert to str of a number. + + Examples + -------- + >>> fftree.set_node('HarmonicBondForce/Bond', + [{'k': 2.0, 'r0': 1.53}, + {'k': 2.0, 'r0': 1.53}]) + """ nodes = self.get_nodes(parser) for nit in range(len(values)): for key in values[nit]: nodes[nit].attrs[key] = f"{values[nit][key]}" - def set_attrib(self, parser, attrname, values): - valdicts = [{attrname: i} for i in values] + def set_attrib(self, parser:str, attrname:str, values:Union[value, List[value]])->None: + """ + set ONE Attribute of nodes which nodes matching certain path + + Parameters + ---------- + parser : str + path to locate nodes + attrname : str + attribute name + values : Union[float, str, List[float, str]] + attribute value or a list of attribute values of a node + + Examples + -------- + >>> fftree.set_attrib('HarmonicBondForce/Bond', 'k', 2.0) + >>> fftree.set_attrib('HarmonicBondForce/Bond', 'k', [2.0, 2.0, 2.0, 2.0, 2.0]) + + """ + if len(values) == 0: + valdicts = [{attrname: values}] + else: + valdicts = [{attrname: i} for i in values] self.set_node(parser, valdicts) diff --git a/tests/data/tip3p.xml b/tests/data/tip3p.xml new file mode 100644 index 000000000..77de2c6cb --- /dev/null +++ b/tests/data/tip3p.xml @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/test_admp/test_compute.py b/tests/test_admp/test_compute.py index 444cf8d5e..79c060dee 100644 --- a/tests/test_admp/test_compute.py +++ b/tests/test_admp/test_compute.py @@ -12,7 +12,7 @@ class TestADMPAPI: """ Test ADMP related generators """ - @pytest.fixture(scope='class', name='potentials') + @pytest.fixture(scope='class', name='generators') def test_init(self): """load generators from XML file @@ -26,10 +26,11 @@ def test_init(self): H = Hamiltonian('tests/data/admp.xml') pdb = app.PDBFile('tests/data/water_dimer.pdb') potential = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5) + generators = H.getGenerators() - yield potential + yield generators - def test_ADMPPmeForce(self, potentials): + def test_ADMPPmeForce(self, generators): rc = 4.0 pdb = app.PDBFile('tests/data/water_dimer.pdb') @@ -41,8 +42,9 @@ def test_ADMPPmeForce(self, potentials): nblist.allocate(positions) pairs = nblist.pairs - pot = potentials.getPotentialFunc(names=['ADMPPmeForce']) - energy = pot(positions, box, pairs, potentials.params) + gen = generators[1] + pot = gen.getJaxPotential() + energy = pot(positions, box, pairs, gen.paramtree) def test_ADMPPmeForce_jit(self, generators): @@ -57,8 +59,8 @@ def test_ADMPPmeForce_jit(self, generators): nblist = NeighborList(box, rc) nblist.allocate(positions) pairs = nblist.pairs - - pot_pme = gen.getJaxPotential() - j_pot_pme = jit(value_and_grad(pot_pme)) - - E_pme, F_pme = j_pot_pme(positions, box, pairs, gen.params) \ No newline at end of file + + gen = generators[1] + pot = gen.getJaxPotential() + j_pot_pme = jit(value_and_grad(pot)) + energy = j_pot_pme(positions, box, pairs, gen.paramtree) diff --git a/tests/test_api.py b/tests/test_api.py index f47cb02ba..31786910d 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -40,13 +40,13 @@ def test_ADMPDispForce_parseXML(self, generators): def test_ADMPDispForce_renderXML(self, generators): gen = generators[0] - xml = gen.renderXML() + params = gen.paramtree['ADMPDispForce'] + gen.overwrite() - assert xml.name == 'ADMPDispForce' - npt.assert_allclose(float(xml[0]['type']), 380) - npt.assert_allclose(float(xml[0]['A']), 1203470.743) - npt.assert_allclose(float(xml[1]['B']), 37.78544799) - npt.assert_allclose(float(xml[1]['Q']), 0.370853) + assert gen.name == 'ADMPDispForce' + npt.assert_allclose(params['mScales'], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) + npt.assert_allclose(params['A'], [1203470.743, 83.2283563]) + npt.assert_allclose(params['B'], [37.81265679, 37.78544799]) def test_ADMPPmeForce_parseXML(self, generators): @@ -64,17 +64,16 @@ def test_ADMPPmeForce_parseXML(self, generators): def test_ADMPPmeForce_renderXML(self, generators): gen = generators[1] - xml = gen.renderXML() - - assert xml.name == 'ADMPPmeForce' - assert xml.attributes['lmax'] == '2' - assert xml.attributes['mScale12'] == '0.0' - assert xml.attributes['mScale15'] == '1.0' - assert xml.elements[0].name == 'Atom' - assert xml.elements[0].attributes['qXZ'] == '-0.07141020' - assert xml.elements[2].name == 'Polarize' - assert xml.elements[2].attributes['polarizabilityXX'][:6] == '0.8800' - assert xml[3]['type'] == '381' + params = gen.paramtree['ADMPPmeForce'] + gen.overwrite() + + npt.assert_allclose(params['mScales'], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) + npt.assert_allclose(params['pScales'], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) + npt.assert_allclose(params['dScales'], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) + # Q_local is already converted to local frame + # npt.assert_allclose(params['Q_local'][0][:4], [-1.0614, 0.0, 0.0, -0.023671684]) + npt.assert_allclose(params['pol'], [0.88000005, 0]) + npt.assert_allclose(params['tholes'], [8., 0.]) class TestClassicalAPI: @@ -99,64 +98,63 @@ def test_init(self): yield H.getGenerators() - def test_NonBond_parseXML(self, generators): + def test_Nonbond_parseXML(self, generators): gen = generators[0] - params = gen.params + params = gen.paramtree['NonbondedForce'] npt.assert_allclose(params['sigma'], [1.0, 1.0, -1.0, -1.0]) def test_NonBond_renderXML(self, generators): gen = generators[0] - xml = gen.renderXML() + params = gen.paramtree['NonbondedForce'] + gen.overwrite() - assert xml.name == 'NonbondedForce' - assert xml.attributes['lj14scale'] == '0.5' - assert xml[0]['type'] == 'n1' - assert xml[1]['sigma'] == '1.0' + assert gen.name == 'NonbondedForce' + # npt.assert_allclose(params['type'], ['n1', 'n2', 'n3', 'n4']) + # type is asigned to the generator themself as a member variable + npt.assert_allclose(params['charge'], [1.0, -1.0, 1.0, -1.0]) + npt.assert_allclose(params['sigma'], [1.0, 1.0, -1.0, -1.0]) + npt.assert_allclose(params['epsilon'], [0.0, 0.0, 0.0, 0.0]) + npt.assert_allclose(params['lj14scale'], [0.5]) + def test_HarmonicAngle_parseXML(self, generators): gen = generators[1] - params = gen.params + params = gen.paramtree['HarmonicAngleForce'] npt.assert_allclose(params['k'], 836.8) npt.assert_allclose(params['angle'], 1.8242181341844732) def test_HarmonicAngle_renderXML(self, generators): gen = generators[1] - xml = gen.renderXML() + params = gen.paramtree['HarmonicAngleForce'] + gen.overwrite() - assert xml.name == 'HarmonicAngleForce' - assert xml[0]['type1'] == 'n1' - assert xml[0]['type2'] == 'n2' - assert xml[0]['type3'] == 'n3' - assert xml[0]['angle'][:7] == '1.82421' - assert xml[0]['k'] == '836.8' + assert gen.name == 'HarmonicAngleForce' + npt.assert_allclose(params['angle'], [1.8242181341844732] * 2) + npt.assert_allclose(params['k'], [836.8] * 2) def test_PeriodicTorsion_parseXML(self, generators): gen = generators[2] - params = gen.params - npt.assert_allclose(params['psi1_p'], 0) - npt.assert_allclose(params['k1_p'], 2.092) + params = gen.paramtree['PeriodicTorsionForce'] + npt.assert_allclose(params['prop_phase']['1'], [0]) + npt.assert_allclose(params['prop_k']['1'], [2.092]) def test_PeriodicTorsion_renderXML(self, generators): gen = generators[2] - xml = gen.renderXML() - assert xml.name == 'PeriodicTorsionForce' - assert xml[0].name == 'Proper' - assert xml[0]['type1'] == 'n1' - assert xml[1].name == 'Improper' - assert xml[1]['type1'] == 'n1' + params = gen.paramtree['PeriodicTorsionForce'] + gen.overwrite() + assert gen.name == 'PeriodicTorsionForce' + npt.assert_allclose(params['prop_phase']['1'], [0]) + npt.assert_allclose(params['prop_k']['1'], [2.092]) def test_parse_multiple_files(self): pdb = app.PDBFile("tests/data/methane_water.pdb") - h = Hamiltonian("tests/data/methane.xml", "tip3p.xml") + h = Hamiltonian("tests/data/methane.xml", "tests/data/tip3p.xml") potentials = h.createPotential(pdb.topology) - npt.assert_allclose( - h.getGenerators()[-1].params["charge"], - [-0.1068, 0.0267, 0.0267, 0.0267, 0.0267, -0.834, 0.417] - ) + diff --git a/tests/test_classical/test_lj.py b/tests/test_classical/test_lj.py index 7bb8ba7eb..3f560b0d6 100644 --- a/tests/test_classical/test_lj.py +++ b/tests/test_classical/test_lj.py @@ -74,7 +74,7 @@ def test_lj_params_check(self): pairs = np.array(pairs, dtype=int) ljE = potential.getPotentialFunc() with pytest.raises(TypeError): - energy = ljE(pos, box, pairs, h.getGenerators()[0].params) + energy = ljE(pos, box, pairs, h.getGenerators()[0].paramtree) energy = jax.jit(ljE)(pos, box, pairs, h.paramtree) # jit will optimized away type check force = jax.grad(jax.jit(ljE))(pos, box, pairs, h.paramtree) \ No newline at end of file From e6346c9a36648528a5567a08280cd2b83163b8b9 Mon Sep 17 00:00:00 2001 From: Roy-Kid Date: Mon, 4 Jul 2022 23:21:19 +0800 Subject: [PATCH 5/5] doc: programming style convention about typing and numpy style docstring --- docs/dev_guide/convention.md | 363 ++++++++++++++++++++++++++++++++++- 1 file changed, 362 insertions(+), 1 deletion(-) diff --git a/docs/dev_guide/convention.md b/docs/dev_guide/convention.md index d4d18f455..bf114dbb2 100644 --- a/docs/dev_guide/convention.md +++ b/docs/dev_guide/convention.md @@ -23,4 +23,365 @@ Under the `dmff`, there are several files and sub-directory: ## 3.2 Programming Style -TBA \ No newline at end of file +In the project DMFF, the programming style is followed numpy docstring. A proper docstring of methods and classes would be generated pretty api doc. In order to make other developers understand the methods you provide and further maintain them, you'd better add type annotations to the externally provided APIs. Here is the doc of [python typing](https://docs.python.org/3/library/typing.html). + +Here is an brife exampl from [napoleon](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_numpy.html). + +```python +# -*- coding: utf-8 -*- +"""Example NumPy style docstrings. + +This module demonstrates documentation as specified by the `NumPy +Documentation HOWTO`_. Docstrings may extend over multiple lines. Sections +are created with a section header followed by an underline of equal length. + +Example +------- +Examples can be given using either the ``Example`` or ``Examples`` +sections. Sections support any reStructuredText formatting, including +literal blocks:: + + $ python example_numpy.py + + +Section breaks are created with two blank lines. Section breaks are also +implicitly created anytime a new section starts. Section bodies *may* be +indented: + +Notes +----- + This is an example of an indented section. It's like any other section, + but the body is indented to help it stand out from surrounding text. + +If a section is indented, then a section break is created by +resuming unindented text. + +Attributes +---------- +module_level_variable1 : int + Module level variables may be documented in either the ``Attributes`` + section of the module docstring, or in an inline docstring immediately + following the variable. + + Either form is acceptable, but the two should not be mixed. Choose + one convention to document module level variables and be consistent + with it. + + +.. _NumPy Documentation HOWTO: + https://github.com/numpy/numpy/blob/master/doc/HOWTO_DOCUMENT.rst.txt + +""" + +module_level_variable1 = 12345 + +module_level_variable2 = 98765 +"""int: Module level variable documented inline. + +The docstring may span multiple lines. The type may optionally be specified +on the first line, separated by a colon. +""" + + +def function_with_types_in_docstring(param1, param2): + """Example function with types documented in the docstring. + + `PEP 484`_ type annotations are supported. If attribute, parameter, and + return types are annotated according to `PEP 484`_, they do not need to be + included in the docstring: + + Parameters + ---------- + param1 : int + The first parameter. + param2 : str + The second parameter. + + Returns + ------- + bool + True if successful, False otherwise. + + .. _PEP 484: + https://www.python.org/dev/peps/pep-0484/ + + """ + + +def function_with_pep484_type_annotations(param1: int, param2: str) -> bool: + """Example function with PEP 484 type annotations. + + The return type must be duplicated in the docstring to comply + with the NumPy docstring style. + + Parameters + ---------- + param1 + The first parameter. + param2 + The second parameter. + + Returns + ------- + bool + True if successful, False otherwise. + + """ + + +def module_level_function(param1, param2=None, *args, **kwargs): + """This is an example of a module level function. + + Function parameters should be documented in the ``Parameters`` section. + The name of each parameter is required. The type and description of each + parameter is optional, but should be included if not obvious. + + If \*args or \*\*kwargs are accepted, + they should be listed as ``*args`` and ``**kwargs``. + + The format for a parameter is:: + + name : type + description + + The description may span multiple lines. Following lines + should be indented to match the first line of the description. + The ": type" is optional. + + Multiple paragraphs are supported in parameter + descriptions. + + Parameters + ---------- + param1 : int + The first parameter. + param2 : :obj:`str`, optional + The second parameter. + *args + Variable length argument list. + **kwargs + Arbitrary keyword arguments. + + Returns + ------- + bool + True if successful, False otherwise. + + The return type is not optional. The ``Returns`` section may span + multiple lines and paragraphs. Following lines should be indented to + match the first line of the description. + + The ``Returns`` section supports any reStructuredText formatting, + including literal blocks:: + + { + 'param1': param1, + 'param2': param2 + } + + Raises + ------ + AttributeError + The ``Raises`` section is a list of all exceptions + that are relevant to the interface. + ValueError + If `param2` is equal to `param1`. + + """ + if param1 == param2: + raise ValueError('param1 may not be equal to param2') + return True + + +def example_generator(n): + """Generators have a ``Yields`` section instead of a ``Returns`` section. + + Parameters + ---------- + n : int + The upper limit of the range to generate, from 0 to `n` - 1. + + Yields + ------ + int + The next number in the range of 0 to `n` - 1. + + Examples + -------- + Examples should be written in doctest format, and should illustrate how + to use the function. + + >>> print([i for i in example_generator(4)]) + [0, 1, 2, 3] + + """ + for i in range(n): + yield i + + +class ExampleError(Exception): + """Exceptions are documented in the same way as classes. + + The __init__ method may be documented in either the class level + docstring, or as a docstring on the __init__ method itself. + + Either form is acceptable, but the two should not be mixed. Choose one + convention to document the __init__ method and be consistent with it. + + Note + ---- + Do not include the `self` parameter in the ``Parameters`` section. + + Parameters + ---------- + msg : str + Human readable string describing the exception. + code : :obj:`int`, optional + Numeric error code. + + Attributes + ---------- + msg : str + Human readable string describing the exception. + code : int + Numeric error code. + + """ + + def __init__(self, msg, code): + self.msg = msg + self.code = code + + +class ExampleClass(object): + """The summary line for a class docstring should fit on one line. + + If the class has public attributes, they may be documented here + in an ``Attributes`` section and follow the same formatting as a + function's ``Args`` section. Alternatively, attributes may be documented + inline with the attribute's declaration (see __init__ method below). + + Properties created with the ``@property`` decorator should be documented + in the property's getter method. + + Attributes + ---------- + attr1 : str + Description of `attr1`. + attr2 : :obj:`int`, optional + Description of `attr2`. + + """ + + def __init__(self, param1, param2, param3): + """Example of docstring on the __init__ method. + + The __init__ method may be documented in either the class level + docstring, or as a docstring on the __init__ method itself. + + Either form is acceptable, but the two should not be mixed. Choose one + convention to document the __init__ method and be consistent with it. + + Note + ---- + Do not include the `self` parameter in the ``Parameters`` section. + + Parameters + ---------- + param1 : str + Description of `param1`. + param2 : :obj:`list` of :obj:`str` + Description of `param2`. Multiple + lines are supported. + param3 : :obj:`int`, optional + Description of `param3`. + + """ + self.attr1 = param1 + self.attr2 = param2 + self.attr3 = param3 #: Doc comment *inline* with attribute + + #: list of str: Doc comment *before* attribute, with type specified + self.attr4 = ["attr4"] + + self.attr5 = None + """str: Docstring *after* attribute, with type specified.""" + + @property + def readonly_property(self): + """str: Properties should be documented in their getter method.""" + return "readonly_property" + + @property + def readwrite_property(self): + """:obj:`list` of :obj:`str`: Properties with both a getter and setter + should only be documented in their getter method. + + If the setter method contains notable behavior, it should be + mentioned here. + """ + return ["readwrite_property"] + + @readwrite_property.setter + def readwrite_property(self, value): + value + + def example_method(self, param1, param2): + """Class methods are similar to regular functions. + + Note + ---- + Do not include the `self` parameter in the ``Parameters`` section. + + Parameters + ---------- + param1 + The first parameter. + param2 + The second parameter. + + Returns + ------- + bool + True if successful, False otherwise. + + """ + return True + + def __special__(self): + """By default special members with docstrings are not included. + + Special members are any methods or attributes that start with and + end with a double underscore. Any special member with a docstring + will be included in the output, if + ``napoleon_include_special_with_doc`` is set to True. + + This behavior can be enabled by changing the following setting in + Sphinx's conf.py:: + + napoleon_include_special_with_doc = True + + """ + pass + + def __special_without_docstring__(self): + pass + + def _private(self): + """By default private members are not included. + + Private members are any methods or attributes that start with an + underscore and are *not* special. By default they are not included + in the output. + + This behavior can be changed such that private members *are* included + by changing the following setting in Sphinx's conf.py:: + + napoleon_include_private_with_doc = True + + """ + pass + + def _private_without_docstring(self): + pass +``` \ No newline at end of file