Skip to content

Commit

Permalink
Convert mBuild custom group names (#686)
Browse files Browse the repository at this point in the history
* [pre-commit.ci] pre-commit autoupdate (#673)

updates:
- [github.com/psf/black: 22.3.0 → 22.6.0](psf/black@22.3.0...22.6.0)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Revert to older versions of forcefield

* Add import for xml_representation in parameteric potential"

* Added tests for generating a forcefield object from a GMSO topology

* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci

* Remove unused import

* pin unyt to version 2.8

* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci

* Custom Groups in convert from_mbuild

* Switch error to warning upon adding extra group labels in mbuild_conversion

* Add test to check for mBuild conversion of a compound that would have no bond graph due to being part of a larger compound hierarchy

* Update gmso/external/convert_mbuild.py

Co-authored-by: Co Quach <43968221+daico007@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Co Quach <43968221+daico007@users.noreply.github.com>
Co-authored-by: Co Quach <daico007@gmail.com>
  • Loading branch information
4 people committed Aug 17, 2022
1 parent 5912c64 commit 76e8f20
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 19 deletions.
99 changes: 80 additions & 19 deletions gmso/external/convert_mbuild.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@
element_by_symbol,
)
from gmso.core.topology import Topology
from gmso.exceptions import GMSOError
from gmso.utils.io import has_mbuild

if has_mbuild:
import mbuild as mb


def from_mbuild(
compound, box=None, search_method=element_by_symbol, parse_label=True
compound,
box=None,
search_method=element_by_symbol,
parse_label=True,
custom_groups=None,
):
"""Convert an mbuild.Compound to a gmso.Topology.
Expand Down Expand Up @@ -63,38 +68,43 @@ def from_mbuild(
parse_label : bool, optional, default=True
Option to parse hierarchy info of the compound into system of top label,
including, group, molecule and residue labels.
custom_groups : list or str, optional, default=None
Allows user to identify the groups assigned to each site in the topology
based on the compound.name attributes found traversing down the hierarchy. Be
sure to supply names such that every particle will be pass through one
matching name on the way down from compound.children. Only the first match
while moving downwards will be assigned to the site. If parse_label=False,
this argument does nothing.
Returns
-------
top : gmso.Topology
"""
msg = "Argument compound is not an mbuild.Compound"
assert isinstance(compound, mb.Compound), msg
msg = "Compound is not a top level compound. Make a copy to pass to the `compound` \
argument that has no parents"
assert not compound.parent, msg

top = Topology()
top.typed = False

site_map = {
particle: {"site": None, "residue": None, "molecule": None}
particle: {
"site": None,
"residue": None,
"molecule": None,
"group": None,
}
for particle in compound.particles()
}
if parse_label:
_parse_label(site_map, compound)
_parse_molecule_residue(site_map, compound)
_parse_group(site_map, compound, custom_groups)

if compound.children:
for child in compound.children:
if not child.children:
site = _parse_site(site_map, child, search_method)
site.group = compound.name
top.add_site(site)
else:
for particle in child.particles():
site = _parse_site(site_map, particle, search_method)
site.group = child.name
top.add_site(site)
else:
site = _parse_site(site_map, compound, search_method)
site.group = compound.name
# Use site map to apply Compound info to Topology.
for part in compound.particles():
site = _parse_site(site_map, part, search_method)
top.add_site(site)

for b1, b2 in compound.bonds():
Expand Down Expand Up @@ -256,12 +266,13 @@ def _parse_site(site_map, particle, search_method):
mass=mass,
molecule=site_map[particle]["molecule"],
residue=site_map[particle]["residue"],
group=site_map[particle]["group"],
)
site_map[particle]["site"] = site
return site


def _parse_label(site_map, compound):
def _parse_molecule_residue(site_map, compound):
"""Parse information necessary for residue and molecule labels when converting from mbuild."""
connected_subgraph = compound.bond_graph.connected_components()
molecule_tracker = dict()
Expand Down Expand Up @@ -312,4 +323,54 @@ def _parse_label(site_map, compound):
molecule_number,
)

return site_map

def _parse_group(site_map, compound, custom_groups):
"""Parse group information."""
if custom_groups:
if isinstance(custom_groups, str):
custom_groups = [custom_groups]
elif not hasattr(custom_groups, "__iter__"):
raise TypeError(
f"Please pass groups {custom_groups} as a list of strings."
)
elif not np.all([isinstance(g, str) for g in custom_groups]):
raise TypeError(
f"Please pass groups {custom_groups} as a list of strings."
)
for part in _traverse_down_hierarchy(compound, custom_groups):
for particle in part.particles():
site_map[particle]["group"] = part.name
try:
applied_groups = set(map(lambda x: x["group"], site_map.values()))
assert applied_groups == set(custom_groups)
except AssertionError:
warn(
f"""Not all custom groups ({custom_groups}, is are being used when
traversing compound hierachy. Only {applied_groups} are used.)"""
)
elif not compound.children:
for particle in compound.particles():
site_map[particle]["group"] = compound.name
elif not np.any(
list(map(lambda c: len(c.children), compound.children))
): # compound is a 2 level hierarchy
for particle in compound.particles():
site_map[particle]["group"] = compound.name
else: # set compund name to se
for child in compound.children:
for particle in child.particles():
site_map[particle]["group"] = child.name


