diff --git a/foyer/tests/files/ethane-multiple.xml b/foyer/tests/files/ethane-multiple.xml new file mode 100644 index 00000000..88ada44d --- /dev/null +++ b/foyer/tests/files/ethane-multiple.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/foyer/tests/test_forcefield.py b/foyer/tests/test_forcefield.py index fcd6fd74..071495ac 100644 --- a/foyer/tests/test_forcefield.py +++ b/foyer/tests/test_forcefield.py @@ -376,6 +376,18 @@ def test_write_xml_multiple_periodictorsions(filename): assert 'k2' in periodic_element[0].attrib assert 'phase2' in periodic_element[0].attrib +@pytest.mark.parametrize("filename", ['ethane.mol2', 'benzene.mol2']) +def test_load_xml(filename): + mol = pmd.load_file(get_fn(filename), structure=True) + if filename == 'ethane.mol2': + ff = Forcefield(get_fn('ethane-multiple.xml')) + else: + ff = Forcefield(name='oplsaa') + typed = ff.apply(mol) + typed.write_foyer(filename='snippet.xml', forcefield=ff, unique=True) + + generated_ff = Forcefield('snippet.xml') + def test_write_xml_overrides(): #Test xml_writer new overrides and comments features mol = pmd.load_file(get_fn('styrene.mol2'), structure=True) diff --git a/foyer/xml_writer.py b/foyer/xml_writer.py index e0f9ad2c..8eadc500 100644 --- a/foyer/xml_writer.py +++ b/foyer/xml_writer.py @@ -2,6 +2,9 @@ import collections from lxml import etree as ET +from foyer.smarts_graph import SMARTSGraph +import networkx as nx +import warnings import numpy as np @@ -124,6 +127,37 @@ def _write_atoms(self, root, atoms, forcefield, unique): nb_force.set('sigma', str(round(atom.atom_type.sigma/10, 4))) nb_force.set('epsilon', str(round(atom.atom_type.epsilon * 4.184, 6))) + _update_defs(atomtypes, nonbonded, forcefield) + +def _update_defs(atomtypes, nonbonded, forcefield): + def_list = [i.get('def') for i in atomtypes.iterchildren()] + name_list = [i.get('name') for i in atomtypes.iterchildren()] + smarts_list = list() + smarts_parser = forcefield.parser + for smarts_string, name in zip(def_list, name_list): + smarts_graph = SMARTSGraph(smarts_string, parser=smarts_parser, + name=name) + for atom_expr in nx.get_node_attributes(smarts_graph, name='atom').values(): + labels = atom_expr.find_data('has_label') + for label in labels: + atom_type = label.children[0][1:] + smarts_list.append(atom_type) + smarts_list = list(set(smarts_list)) + extra_types = [i for i in smarts_list if i not in name_list] + + for extra in extra_types: + for i, definition in enumerate(def_list): + if extra in definition: + warnings.warn('Removing undefined atom type `{}`' + ' from SMARTS string `{}`'.format( + extra, definition)) + extra_edit = '%' + extra + extra_index = definition.find(extra_edit) + if definition[extra_index-1] == ';': + new_def = definition.replace(extra_edit + ',' , '') + else: + new_def = definition.replace(',' + extra_edit, '') + atomtypes[i].set('def', new_def) def _write_bonds(root, bonds, unique): bond_forces = ET.SubElement(root, 'HarmonicBondForce')