diff --git a/keras/utils/data_utils.py b/keras/utils/data_utils.py index 179e8861f24a..238daaaa7502 100644 --- a/keras/utils/data_utils.py +++ b/keras/utils/data_utils.py @@ -608,7 +608,6 @@ def get(self): try: future = self.queue.get(block=True) inputs = future.get(timeout=30) - self.queue.task_done() except mp.TimeoutError: idx = future.idx warnings.warn( @@ -616,6 +615,9 @@ def get(self): ' It could be because a worker has died.'.format(idx), UserWarning) inputs = self.sequence[idx] + finally: + self.queue.task_done() + if inputs is not None: yield inputs except Exception: diff --git a/tests/keras/utils/data_utils_test.py b/tests/keras/utils/data_utils_test.py index ff87efe76575..5dea7fe644b9 100644 --- a/tests/keras/utils/data_utils_test.py +++ b/tests/keras/utils/data_utils_test.py @@ -5,6 +5,7 @@ import sys import tarfile import threading +import signal import shutil import zipfile from itertools import cycle @@ -211,6 +212,25 @@ def on_epoch_end(self): pass +class SlowSequence(Sequence): + def __init__(self, shape, value=1.0): + self.shape = shape + self.inner = value + self.wait = True + + def __getitem__(self, item): + if self.wait: + self.wait = False + time.sleep(40) + return np.ones(self.shape, dtype=np.uint32) * item * self.inner + + def __len__(self): + return 10 + + def on_epoch_end(self): + pass + + @threadsafe_generator def create_generator_from_sequence_threads(ds): for i in cycle(range(len(ds))): @@ -335,6 +355,32 @@ def test_ordered_enqueuer_fail_threads(): next(gen_output) +def test_ordered_enqueuer_timeout_threads(): + enqueuer = OrderedEnqueuer(SlowSequence([3, 10, 10, 3]), + use_multiprocessing=False) + + def handler(signum, frame): + raise TimeoutError('Sequence deadlocked') + + old = signal.signal(signal.SIGALRM, handler) + signal.setitimer(signal.ITIMER_REAL, 60) + with pytest.warns(UserWarning) as record: + enqueuer.start(5, 10) + gen_output = enqueuer.get() + for epoch_num in range(2): + acc = [] + for i in range(10): + acc.append(next(gen_output)[0, 0, 0, 0]) + assert acc == list(range(10)), 'Order was not keep in ' \ + 'OrderedEnqueuer with threads' + enqueuer.stop() + assert len(record) == 1 + assert str(record[0].message) == 'The input 0 could not be retrieved. ' \ + 'It could be because a worker has died.' + signal.setitimer(signal.ITIMER_REAL, 0) + signal.signal(signal.SIGALRM, old) + + @use_spawn def test_on_epoch_end_processes(): enqueuer = OrderedEnqueuer(DummySequence([3, 10, 10, 3]),