Skip to content

Commit

Permalink
revise concurrent [async_]mapping/flattening flow: wait for a result …
Browse files Browse the repository at this point in the history
…then queue tasks then yield back control
  • Loading branch information
ebonnal committed Jun 9, 2024
1 parent 40e2c2d commit d07ed88
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 21 deletions.
55 changes: 36 additions & 19 deletions streamable/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,15 @@ def __init__(
def __iter__(self) -> Iterator[Union[U, _RaisingIterator.ExceptionContainer]]:
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
futures: Deque[Future] = deque()
# queue and yield (FIFO)
element_to_yield: List[Union[U, _RaisingIterator.ExceptionContainer]] = []
# wait, queue, yield (FIFO)
while True:
if futures:
try:
element_to_yield.append(futures.popleft().result())
except Exception as e:
element_to_yield.append(_RaisingIterator.ExceptionContainer(e))

# queue tasks up to buffer_size
while len(futures) < self.buffer_size:
try:
Expand All @@ -288,12 +295,10 @@ def __iter__(self) -> Iterator[Union[U, _RaisingIterator.ExceptionContainer]]:
# the upstream iterator is exhausted
break
futures.append(executor.submit(self.func, elem))
if element_to_yield:
yield element_to_yield.pop()
if not futures:
break
try:
yield futures.popleft().result()
except Exception as e:
yield _RaisingIterator.ExceptionContainer(e)


class _AsyncConcurrentMappingIterable(
Expand Down Expand Up @@ -328,8 +333,13 @@ def __iter__(self) -> Iterator[Union[U, _RaisingIterator.ExceptionContainer]]:
awaitables: Deque[
asyncio.Task[Union[U, _RaisingIterator.ExceptionContainer]]
] = deque()
# queue and yield (FIFO)
element_to_yield: List[Union[U, _RaisingIterator.ExceptionContainer]] = []
# wait, queue, yield (FIFO)
while True:
if awaitables:
element_to_yield.append(
self._LOOP.run_until_complete(awaitables.popleft())
)
# queue tasks up to buffer_size
while len(awaitables) < self.buffer_size:
try:
Expand All @@ -338,9 +348,10 @@ def __iter__(self) -> Iterator[Union[U, _RaisingIterator.ExceptionContainer]]:
# the upstream iterator is exhausted
break
awaitables.append(self._LOOP.create_task(self._safe_func(elem)))
if element_to_yield:
yield element_to_yield.pop()
if not awaitables:
break
yield self._LOOP.run_until_complete(awaitables.popleft())


class _ConcurrentFlatteningIterable(
Expand All @@ -356,11 +367,26 @@ def __init__(
self.concurrency = concurrency
self.buffer_size = buffer_size

def _requeue(self, iterator: Iterator[T]) -> None:
self.iterables_iterator = itertools.chain(self.iterables_iterator, [iterator])

def __iter__(self) -> Iterator[Union[T, _RaisingIterator.ExceptionContainer]]:
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
iterator_and_future_pairs: Deque[Tuple[Iterator[T], Future]] = deque()
# queue and yield (FIFO)
element_to_yield: List[Union[T, _RaisingIterator.ExceptionContainer]] = []
# wait, queue, yield (FIFO)
while True:
if iterator_and_future_pairs:
iterator, future = iterator_and_future_pairs.popleft()
try:
element_to_yield.append(future.result())
self._requeue(iterator)
except StopIteration:
pass
except Exception as e:
element_to_yield.append(_RaisingIterator.ExceptionContainer(e))
self._requeue(iterator)

# queue tasks up to buffer_size
while len(iterator_and_future_pairs) < self.buffer_size:
try:
Expand All @@ -372,19 +398,10 @@ def __iter__(self) -> Iterator[Union[T, _RaisingIterator.ExceptionContainer]]:
cast(Callable[[Iterable[T]], T], next), iterator
)
iterator_and_future_pairs.append((iterator, future))

if element_to_yield:
yield element_to_yield.pop()
if not iterator_and_future_pairs:
break
iterator, future = iterator_and_future_pairs.popleft()
try:
yield future.result()
except StopIteration:
continue
except Exception as e:
yield _RaisingIterator.ExceptionContainer(e)
self.iterables_iterator = itertools.chain(
self.iterables_iterator, [iterator]
)


# functions
Expand Down
6 changes: 4 additions & 2 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,9 @@ def remembering_src() -> Iterator[int]:

for stream in [
Stream(remembering_src).map(identity, concurrency=concurrency),
Stream(remembering_src).amap(async_identity, concurrency=concurrency),
Stream(remembering_src).foreach(identity, concurrency=concurrency),
Stream(remembering_src).aforeach(async_identity, concurrency=concurrency),
Stream(remembering_src).group(1).flatten(concurrency=concurrency),
]:
yielded_elems = []
Expand All @@ -519,8 +521,8 @@ def remembering_src() -> Iterator[int]:
time.sleep(0.5)
self.assertEqual(
len(yielded_elems),
concurrency,
msg=f"after the first call to `next` a concurrent {type(stream)} should have pulled only {concurrency} (=concurrency) upstream elements.",
concurrency + 1,
msg=f"after the first call to `next` a concurrent {type(stream)} should have pulled only {concurrency + 1} (== concurrency + 1) upstream elements.",
)

def test_filter(self) -> None:
Expand Down

0 comments on commit d07ed88

Please sign in to comment.