Skip to content

Commit

Permalink
add asyncio support: .amap and .aforeach operations
Browse files Browse the repository at this point in the history
  • Loading branch information
ebonnal committed Jun 9, 2024
1 parent 3eef2e8 commit e088216
Show file tree
Hide file tree
Showing 10 changed files with 359 additions and 61 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ integer_strings: Stream[str] = integers.map(str)

It has an optional `concurrency: int` parameter to execute the function concurrently (threads) while preserving the order.

There is also an async counterpart operation called `.amap` to run coroutines concurrently (`asyncio`).

## `.foreach`
Applies a function on elements like `.map` but yields the elements instead of the results.

Expand All @@ -74,6 +76,8 @@ printed_integers: Stream[int] = integers.foreach(print)

It has an optional `concurrency: int` parameter to execute the function concurrently (threads) while preserving the order.

There is also an async counterpart operation called `.aforeach` to run coroutines concurrently (`asyncio`).

## `.filter`
Filters elements based on a predicate function.

Expand Down
19 changes: 17 additions & 2 deletions streamable/_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import sys
from typing import Any, Callable, Optional, Type, TypeVar
from typing import Any, Callable, Coroutine, Optional, Type, TypeVar

LOGGER = logging.getLogger("streamable")
LOGGER.propagate = False
Expand All @@ -16,13 +16,28 @@


def sidify(func: Callable[[T], Any]) -> Callable[[T], T]:
def wrap(arg):
def wrap(arg: T):
func(arg)
return arg

return wrap


def async_sidify(
func: Callable[[T], Coroutine]
) -> Callable[[T], Coroutine[Any, Any, T]]:
async def wrap(arg: T) -> T:
coroutine = func(arg)
if not isinstance(coroutine, Coroutine):
raise TypeError(
f"`func` is expected to return a Coroutine but got a {type(coroutine)}."
)
await coroutine
return arg

return wrap


