From 47c91496dcdbeedb888f74edcc9981216b038d75 Mon Sep 17 00:00:00 2001 From: jeanollion Date: Tue, 23 Apr 2024 09:59:10 +0200 Subject: [PATCH] - ordered_enqueuer_cf.py : use semaphore instead of queue and modified pool shutdown --- dataset_iterator/ordered_enqueuer_cf.py | 72 +++++++++++++++---------- setup.py | 2 +- 2 files changed, 44 insertions(+), 30 deletions(-) diff --git a/dataset_iterator/ordered_enqueuer_cf.py b/dataset_iterator/ordered_enqueuer_cf.py index 0b89600..042600d 100644 --- a/dataset_iterator/ordered_enqueuer_cf.py +++ b/dataset_iterator/ordered_enqueuer_cf.py @@ -7,6 +7,7 @@ import threading import time from multiprocessing import managers, shared_memory +from threading import BoundedSemaphore # adapted from https://github.com/keras-team/keras/blob/v2.13.1/keras/utils/data_utils.py#L651-L776 # uses concurrent.futures, solves a memory leak in case of hard sample mining run as callback with regular orderedEnqueur. Option to pass tensors through shared memory @@ -49,6 +50,7 @@ def __init__(self, sequence, shuffle=False, single_epoch:bool=False, use_shm:boo self.run_thread = None self.stop_signal = None self.shm_manager = None + self.semaphore = None def is_running(self): return self.stop_signal is not None and not self.stop_signal.is_set() @@ -62,7 +64,10 @@ def start(self, workers=1, max_queue_size=10): (when full, workers could block on `put()`) """ self.workers = workers - self.queue = queue.Queue(max_queue_size) + if max_queue_size <= 0: + max_queue_size = self.workers + self.semaphore = BoundedSemaphore(max_queue_size) + self.queue = [] self.stop_signal = threading.Event() if self.use_shm: self.shm_manager = managers.SharedMemoryManager() @@ -71,12 +76,16 @@ def start(self, workers=1, max_queue_size=10): self.run_thread.daemon = True self.run_thread.start() - def _wait_queue(self): + def _wait_queue(self, empty:bool): """Wait for the queue to be empty.""" while True: - time.sleep(0.1) - if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set(): + if (empty and len(self.queue) == 0) or (not empty and len(self.queue) > 0) or self.stop_signal.is_set(): return + time.sleep(0.1) + + def _task_done(self, _): + """Called once task is done, releases the queue if blocked.""" + self.semaphore.release() def _run(self): """Submits request to the executor and queue the `Future` objects.""" @@ -87,15 +96,19 @@ def _run(self): if self.shuffle: random.shuffle(sequence) task = get_item_shm if self.use_shm else get_item - with ProcessPoolExecutor(max_workers=self.workers, initializer=init_pool_generator, initargs=(self.sequence, self.uid, self.shm_manager)) as executor: - for i in sequence: - if self.stop_signal.is_set(): - return - future = executor.submit(task, self.uid, i) - self.queue.put(future, block=True) - # Done with the current epoch, waiting for the final batches - self._wait_queue() - + executor = ProcessPoolExecutor(max_workers=self.workers, initializer=init_pool_generator, initargs=(self.sequence, self.uid, self.shm_manager)) + for idx, i in enumerate(sequence): + if self.stop_signal.is_set(): + return + self.semaphore.acquire() + future = executor.submit(task, self.uid, i) + self.queue.append((future, i)) + # Done with the current epoch, waiting for the final batches + self._wait_queue(True) # safer to wait before calling shutdown than calling directly shutdown with wait=True + print("exiting from ProcessPoolExecutor...", flush=True) + time.sleep(0.1) + executor.shutdown(wait=False, cancel_futures=True) + print("exiting from ProcessPoolExecutor done", flush=True) if self.stop_signal.is_set() or self.single_epoch: # We're done return @@ -124,20 +137,23 @@ def get(self): `(inputs, targets, sample_weights)`. """ while self.is_running(): - try: - inputs = self.queue.get(block=True, timeout=5).result() - if self.is_running(): - self.queue.task_done() - if inputs is not None: + self._wait_queue(False) + if len(self.queue) > 0: + future, i = self.queue[0] + try: + inputs = future.result() + self.queue.pop(0) # only remove after result() is called to avoid terminating pool while a process is still running if self.use_shm: inputs = from_shm(*inputs) - yield inputs - except queue.Empty: - pass - except Exception as e: - self.stop() - print("Exception raised while getting future", flush=True) - raise e + self.semaphore.release() # release is done here and not as a future callback to limit effective number of samples in memory + except Exception as e: + self.stop() + print(f"Exception raised while getting future result from task: {i}", flush=True) + raise e + finally: + future.cancel() + del future + yield inputs def stop(self, timeout=None): """Stops running threads and wait for them to exit, if necessary. @@ -148,14 +164,12 @@ def stop(self, timeout=None): timeout: maximum time to wait on `thread.join()` """ self.stop_signal.set() - with self.queue.mutex: - self.queue.queue.clear() - self.queue.unfinished_tasks = 0 - self.queue.not_full.notify() self.run_thread.join(timeout) if self.use_shm is not None: self.shm_manager.shutdown() self.shm_manager.join() + self.queue = None + self.semaphore = None global _SHARED_SHM_MANAGER _SHARED_SHM_MANAGER[self.uid] = None global _SHARED_SEQUENCES diff --git a/setup.py b/setup.py index a4167ae..dc9bad2 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/jeanollion/dataset_iterator.git", - download_url='https://github.com/jeanollion/dataset_iterator/releases/download/v0.4.0/dataset_iterator-0.4.1.tar.gz', + download_url='https://github.com/jeanollion/dataset_iterator/releases/download/v0.4.1/dataset_iterator-0.4.1.tar.gz', keywords=['Iterator', 'Dataset', 'Image', 'Numpy'], packages=setuptools.find_packages(), classifiers=[ #https://pypi.org/classifiers/