Skip to content

Commit

Permalink
Change syntax to remove ghost ports (#995)
Browse files Browse the repository at this point in the history
* Change syntax to remove ghost ports

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* also remove label from when removing port

* add unit test

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
daico007 and pre-commit-ci[bot] committed Mar 23, 2022
1 parent 41025a2 commit 2725815
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
19 changes: 13 additions & 6 deletions mbuild/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,17 +833,24 @@ def _check_if_empty(child):
removed_part.parent.children.remove(removed_part)
self._remove_references(removed_part)

# Remove ghost ports
all_ports_list = list(self.all_ports())
for port in all_ports_list:
if port.anchor not in [i for i in self.particles()]:
port.parent.children.remove(port)

# Check and reorder rigid id
for _ in particles_to_remove:
if self.contains_rigid:
self.root._reorder_rigid_ids()

# Remove ghsot ports
self._prune_ghost_ports()

def _prune_ghost_ports(self):
"""Worker for remove(). Remove all ports whose anchor has been deleted."""
all_ports_list = list(self.all_ports())
particles = list(self.particles())
for port in all_ports_list:
if port.anchor not in particles:
self._remove(port)
port.parent.children.remove(port)
self._remove_references(port)

def _remove(self, removed_part):
"""Worker for remove(). Fixes rigid IDs and removes bonds."""
if removed_part.rigid_id is not None:
Expand Down
14 changes: 14 additions & 0 deletions mbuild/tests/test_compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,6 +1278,20 @@ def test_add_bond_remove_ports(self, hydrogen):
assert len(hydrogen.all_ports()) == 0
assert len(h_clone.all_ports()) == 0

def test_pruning_ghost_ports(self, ethane):
eth1 = ethane
h1 = eth1.particles_by_name("H")
eth1.remove(h1)
assert len(eth1.all_ports()) == 6
for port in eth1.all_ports():
assert port.anchor in list(eth1.particles())

eth2 = mb.load("CC", smiles=True)
h2 = eth2[-1]
eth2.remove(h2)
assert len(eth2.all_ports()) == len(eth2.available_ports()) == 1
assert eth2.all_ports()[0] is eth2.available_ports()[0]

def test_remove_bond_add_ports(self, hydrogen):
h_clone = mb.clone(hydrogen)
h2 = Compound(subcompounds=(hydrogen, h_clone))
Expand Down

0 comments on commit 2725815

Please sign in to comment.