diff --git a/vermouth/processors/processor.py b/vermouth/processors/processor.py index 47d172e1..d1cca4aa 100644 --- a/vermouth/processors/processor.py +++ b/vermouth/processors/processor.py @@ -16,13 +16,15 @@ """ Provides an abstract base class for processors. """ - +import multiprocessing class Processor: """ An abstract base class for processors. Subclasses must implement a `run_molecule` method. + Has nproc attribute that is changed by a CLI flag """ + nproc = 1 def run_system(self, system): """ Process `system`. @@ -32,10 +34,14 @@ def run_system(self, system): system: vermouth.system.System The system to process. Is modified in-place. """ - mols = [] - for molecule in system.molecules: - mols.append(self.run_molecule(molecule)) - system.molecules = mols + if hasattr(self, 'nproc') and self.nproc > 1: + pool = multiprocessing.Pool(self.nproc) + system.molecules = pool.map(self.run_molecule, system.molecules) + else: + mols = [] + for molecule in system.molecules: + mols.append(self.run_molecule(molecule)) + system.molecules = mols def run_molecule(self, molecule): """ diff --git a/vermouth/tests/test_make_bonds.py b/vermouth/tests/test_make_bonds.py index 78ce2ec4..6dfef8bc 100644 --- a/vermouth/tests/test_make_bonds.py +++ b/vermouth/tests/test_make_bonds.py @@ -181,9 +181,21 @@ def test_make_bonds(nodes, edges, expected_edges): mol.add_nodes_from(enumerate(node_set)) mol.add_edges_from(edge_set) system.add_molecule(mol) + system_mp = system.copy() + MakeBonds().run_system(system) # Make sure number of connected components is the same assert len(system.molecules) == len(expected_edges) # Make sure that for every molecule found, the edges are correct for found_mol, ref_edges in zip(system.molecules, expected_edges): assert dict(found_mol.edges) == ref_edges + + # Also test making bonds with multiprocessing + mb = MakeBonds() + mb.nproc = 2 + mb.run_system(system) + # Make sure number of connected components is the same + assert len(system.molecules) == len(expected_edges) + # Make sure that for every molecule found, the edges are correct + for found_mol, ref_edges in zip(system.molecules, expected_edges): + assert dict(found_mol.edges) == ref_edges