Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,14 +608,16 @@ def get(self):
try:
future = self.queue.get(block=True)
inputs = future.get(timeout=30)
self.queue.task_done()
except mp.TimeoutError:
idx = future.idx
warnings.warn(
'The input {} could not be retrieved.'
' It could be because a worker has died.'.format(idx),
UserWarning)
inputs = self.sequence[idx]
finally:
self.queue.task_done()

if inputs is not None:
yield inputs
except Exception:
Expand Down
46 changes: 46 additions & 0 deletions tests/keras/utils/data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import tarfile
import threading
import signal
import shutil
import zipfile
from itertools import cycle
Expand Down Expand Up @@ -211,6 +212,25 @@ def on_epoch_end(self):
pass


class SlowSequence(Sequence):
def __init__(self, shape, value=1.0):
self.shape = shape
self.inner = value
self.wait = True

def __getitem__(self, item):
if self.wait:
self.wait = False
time.sleep(40)
return np.ones(self.shape, dtype=np.uint32) * item * self.inner

def __len__(self):
return 10

def on_epoch_end(self):
pass


@threadsafe_generator
def create_generator_from_sequence_threads(ds):
for i in cycle(range(len(ds))):
Expand Down Expand Up @@ -335,6 +355,32 @@ def test_ordered_enqueuer_fail_threads():
next(gen_output)


def test_ordered_enqueuer_timeout_threads():
enqueuer = OrderedEnqueuer(SlowSequence([3, 10, 10, 3]),
use_multiprocessing=False)

def handler(signum, frame):
raise TimeoutError('Sequence deadlocked')

old = signal.signal(signal.SIGALRM, handler)
signal.setitimer(signal.ITIMER_REAL, 60)
with pytest.warns(UserWarning) as record:
enqueuer.start(5, 10)
gen_output = enqueuer.get()
for epoch_num in range(2):
acc = []
for i in range(10):
acc.append(next(gen_output)[0, 0, 0, 0])
assert acc == list(range(10)), 'Order was not keep in ' \
'OrderedEnqueuer with threads'
enqueuer.stop()
assert len(record) == 1
assert str(record[0].message) == 'The input 0 could not be retrieved. ' \
'It could be because a worker has died.'
signal.setitimer(signal.ITIMER_REAL, 0)
signal.signal(signal.SIGALRM, old)


@use_spawn
def test_on_epoch_end_processes():
enqueuer = OrderedEnqueuer(DummySequence([3, 10, 10, 3]),
Expand Down