diff --git a/dataflows/processors/parallelize.py b/dataflows/processors/parallelize.py index 79f4aec..f731800 100644 --- a/dataflows/processors/parallelize.py +++ b/dataflows/processors/parallelize.py @@ -1,15 +1,12 @@ import itertools -from re import match -from build.lib.dataflows.helpers.resource_matcher import ResourceMatcher -from build.lib.dataflows.base.package_wrapper import PackageWrapper -from build.lib.dataflows.base.resource_wrapper import ResourceWrapper import os import multiprocessing as mp import threading import queue -from .. import Flow from ..helpers import ResourceMatcher +from .. import PackageWrapper, ResourceWrapper + def init_mp(num_processors, row_func, q_in, q_internal): q_out = mp.Queue() @@ -22,13 +19,13 @@ def init_mp(num_processors, row_func, q_in, q_internal): def fini_mp(processes, t_fetch): - for i, process in enumerate(processes): + for process in processes: try: process.join(timeout=10) - except Exception as e: + except Exception: try: process.kill() - except: + except Exception: pass finally: process.close() @@ -44,7 +41,7 @@ def producer(res, q_in, q_internal, num_processors, predicate): q_internal.put(row) for _ in range(num_processors): q_in.put(None) - except Exception as e: + except Exception: q_internal.put(None) return 1 return 0 @@ -64,21 +61,19 @@ def fetcher(q_out, q_internal, num_processors): def work(q_in: mp.Queue, q_out: mp.Queue, row_func): - count = 0 pid = os.getpid() try: while True: row = q_in.get() if row is None: break - count += 1 try: row_func(row) except Exception as e: print(pid, 'FAILED TO RUN row_func {}\n'.format(e)) pass q_out.put(row) - except Exception as e: + except Exception: pass finally: q_out.put(None) @@ -88,25 +83,21 @@ def fork(res, row_func, num_processors, predicate): predicate = predicate or (lambda x: True) for row in res: if predicate(row): - res = itertools.chain([row], res) - q_in = mp.Queue() - q_internal = queue.Queue() - t_prod = threading.Thread(target=producer, args=(res, q_in, q_internal, num_processors, predicate)) - t_prod.start() - - processes, t_fetch = init_mp(num_processors, row_func, q_in, q_internal) - - count = 0 - while True: - row = q_internal.get() - if row is None: - break - count += 1 - if count % 100 == 0: - print('fork got %d rows' % count) - yield row - t_prod.join() - fini_mp(processes, t_fetch) + res = itertools.chain([row], res) + q_in = mp.Queue() + q_internal = queue.Queue() + t_prod = threading.Thread(target=producer, args=(res, q_in, q_internal, num_processors, predicate)) + t_prod.start() + + processes, t_fetch = init_mp(num_processors, row_func, q_in, q_internal) + + while True: + row = q_internal.get() + if row is None: + break + yield row + t_prod.join() + fini_mp(processes, t_fetch) else: yield row @@ -117,12 +108,12 @@ def parallelize(row_func, num_processors=None, resources=None, predicate=None): def func(package: PackageWrapper): yield package.pkg matcher = ResourceMatcher(resources, package.pkg) - + res: ResourceWrapper for res in package: if matcher.match(res.res.name): yield fork(res, row_func, num_processors, predicate) else: yield res - + return func