def reraise_as(
func: Callable[[T], R], source: Type[Exception], target: Type[Exception]
) -> Callable[[T], R]:
Expand Down
66 changes: 66 additions & 0 deletions streamable/functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import builtins
import itertools
import time
Expand All @@ -8,6 +9,7 @@
from typing import (
Any,
Callable,
Coroutine,
Deque,
Dict,
Iterable,
Expand Down Expand Up @@ -294,6 +296,52 @@ def __iter__(self) -> Iterator[Union[U, _RaisingIterator.ExceptionContainer]]:
yield _RaisingIterator.ExceptionContainer(e)


class _AsyncConcurrentMappingIterable(
Iterable[Union[U, _RaisingIterator.ExceptionContainer]]
):
_LOOP = asyncio.new_event_loop()

def __init__(
self,
iterator: Iterator[T],
func: Callable[[T], Coroutine[Any, Any, U]],
concurrency: int,
buffer_size: int,
) -> None:
self.iterator = iterator
self.func = func
self.concurrency = concurrency
self.buffer_size = buffer_size

async def _safe_func(self, elem: T) -> U:
try:
coroutine = self.func(elem)
if not isinstance(coroutine, Coroutine):
raise TypeError(
f"The `func` passed to `amap` or `aforeach` must return a Coroutine object, but got a {type(coroutine)}."
)
return await coroutine
except Exception as e:
return _RaisingIterator.ExceptionContainer(e)

def __iter__(self) -> Iterator[Union[U, _RaisingIterator.ExceptionContainer]]:
awaitables: Deque[Future] = deque()
# queue and yield (FIFO)
while True:
# queue tasks up to buffer_size
while len(awaitables) < self.buffer_size:
try:
elem = next(self.iterator)
except StopIteration:
# the upstream iterator is exhausted
break
awaitables.append(self._LOOP.create_task(self._safe_func(elem)))
if not awaitables:
break
while len(awaitables):
yield self._LOOP.run_until_complete(awaitables.popleft())


class _ConcurrentFlatteningIterable(
Iterable[Union[T, _RaisingIterator.ExceptionContainer]]
):
Expand Down Expand Up @@ -413,6 +461,24 @@ def map(
)


def amap(
func: Callable[[T], Coroutine[Any, Any, U]],
iterator: Iterator[T],
concurrency: int = 1,
) -> Iterator[U]:
_util.validate_concurrency(concurrency)
return _RaisingIterator(
iter(
_AsyncConcurrentMappingIterable(
iterator,
_util.reraise_as(func, StopIteration, WrappedStopIteration),
concurrency=concurrency,
buffer_size=concurrency,
)
)
)


def limit(iterator: Iterator[T], count: int) -> Iterator[T]:
_util.validate_limit_count(count)
return _LimitingIterator(iterator, count)
Expand Down
63 changes: 63 additions & 0 deletions streamable/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Any,
Callable,
Collection,
Coroutine,
Generic,
Iterable,
Iterator,
Expand Down Expand Up @@ -237,6 +238,24 @@ def foreach(
validate_concurrency(concurrency)
return ForeachStream(self, func, concurrency)

def aforeach(
self,
func: Callable[[T], Coroutine],
concurrency: int = 1,
) -> "Stream[T]":
"""
Call the asynchronous `func` on upstream elements and yield them in order.
If the `func(elem)` coroutine throws an exception then it will be thrown and `elem` will not be yielded.
Args:
func (Callable[[T], Any]): The asynchronous function to be applied to each element.
concurrency (int): How many asyncio tasks will run at the same time.
Returns:
Stream[T]: A stream of upstream elements, unchanged.
"""
validate_concurrency(concurrency)
return AForeachStream(self, func, concurrency)

def group(
self,
size: Optional[int] = None,
Expand Down Expand Up @@ -292,6 +311,23 @@ def map(
validate_concurrency(concurrency)
return MapStream(self, func, concurrency)

def amap(
self,
func: Callable[[T], Coroutine[Any, Any, U]],
concurrency: int = 1,
) -> "Stream[U]":
"""
Apply an asynchronous `func` on upstream elements and yield the results in order.
Args:
func (Callable[[T], Coroutine[Any, Any, U]]): The asynchronous function to be applied to each element.
concurrency (int): How many asyncio tasks will run at the same time.
Returns:
Stream[R]: A stream of results of `func` applied to upstream elements.
"""
validate_concurrency(concurrency)
return AMapStream(self, func, concurrency)

def observe(self, what: str = "elements", colored: bool = False) -> "Stream[T]":
"""
Log the progress of any iteration over this stream's elements.
Expand Down Expand Up @@ -389,6 +425,18 @@ def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_foreach_stream(self)


class AForeachStream(DownStream[T, T]):
def __init__(
self, upstream: Stream[T], func: Callable[[T], Coroutine], concurrency: int
) -> None:
super().__init__(upstream)
self.func = func
self.concurrency = concurrency

def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_aforeach_stream(self)


class GroupStream(DownStream[T, List[T]]):
def __init__(
self,
Expand Down Expand Up @@ -427,6 +475,21 @@ def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_map_stream(self)


class AMapStream(DownStream[T, U]):
def __init__(
self,
upstream: Stream[T],
func: Callable[[T], Coroutine[Any, Any, U]],
concurrency: int,
) -> None:
super().__init__(upstream)
self.func = func
self.concurrency = concurrency

def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_amap_stream(self)


class ObserveStream(DownStream[T, T]):
def __init__(self, upstream: Stream[T], what: str, colored: bool) -> None:
super().__init__(upstream)
Expand Down
6 changes: 6 additions & 0 deletions streamable/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def visit_flatten_stream(self, stream: stream.FlattenStream) -> V:
def visit_foreach_stream(self, stream: stream.ForeachStream) -> V:
return self.visit_stream(stream)

def visit_aforeach_stream(self, stream: stream.AForeachStream) -> V:
return self.visit_stream(stream)

def visit_group_stream(self, stream: stream.GroupStream) -> V:
return self.visit_stream(stream)

Expand All @@ -36,5 +39,8 @@ def visit_observe_stream(self, stream: stream.ObserveStream) -> V:
def visit_map_stream(self, stream: stream.MapStream) -> V:
return self.visit_stream(stream)

def visit_amap_stream(self, stream: stream.AMapStream) -> V:
return self.visit_stream(stream)

def visit_slow_stream(self, stream: stream.SlowStream) -> V:
return self.visit_stream(stream)
6 changes: 6 additions & 0 deletions streamable/visitors/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def visit_foreach_stream(self, stream: stream.ForeachStream) -> str:
f"func={_util.get_name(stream.func)}, concurrency={stream.concurrency}",
)

def visit_aforeach_stream(self, stream: stream.ForeachStream) -> str:
return self.visit_foreach_stream(stream)

def visit_group_stream(self, stream: stream.GroupStream) -> str:
return self._explanation(
stream, f"size={stream.size}, seconds={stream.seconds}, by={stream.by}"
Expand All @@ -82,6 +85,9 @@ def visit_map_stream(self, stream: stream.MapStream) -> str:
f"func={_util.get_name(stream.func)}, concurrency={stream.concurrency}",
)

def visit_amap_stream(self, stream: stream.AMapStream) -> str:
return self.visit_map_stream(stream)

def visit_observe_stream(self, stream: stream.ObserveStream) -> str:
return self._explanation(
stream, f"what='{stream.what}', colored={stream.colored}"
Expand Down
18 changes: 18 additions & 0 deletions streamable/visitors/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from streamable import _util, functions
from streamable.stream import (
AForeachStream,
AMapStream,
CatchStream,
FilterStream,
FlattenStream,
Expand Down Expand Up @@ -50,6 +52,15 @@ def visit_foreach_stream(self, stream: ForeachStream[T]) -> Iterator[T]:
)
)

def visit_aforeach_stream(self, stream: AForeachStream[T]) -> Iterator[T]:
return self.visit_amap_stream(
AMapStream(
stream.upstream,
_util.async_sidify(stream.func),
stream.concurrency,
)
)

def visit_group_stream(self, stream: GroupStream[U]) -> Iterator[T]:
return cast(
Iterator[T],
Expand All @@ -74,6 +85,13 @@ def visit_map_stream(self, stream: MapStream[U, T]) -> Iterator[T]:
concurrency=stream.concurrency,
)

def visit_amap_stream(self, stream: AMapStream[U, T]) -> Iterator[T]:
return functions.amap(
stream.func,
stream.upstream.accept(IteratorVisitor[U]()),
concurrency=stream.concurrency,
)

def visit_observe_stream(self, stream: ObserveStream[T]) -> Iterator[T]:
return functions.observe(
stream.upstream.accept(self),
Expand Down
7 changes: 3 additions & 4 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from typing import Callable, Iterable, Iterator, List, TypeVar, cast
from typing import Callable, Iterator, List, TypeVar, cast

from streamable.functions import catch, flatten, group, limit, map, observe, slow

Expand All @@ -10,13 +10,12 @@
N = 256


def src() -> Iterable[int]:
return range(N)
src = range(N)


class TestFunctions(unittest.TestCase):
def test_signatures(self) -> None:
it = iter(src())
it = iter(src)
func = cast(Callable[[int], int], ...)
mapped_it_1: Iterator[int] = map(func, it)
mapped_it_2: Iterator[int] = map(func, it, concurrency=1)
Expand Down
Loading

0 comments on commit e088216

Please sign in to comment.