def _traverse_down_hierarchy(compound, group_names):
if compound.name in group_names:
yield compound
elif compound.children:
for child in compound.children:
yield from _traverse_down_hierarchy(child, group_names)
else:
raise GMSOError(
f"""A particle named {compound.name} cannot be associated with the
custom_groups {group_names}. Be sure to specify a list of group names that will cover
all particles in the compound. This particle is one level below {compound.parent.name}."""
)
90 changes: 90 additions & 0 deletions gmso/tests/test_convert_mbuild.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import gmso
from gmso.core.atom import Atom
from gmso.core.topology import Topology as Top
from gmso.exceptions import GMSOError
from gmso.external.convert_mbuild import from_mbuild, to_mbuild
from gmso.tests.base_test import BaseTest
from gmso.utils.io import get_fn, has_mbuild
Expand Down Expand Up @@ -183,3 +184,92 @@ def test_group_2_level_compound(self):
top = from_mbuild(filled_box)
for site in top.sites:
assert site.group == filled_box.name

@pytest.mark.skipif(not has_mbuild, reason="mBuild is not installed")
def test_custom_groups_from_compound(self):
mb_cpd1 = mb.Compound(name="_CH4")

first_bead = mb.Compound(name="_CH3")
middle_bead = mb.Compound(name="_CH2")
last_bead = mb.Compound(name="_CH3")
mb_cpd2 = mb.Compound(name="Alkane")
[mb_cpd2.add(cpd) for cpd in [first_bead, middle_bead, last_bead]]
mb_cpd2.add_bond((first_bead, middle_bead))
mb_cpd2.add_bond((last_bead, middle_bead))

mb_cpd3 = mb.load("O", smiles=True)
mb_cpd3.name = "O"

filled_box1 = mb.fill_box(
[mb_cpd1, mb_cpd2], n_compounds=[2, 2], box=[1, 1, 1]
)
filled_box1.name = "box1"
filled_box2 = mb.fill_box(mb_cpd3, n_compounds=2, box=[1, 1, 1])
filled_box2.name = "box2"

top_box = mb.Compound()
top_box.add(filled_box1)
top_box.add(filled_box2)
top_box.name = "top"

list_of_groups = [
(["top"], [14]), # top level of hierarchy
(["box1", "box2"], [8, 6]), # middle level of hierarchy
(["_CH4", "_CH2", "_CH3", "O"], [2, 2, 4, 6]), # particle level
(
["box2", "Alkane", "_CH4"],
[6, 6, 2],
), # multiple different levels
]
for groups, n_groups in list_of_groups:
top = from_mbuild(top_box, custom_groups=groups)
assert np.all([site.group in groups for site in top.sites])
for n, gname in zip(n_groups, groups):
assert (
len([True for site in top.sites if site.group == gname])
== n
)

@pytest.mark.skipif(not has_mbuild, reason="mBuild is not installed")
def test_single_custom_group(self):
mb_cpd1 = mb.Compound(name="_CH4")
mb_cpd2 = mb.Compound(name="_CH3")
filled_box = mb.fill_box(
[mb_cpd1, mb_cpd2], n_compounds=[2, 2], box=[1, 1, 1]
)
filled_box.name = "box1"

top = from_mbuild(filled_box, custom_groups=filled_box.name)
assert (
len([True for site in top.sites if site.group == filled_box.name])
== filled_box.n_particles
)

@pytest.mark.skipif(not has_mbuild, reason="mBuild is not installed")
def test_bad_custom_groups_from_compound(self):
mb_cpd1 = mb.Compound(name="_CH4")
mb_cpd2 = mb.Compound(name="_CH3")
filled_box = mb.fill_box(
[mb_cpd1, mb_cpd2], n_compounds=[2, 2], box=[1, 1, 1]
)

with pytest.warns(Warning):
top = from_mbuild(
filled_box, custom_groups=["_CH4", "_CH3", "_CH5"]
)

with pytest.raises(GMSOError):
top = from_mbuild(filled_box, custom_groups=["_CH4"])

with pytest.raises(TypeError):
top = from_mbuild(filled_box, custom_groups=mb_cpd1)

with pytest.raises(TypeError):
top = from_mbuild(filled_box, custom_groups=[mb_cpd1])

@pytest.mark.skipif(not has_mbuild, reason="mBuild is not installed")
def test_nontop_level_compound(self, mb_ethane):
cpd = mb.Compound(name="top")
cpd.add(mb_ethane)
with pytest.raises(AssertionError):
from_mbuild(mb_ethane)

0 comments on commit 76e8f20

Please sign in to comment.