From 2d26b9da56826dd0b9b6f7ef2ee5afb9fba062ea Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Mon, 23 May 2022 09:20:51 +0800 Subject: [PATCH 1/5] feat:rewrite renderXML with very clean commit history --- dmff/__init__.py | 5 +- dmff/api.py | 182 ++++++++++++++++++++++-- dmff/settings.py | 1 + tests/data/admp.xml | 44 ++++++ tests/data/classical.xml | 34 +++++ tests/data/linear.pdb | 7 + tests/data/water_dimer.pdb | 13 ++ tests/{ => test_admp}/test_multipole.py | 0 tests/{ => test_admp}/test_sptial.py | 0 tests/test_api.py | 153 ++++++++++++++++++++ tests/{ => test_common}/test_nblist.py | 0 11 files changed, 425 insertions(+), 14 deletions(-) create mode 100644 tests/data/admp.xml create mode 100644 tests/data/classical.xml create mode 100644 tests/data/linear.pdb create mode 100644 tests/data/water_dimer.pdb rename tests/{ => test_admp}/test_multipole.py (100%) rename tests/{ => test_admp}/test_sptial.py (100%) create mode 100644 tests/test_api.py rename tests/{ => test_common}/test_nblist.py (100%) diff --git a/dmff/__init__.py b/dmff/__init__.py index e44570345..cfc4bd2e3 100644 --- a/dmff/__init__.py +++ b/dmff/__init__.py @@ -1,2 +1,3 @@ -import dmff.settings -from dmff.common.nblist import NeighborList \ No newline at end of file +from .settings import * +from .common.nblist import NeighborList +from .api import Hamiltonian \ No newline at end of file diff --git a/dmff/api.py b/dmff/api.py index d8c7c24a5..a0100d47d 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -10,7 +10,7 @@ from dmff.utils import isinstance_jnp from .admp.disp_pme import ADMPDispPmeForce -from .admp.multipole import convert_cart2harm +from .admp.multipole import convert_cart2harm, convert_harm2cart from .admp.pairwise import TT_damping_qq_c6_kernel, generate_pairwise_interaction from .admp.pairwise import slater_disp_damping_kernel, slater_sr_kernel, TT_damping_qq_kernel from .admp.pme import ADMPPmeForce @@ -30,9 +30,26 @@ CoulReactionFieldForce, ) import sys +from copy import deepcopy class XMLNodeInfo: + + @staticmethod + def to_str(value)->str: + """ convert value to string if it can + """ + if isinstance(value, str): + return value + elif isinstance(value, (jnp.ndarray, np.ndarray)): + if value.ndim == 0: + return str(value) + else: + return str(value[0]) + elif isinstance(value, list): + return value[0] # strip [] of value + else: + return str(value) class XMLElementInfo: @@ -41,29 +58,49 @@ def __init__(self, name): self.attributes = {} def addAttribute(self, key, value): - self.attributes[key] = value + self.attributes[key] = XMLNodeInfo.to_str(value) + + def __repr__(self): + return f'<{self.name} {" ".join([f"{k}={v}" for k, v in self.attributes.items()])}>' + + def __getitem__(self, name): + return self.attributes[name] def __init__(self, name): self.name = name self.attributes = {} self.elements = [] + + def __getitem__(self, name): + if isinstance(name, str): + return self.attributes[name] + elif isinstance(name, int): + return self.elements[name] def addAttribute(self, key, value): - self.attributes[key] = value + self.attributes[key] = XMLNodeInfo.to_str(value) def addElement(self, name, info): element = self.XMLElementInfo(name) for k, v in info.items(): element.addAttribute(k, v) - self.elements.append(element) + self.elements.append(element) def modResidue(self, residue, atom, key, value): pass + def __repr__(self): + # tricy string formatting + left = f'<{self.name} {" ".join([f"{k}={v}" for k, v in self.attributes.items()])}> \n\t' + right = f'<\\{self.name}>' + content = '\n\t'.join([repr(e) for e in self.elements]) + return left + content + '\n' + right + + def get_line_context(file_path, line_number): return linecache.getline(file_path, line_number).strip() @@ -202,7 +239,18 @@ def getJaxPotential(self): def renderXML(self): # generate xml force field file - pass + finfo = XMLNodeInfo('ADMPDispForce') + finfo.addAttribute('mScale12', self.params["mScales"][0]) + finfo.addAttribute('mScale13', self.params["mScales"][1]) + finfo.addAttribute('mScale14', self.params["mScales"][2]) + finfo.addAttribute('mScale15', self.params["mScales"][3]) + finfo.addAttribute('mScale16', self.params["mScales"][4]) + + for i in range(len(self.types)): + ainfo = {'type': self.types[i], 'A': self.params["A"][i], 'B': self.params["B"][i], 'Q': self.params["Q"][i], 'C6': self.params["C6"][i], 'C8': self.params["C8"][i], 'C10': self.params["C10"][i]} + finfo.addElement('Atom', ainfo) + + return finfo # register all parsers app.forcefield.parsers["ADMPDispForce"] = ADMPDispGenerator.parseElement @@ -700,6 +748,7 @@ def parseElement(element, hamiltonian): generator.types = np.array(generator.types) n_atoms = len(element.findall("Atom")) + generator.n_atoms = n_atoms # map atom multipole moments if generator.lmax == 0: @@ -1041,7 +1090,41 @@ def getJaxPotential(self): return self._jaxPotential def renderXML(self): - pass + # + + finfo = XMLNodeInfo('ADMPPmeForce') + finfo.addAttribute('lmax', str(self.lmax)) + outputparams = deepcopy(self.params) + mScales = outputparams.pop('mScales') + pScales = outputparams.pop('pScales') + dScales = outputparams.pop('dScales') + for i in range(len(mScales)): + finfo.addAttribute(f'mScale1{i+2}', str(mScales[i])) + for i in range(len(pScales)): + finfo.addAttribute(f'pScale{i+1}', str(pScales[i])) + for i in range(len(dScales)): + finfo.addAttribute(f'dScale{i+1}', str(dScales[i])) + + Q = outputparams['Q_local'] + Q_global = convert_harm2cart(Q, self.lmax) + + # + for atom in range(self.n_atoms): + info = {'type': self.map_atomtype[atom]} + info.update({ktype:self.kStrings[ktype][atom] for ktype in ['kz', 'kx', 'ky']}) + for i, key in enumerate(['c0', 'dX', 'dY', 'dZ', 'qXX', 'qXY', 'qXZ', 'qYY', 'qYZ', 'qZZ']): + info[key] = "%.8f" % Q_global[atom][i] + finfo.addElement('Atom', info) + + # + for t in range(len(self.types)): + info = { + 'type': self.types[t] + } + info.update({p: "%.8f" % self.params['pol'][t] for p in ['polarizabilityXX', 'polarizabilityYY', 'polarizabilityZZ']}) + finfo.addElement('Polarize', info) + + return finfo app.forcefield.parsers["ADMPPmeForce"] = ADMPPmeGenerator.parseElement @@ -1169,8 +1252,8 @@ def parseElement(element, hamiltonian): """ generator = HarmonicAngleJaxGenerator(hamiltonian) hamiltonian.registerGenerator(generator) - for bondtype in element.findall("Angle"): - generator.registerAngleType(bondtype.attrib) + for angletype in element.findall("Angle"): + generator.registerAngleType(angletype.attrib) def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): @@ -1224,13 +1307,20 @@ def getJaxPotential(self): def renderXML(self): # generate xml force field file - pass + finfo = XMLNodeInfo("HarmonicAngleForce") + for i, type in enumerate(self.types): + t1, t2, t3 = type + ainfo = {'type1': t1, 'type2': t2, 'type3': t3, 'k': self.params['k'][i], 'angle': self.params['angle'][i]} + finfo.addElement('Angle', ainfo) + + return finfo # register all parsers app.forcefield.parsers["HarmonicAngleForce"] = HarmonicAngleJaxGenerator.parseElement + def _matchImproper(data, torsion, generator): type1 = data.atomType[data.atoms[torsion[0]]] type2 = data.atomType[data.atoms[torsion[1]]] @@ -1437,7 +1527,7 @@ def parseElement(element, ff): - + """ @@ -1774,7 +1864,51 @@ def getJaxPotential(self): def renderXML(self): # generate xml force field file - pass + finfo = XMLNodeInfo('PeriodicTorsionForce') + for i in range(len(self.proper)): + proper = self.proper[i] + + finfo.addElement('Proper', + {'type1': proper.types1, 'type2': proper.types2, + 'type3': proper.types3, 'type4': proper.types4, + 'periodicity1': proper.periodicity[0], + 'phase1': proper.phase[0], + 'k1': proper.k[0], + 'periodicity2': proper.periodicity[1], + 'phase2': proper.phase[1], + 'k2': proper.k[1], + 'periodicity3': proper.periodicity[2], + 'phase3': proper.phase[2], + 'k3': proper.k[2], + 'periodicity4': proper.periodicity[3], + 'phase4': proper.phase[3], + 'k4': proper.k[3], + } + ) + + for i in range(len(self.improper)): + + improper = self.improper[i] + + finfo.addElement('Improper', + {'type1': improper.types1, 'type2': improper.types2, + 'type3': improper.types3, 'type4': improper.types4, + 'periodicity1': improper.periodicity[0], + 'phase1': improper.phase[0], + 'k1': improper.k[0], + 'periodicity2': improper.periodicity[1], + 'phase2': improper.phase[1], + 'k2': improper.k[1], + 'periodicity3': improper.periodicity[2], + 'phase3': improper.phase[2], + 'k3': improper.k[2], + 'periodicity4': improper.periodicity[3], + 'phase4': improper.phase[3], + 'k4': improper.k[3], + } + ) + + return finfo app.forcefield.parsers[ @@ -1862,6 +1996,13 @@ def parseElement(element, ff): generator.useAttributeFromResidue.append(eprm) for atom in element.findall("Atom"): generator.registerAtom(atom.attrib) + + generator.n_atoms = len(element.findall("Atom")) + + # jax it! + for k in generator.params.keys(): + generator.params[k] = jnp.array(generator.params[k]) + generator.types = np.array(generator.types) def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): @@ -1942,6 +2083,7 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): map_nbfix = [] # implement it later map_nbfix = np.array(map_nbfix, dtype=int).reshape((-1, 2)) + colv_map = build_covalent_map(data, 6) @@ -1960,6 +2102,11 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): else: r_switch = r_cut ifSwitch = False + + map_lj = jnp.array(map_lj) + map_nbfix = jnp.array(map_nbfix) + map_charge = jnp.array(map_charge) + ljforce = LennardJonesForce( r_switch, r_cut, @@ -2012,12 +2159,23 @@ def getJaxPotential(self): return self._jaxPotential def renderXML(self): - pass + + # + finfo = XMLNodeInfo('NonbondedForce') + finfo.addAttribute('coulomb14scale', str(self.coulomb14scale)) + finfo.addAttribute('lj14scale', str(self.lj14scale)) + + for atom in range(self.n_atoms): + info = {'type': self.types[atom], 'charge': self.params['charge'][atom], 'sigma': self.params['sigma'][atom], 'epsilon': self.params['epsilon'][atom]} + finfo.addElement('Atom', info) + + return finfo app.forcefield.parsers["NonbondedForce"] = NonbondJaxGenerator.parseElement + class Hamiltonian(app.forcefield.ForceField): def __init__(self, *xmlnames): super().__init__(*xmlnames) diff --git a/dmff/settings.py b/dmff/settings.py index c7965bc2d..22845c92a 100644 --- a/dmff/settings.py +++ b/dmff/settings.py @@ -7,3 +7,4 @@ if PRECISION == 'double': config.update("jax_enable_x64", True) +__all__ = ['PRECISION', 'DO_JIT'] \ No newline at end of file diff --git a/tests/data/admp.xml b/tests/data/admp.xml new file mode 100644 index 000000000..0dc44c980 --- /dev/null +++ b/tests/data/admp.xml @@ -0,0 +1,44 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/classical.xml b/tests/data/classical.xml new file mode 100644 index 000000000..4788423ea --- /dev/null +++ b/tests/data/classical.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/data/linear.pdb b/tests/data/linear.pdb new file mode 100644 index 000000000..b5ef61a07 --- /dev/null +++ b/tests/data/linear.pdb @@ -0,0 +1,7 @@ +HETATM 1 N1 LIG A 1 0.000 1.000 0.000 1.00 0.00 N +HETATM 2 N2 LIG A 1 1.000 0.000 0.000 1.00 0.00 N +HETATM 3 N3 LIG A 1 0.000 1.000 1.000 1.00 0.00 N +HETATM 4 N4 LIG A 1 1.000 0.000 1.000 1.00 0.00 N +CONECT 1 2 +CONECT 2 3 +CONECT 3 4 \ No newline at end of file diff --git a/tests/data/water_dimer.pdb b/tests/data/water_dimer.pdb new file mode 100644 index 000000000..75668c11e --- /dev/null +++ b/tests/data/water_dimer.pdb @@ -0,0 +1,13 @@ +REMARK 1 CREATED WITH OPENMM 7.3, 2018-10-03 +CRYST1 31.289 31.289 31.289 90.00 90.00 90.00 P 1 1 +MODEL 1 +HETATM 1 O HOH A 1 12.434 3.404 1.540 1.00 0.00 O +HETATM 2 H1 HOH A 1 13.030 2.664 1.322 1.00 0.00 H +HETATM 3 H2 HOH A 1 12.312 3.814 0.660 1.00 0.00 H +HETATM 4 O HOH A 2 14.216 1.424 1.103 1.00 0.00 O +HETATM 5 H1 HOH A 2 14.246 1.144 2.054 1.00 0.00 H +HETATM 6 H2 HOH A 2 15.155 1.542 0.910 1.00 0.00 H +TER 7 HOH A 2 +ENDMDL +END + diff --git a/tests/test_multipole.py b/tests/test_admp/test_multipole.py similarity index 100% rename from tests/test_multipole.py rename to tests/test_admp/test_multipole.py diff --git a/tests/test_sptial.py b/tests/test_admp/test_sptial.py similarity index 100% rename from tests/test_sptial.py rename to tests/test_admp/test_sptial.py diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 000000000..8e27e47f7 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,153 @@ +import openmm.app as app +import openmm.unit as unit +import numpy as np +import jax.numpy as jnp +import numpy.testing as npt +import pytest +from dmff import Hamiltonian + +class TestADMPAPI: + + """ Test ADMP related generators + """ + + @pytest.fixture(scope='class', name='generators') + def test_init(self): + """load generators from XML file + + Yields: + Tuple: ( + ADMPDispForce, + ADMPPmeForce, # polarized + ) + """ + rc = 4.0 + H = Hamiltonian('tests/data/admp.xml') + pdb = app.PDBFile('tests/data/water_dimer.pdb') + H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom) + + yield H.getGenerators() + + def test_ADMPDispForce_parseXML(self, generators): + + gen = generators[0] + params = gen.params + + npt.assert_allclose(params['mScales'], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) + npt.assert_allclose(params['A'], [1203470.743, 83.2283563]) + npt.assert_allclose(params['B'], [37.81265679, 37.78544799]) + + def test_ADMPDispForce_renderXML(self, generators): + + gen = generators[0] + xml = gen.renderXML() + + assert xml.name == 'ADMPDispForce' + npt.assert_allclose(float(xml[0]['type']), 380) + npt.assert_allclose(float(xml[0]['A']), 1203470.743) + npt.assert_allclose(float(xml[1]['B']), 37.78544799) + npt.assert_allclose(float(xml[1]['Q']), 0.370853) + + def test_ADMPPmeForce_parseXML(self, generators): + + gen = generators[1] + params = gen.params + + npt.assert_allclose(params['mScales'], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) + npt.assert_allclose(params['pScales'], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) + npt.assert_allclose(params['dScales'], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) + # Q_local is already converted to local frame + # npt.assert_allclose(params['Q_local'][0][:4], [-1.0614, 0.0, 0.0, -0.023671684]) + npt.assert_allclose(params['pol'], [0.88000005, 0]) + npt.assert_allclose(params['tholes'], [8., 0.]) + + def test_ADMPPmeForce_renderXML(self, generators): + + gen = generators[1] + xml = gen.renderXML() + + assert xml.name == 'ADMPPmeForce' + assert xml.attributes['lmax'] == '2' + assert xml.attributes['mScale12'] == '0.0' + assert xml.attributes['mScale15'] == '1.0' + assert xml.elements[0].name == 'Atom' + assert xml.elements[0].attributes['qXZ'] == '-0.07141020' + assert xml.elements[2].name == 'Polarize' + assert xml.elements[2].attributes['polarizabilityXX'][:6] == '0.8800' + assert xml[3]['type'] == '381' + +class TestClassicalAPI: + + """ Test classical forcefield generators + """ + + @pytest.fixture(scope='class', name='generators') + def test_init(self): + """load generators from XML file + + Yields: + Tuple: ( + NonBondJaxGenerator, + HarmonicAngle, + PeriodicTorsionForce, + ) + """ + rc = 4.0 + H = Hamiltonian('tests/data/classical.xml') + pdb = app.PDBFile('tests/data/linear.pdb') + H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom) + + yield H.getGenerators() + + def test_NonBond_parseXML(self, generators): + + gen = generators[0] + params = gen.params + npt.assert_allclose(params['sigma'], [1.0, 1.0, -1.0, -1.0]) + + + def test_NonBond_renderXML(self, generators): + + gen = generators[0] + xml = gen.renderXML() + + assert xml.name == 'NonbondedForce' + assert xml.attributes['lj14scale'] == '0.5' + assert xml[0]['type'] == 'n1' + assert xml[1]['sigma'] == '1.0' + + def test_HarmonicAngle_parseXML(self, generators): + + gen = generators[1] + params = gen.params + npt.assert_allclose(params['k'], 836.8) + npt.assert_allclose(params['angle'], 1.8242181341844732) + + def test_HarmonicAngle_renderXML(self, generators): + + gen = generators[1] + xml = gen.renderXML() + + assert xml.name == 'HarmonicAngleForce' + assert xml[0]['type1'] == 'n1' + assert xml[0]['type2'] == 'n2' + assert xml[0]['type3'] == 'n3' + assert xml[0]['angle'][:7] == '1.82421' + assert xml[0]['k'] == '836.8' + + def test_PeriodicTorsion_parseXML(self, generators): + + gen = generators[2] + params = gen.params + npt.assert_allclose(params['psi1_p'], 0) + npt.assert_allclose(params['k1_p'], 2.092) + + def test_PeriodicTorsion_renderXML(self, generators): + + gen = generators[2] + xml = gen.renderXML() + assert xml.name == 'PeriodicTorsionForce' + assert xml[0].name == 'Proper' + assert xml[0]['type1'] == 'n1' + assert xml[1].name == 'Improper' + assert xml[1]['type1'] == 'n1' \ No newline at end of file diff --git a/tests/test_nblist.py b/tests/test_common/test_nblist.py similarity index 100% rename from tests/test_nblist.py rename to tests/test_common/test_nblist.py From 9c4b663aec4c47acd21abc2adc5f331a7eb48d71 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Mon, 23 May 2022 10:54:40 +0800 Subject: [PATCH 2/5] fix bug in Torsion renderXML --- dmff/api.py | 43 +++++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/dmff/api.py b/dmff/api.py index a0100d47d..4f2a3679f 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -1496,6 +1496,8 @@ def __init__(self, hamiltonian): self.proper = [] self.improper = [] self.propersForAtomType = defaultdict(set) + self.n_proper = 0 + self.n_improper = 0 def registerProperTorsion(self, parameters): torsion = _parseTorsion(self.ff, parameters) @@ -1863,26 +1865,27 @@ def getJaxPotential(self): return self._jaxPotential def renderXML(self): + params = self.params # generate xml force field file finfo = XMLNodeInfo('PeriodicTorsionForce') for i in range(len(self.proper)): proper = self.proper[i] - + finfo.addElement('Proper', {'type1': proper.types1, 'type2': proper.types2, 'type3': proper.types3, 'type4': proper.types4, 'periodicity1': proper.periodicity[0], - 'phase1': proper.phase[0], - 'k1': proper.k[0], + 'phase1': params['psi1_p'][i], + 'k1': params['k1_p'][i], 'periodicity2': proper.periodicity[1], - 'phase2': proper.phase[1], - 'k2': proper.k[1], + 'phase2': params['psi2_p'][i], + 'k2': params['k2_p'][i], 'periodicity3': proper.periodicity[2], - 'phase3': proper.phase[2], - 'k3': proper.k[2], + 'phase3': params['psi3_p'][2], + 'k3': params['k3_p'][2], 'periodicity4': proper.periodicity[3], - 'phase4': proper.phase[3], - 'k4': proper.k[3], + 'phase4': params['psi4_p'][3], + 'k4': params['k4_p'][3], } ) @@ -1894,17 +1897,17 @@ def renderXML(self): {'type1': improper.types1, 'type2': improper.types2, 'type3': improper.types3, 'type4': improper.types4, 'periodicity1': improper.periodicity[0], - 'phase1': improper.phase[0], - 'k1': improper.k[0], - 'periodicity2': improper.periodicity[1], - 'phase2': improper.phase[1], - 'k2': improper.k[1], - 'periodicity3': improper.periodicity[2], - 'phase3': improper.phase[2], - 'k3': improper.k[2], - 'periodicity4': improper.periodicity[3], - 'phase4': improper.phase[3], - 'k4': improper.k[3], + 'phase1': params['psi1_i'][i], + 'k1': params['k1_i'][i], + 'periodicity2': proper.periodicity[1], + 'phase2': params['psi2_i'][i], + 'k2': params['k2_i'][i], + 'periodicity3': proper.periodicity[2], + 'phase3': params['psi3_i'][2], + 'k3': params['k3_i'][2], + 'periodicity4': proper.periodicity[3], + 'phase4': params['psi4_i'][3], + 'k4': params['k4_i'][3], } ) From b6c83b8ffc5bfe78485511ebd36b9e74497e9477 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Mon, 23 May 2022 10:57:01 +0800 Subject: [PATCH 3/5] add test_utils as a placehold --- tests/test_utils.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/test_utils.py diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..e69de29bb From 37e9ccaa58959aba11389a71c2a7f17ecab90cf8 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Mon, 23 May 2022 11:01:02 +0800 Subject: [PATCH 4/5] fix: fix typo in api.py --- dmff/api.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/dmff/api.py b/dmff/api.py index 4f2a3679f..3431843a1 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -1881,11 +1881,11 @@ def renderXML(self): 'phase2': params['psi2_p'][i], 'k2': params['k2_p'][i], 'periodicity3': proper.periodicity[2], - 'phase3': params['psi3_p'][2], - 'k3': params['k3_p'][2], + 'phase3': params['psi3_p'][i], + 'k3': params['k3_p'][i], 'periodicity4': proper.periodicity[3], - 'phase4': params['psi4_p'][3], - 'k4': params['k4_p'][3], + 'phase4': params['psi4_p'][i], + 'k4': params['k4_p'][i], } ) @@ -1903,11 +1903,11 @@ def renderXML(self): 'phase2': params['psi2_i'][i], 'k2': params['k2_i'][i], 'periodicity3': proper.periodicity[2], - 'phase3': params['psi3_i'][2], - 'k3': params['k3_i'][2], + 'phase3': params['psi3_i'][i], + 'k3': params['k3_i'][i], 'periodicity4': proper.periodicity[3], - 'phase4': params['psi4_i'][3], - 'k4': params['k4_i'][3], + 'phase4': params['psi4_i'][i], + 'k4': params['k4_i'][i], } ) From 1a6b3686ddee2b8d50b7ebc35c784ef4d15e0b43 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Mon, 23 May 2022 11:11:33 +0800 Subject: [PATCH 5/5] docs: add renderXML related api usages --- docs/dev_guide/arch.md | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/docs/dev_guide/arch.md b/docs/dev_guide/arch.md index e7e4811d1..30f5f5b4a 100644 --- a/docs/dev_guide/arch.md +++ b/docs/dev_guide/arch.md @@ -179,7 +179,7 @@ class SimpleJAXGenerator: return self._jaxPotential def renderXML(self): - render_xml_forcefield_from_params + # render_xml_forcefield_from_params app.parsers["SimpleJAXForce"] = SimpleJAXGenerator.parseElement @@ -291,6 +291,39 @@ class HarmonicBondJaxGenerator: app.forcefield.parsers["HarmonicBondForce"] = HarmonicBondJaxGenerator.parseElement ``` + After the calculation and optimization, we need to save the optimized parameters as XML format files for the next calculation. This serialization process is implemented through the `renderXML` method. At the beginning of the `api.py` file, we provide nested helper classes called `XMLNodeInfo` and `XMLElementInfo`. In the XML file, a `` and its close tag is represented by XMLNodeInfo and the content element is controlled by `XMLElementInfo` + +``` + + + + +``` + + When we want to serialize optimized parameters from the generator to a new XML file, we first initialize a `XMLNodeInfo(name:str)` class with the potential name + +```python +finfo = XMLNodeInfo("HarmonicBondForce") +``` + If necessary, you can add attributes to this tag using the `addAttribute(name:str, value:str)` method. Then we add the inner `` tag by invoke `finfo.addElement(name:str, attrib:dict)` method. Here is an example to render `` + +``` + def renderXML(self): + # generate xml force field file + finfo = XMLNodeInfo("HarmonicBondForce") # and <\HarmonicBondForce> + for ntype in range(len(self.types)): + binfo = {} + k1, v1 = self.typetexts[ntype][0] + k2, v2 = self.typetexts[ntype][1] + binfo[k1] = v1 + binfo[k2] = v2 + for key in self.params.keys(): + binfo[key] = "%.8f"%self.params[key][ntype] + finfo.addElement("Bond", binfo) # + return finfo +``` + + ## How Backend Works ### Force Class