Skip to content

Commit

Permalink
add asyncio support: .amap and .aforeach operations
Browse files Browse the repository at this point in the history
  • Loading branch information
ebonnal committed Jun 9, 2024
1 parent 3eef2e8 commit 16d0b82
Show file tree
Hide file tree
Showing 10 changed files with 419 additions and 75 deletions.
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ integer_strings: Stream[str] = integers.map(str)

It has an optional `concurrency: int` parameter to execute the function concurrently (threads) while preserving the order.

There is also an async counterpart operation called `.amap` to run coroutines concurrently (`asyncio`).

## `.foreach`
Applies a function on elements like `.map` but yields the elements instead of the results.

Expand All @@ -74,6 +76,8 @@ printed_integers: Stream[int] = integers.foreach(print)

It has an optional `concurrency: int` parameter to execute the function concurrently (threads) while preserving the order.

There is also an async counterpart operation called `.aforeach` to run coroutines concurrently (`asyncio`).

## `.filter`
Filters elements based on a predicate function.

Expand Down Expand Up @@ -189,6 +193,37 @@ stream: Stream[str] = (
)
```

## asyncio
The majority of the use cases should find convenient the threads-based concurrency available when using `.map` or `.foreach`, but as an alternative there are the `.amap` and `.aforeach` operations that **allow to apply `async` functions** concurrently on your stream:

```python
import asyncio
import time

async def slow_async_square(n: int) -> int:
await asyncio.sleep(3)
return n ** 2

def slow_str(n: int) -> str:
time.sleep(3)
return str(n)

print(
", ".join(
integers
# coroutines-based concurrency
.amap(slow_async_square, concurrency=8)
# threads-based concurrency
.map(slow_str, concurrency=8)
.limit(5)
)
)
```
this prints (in 6s):
```bash
0, 1, 4, 9, 16
```

## functions
The `Stream`'s methods are also exposed as functions:
```python
Expand Down
19 changes: 17 additions & 2 deletions streamable/_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import sys
from typing import Any, Callable, Optional, Type, TypeVar
from typing import Any, Callable, Coroutine, Optional, Type, TypeVar

LOGGER = logging.getLogger("streamable")
LOGGER.propagate = False
Expand All @@ -16,13 +16,28 @@


def sidify(func: Callable[[T], Any]) -> Callable[[T], T]:
def wrap(arg):
def wrap(arg: T):
func(arg)
return arg

return wrap


def async_sidify(
func: Callable[[T], Coroutine]
) -> Callable[[T], Coroutine[Any, Any, T]]:
async def wrap(arg: T) -> T:
coroutine = func(arg)
if not isinstance(coroutine, Coroutine):
raise TypeError(
f"`func` is expected to return a Coroutine but got a {type(coroutine)}."
)
await coroutine
return arg

return wrap


def reraise_as(
func: Callable[[T], R], source: Type[Exception], target: Type[Exception]
) -> Callable[[T], R]:
Expand Down
66 changes: 66 additions & 0 deletions streamable/functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import builtins
import itertools
import time
Expand All @@ -8,6 +9,7 @@
from typing import (
Any,
Callable,
Coroutine,
Deque,
Dict,
Iterable,
Expand Down Expand Up @@ -294,6 +296,53 @@ def __iter__(self) -> Iterator[Union[U, _RaisingIterator.ExceptionContainer]]:
yield _RaisingIterator.ExceptionContainer(e)


class _AsyncConcurrentMappingIterable(
Iterable[Union[U, _RaisingIterator.ExceptionContainer]]
):
_LOOP = asyncio.new_event_loop()

def __init__(
self,
iterator: Iterator[T],
func: Callable[[T], Coroutine[Any, Any, U]],
buffer_size: int,
) -> None:
self.iterator = iterator
self.func = func
self.buffer_size = buffer_size

async def _safe_func(
self, elem: T
) -> Union[U, _RaisingIterator.ExceptionContainer]:
try:
coroutine = self.func(elem)
if not isinstance(coroutine, Coroutine):
raise TypeError(
f"The `func` passed to `amap` or `aforeach` must return a Coroutine object, but got a {type(coroutine)}."
)
return await coroutine
except Exception as e:
return _RaisingIterator.ExceptionContainer(e)

def __iter__(self) -> Iterator[Union[U, _RaisingIterator.ExceptionContainer]]:
awaitables: Deque[
asyncio.Task[Union[U, _RaisingIterator.ExceptionContainer]]
] = deque()
# queue and yield (FIFO)
while True:
# queue tasks up to buffer_size
while len(awaitables) < self.buffer_size:
try:
elem = next(self.iterator)
except StopIteration:
# the upstream iterator is exhausted
break
awaitables.append(self._LOOP.create_task(self._safe_func(elem)))
if not awaitables:
break
yield self._LOOP.run_until_complete(awaitables.popleft())


class _ConcurrentFlatteningIterable(
Iterable[Union[T, _RaisingIterator.ExceptionContainer]]
):
Expand Down Expand Up @@ -413,6 +462,23 @@ def map(
)


def amap(
func: Callable[[T], Coroutine[Any, Any, U]],
iterator: Iterator[T],
concurrency: int = 1,
) -> Iterator[U]:
_util.validate_concurrency(concurrency)
return _RaisingIterator(
iter(
_AsyncConcurrentMappingIterable(
iterator,
_util.reraise_as(func, StopIteration, WrappedStopIteration),
buffer_size=concurrency,
)
)
)


def limit(iterator: Iterator[T], count: int) -> Iterator[T]:
_util.validate_limit_count(count)
return _LimitingIterator(iterator, count)
Expand Down
63 changes: 63 additions & 0 deletions streamable/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Any,
Callable,
Collection,
Coroutine,
Generic,
Iterable,
Iterator,
Expand Down Expand Up @@ -237,6 +238,24 @@ def foreach(
validate_concurrency(concurrency)
return ForeachStream(self, func, concurrency)

def aforeach(
self,
func: Callable[[T], Coroutine],
concurrency: int = 1,
) -> "Stream[T]":
"""
Call the asynchronous `func` on upstream elements and yield them in order.
If the `func(elem)` coroutine throws an exception then it will be thrown and `elem` will not be yielded.
Args:
func (Callable[[T], Any]): The asynchronous function to be applied to each element.
concurrency (int): How many asyncio tasks will run at the same time.
Returns:
Stream[T]: A stream of upstream elements, unchanged.
"""
validate_concurrency(concurrency)
return AForeachStream(self, func, concurrency)

def group(
self,
size: Optional[int] = None,
Expand Down Expand Up @@ -292,6 +311,23 @@ def map(
validate_concurrency(concurrency)
return MapStream(self, func, concurrency)

def amap(
self,
func: Callable[[T], Coroutine[Any, Any, U]],
concurrency: int = 1,
) -> "Stream[U]":
"""
Apply an asynchronous `func` on upstream elements and yield the results in order.
Args:
func (Callable[[T], Coroutine[Any, Any, U]]): The asynchronous function to be applied to each element.
concurrency (int): How many asyncio tasks will run at the same time.
Returns:
Stream[R]: A stream of results of `func` applied to upstream elements.
"""
validate_concurrency(concurrency)
return AMapStream(self, func, concurrency)

def observe(self, what: str = "elements", colored: bool = False) -> "Stream[T]":
"""
Log the progress of any iteration over this stream's elements.
Expand Down Expand Up @@ -389,6 +425,18 @@ def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_foreach_stream(self)


class AForeachStream(DownStream[T, T]):
def __init__(
self, upstream: Stream[T], func: Callable[[T], Coroutine], concurrency: int
) -> None:
super().__init__(upstream)
self.func = func
self.concurrency = concurrency

def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_aforeach_stream(self)


class GroupStream(DownStream[T, List[T]]):
def __init__(
self,
Expand Down Expand Up @@ -427,6 +475,21 @@ def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_map_stream(self)


class AMapStream(DownStream[T, U]):
def __init__(
self,
upstream: Stream[T],
func: Callable[[T], Coroutine[Any, Any, U]],
concurrency: int,
) -> None:
super().__init__(upstream)
self.func = func
self.concurrency = concurrency

def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_amap_stream(self)


class ObserveStream(DownStream[T, T]):
def __init__(self, upstream: Stream[T], what: str, colored: bool) -> None:
super().__init__(upstream)
Expand Down
6 changes: 6 additions & 0 deletions streamable/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def visit_flatten_stream(self, stream: stream.FlattenStream) -> V:
def visit_foreach_stream(self, stream: stream.ForeachStream) -> V:
return self.visit_stream(stream)

def visit_aforeach_stream(self, stream: stream.AForeachStream) -> V:
return self.visit_stream(stream)

def visit_group_stream(self, stream: stream.GroupStream) -> V:
return self.visit_stream(stream)

Expand All @@ -36,5 +39,8 @@ def visit_observe_stream(self, stream: stream.ObserveStream) -> V:
def visit_map_stream(self, stream: stream.MapStream) -> V:
return self.visit_stream(stream)

def visit_amap_stream(self, stream: stream.AMapStream) -> V:
return self.visit_stream(stream)

def visit_slow_stream(self, stream: stream.SlowStream) -> V:
return self.visit_stream(stream)
46 changes: 32 additions & 14 deletions streamable/visitors/explanation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
import textwrap
from typing import cast

from streamable import _util, stream
from streamable.stream import (
AForeachStream,
AMapStream,
CatchStream,
FilterStream,
FlattenStream,
ForeachStream,
GroupStream,
LimitStream,
MapStream,
ObserveStream,
SlowStream,
Stream,
)
from streamable.visitor import Visitor


Expand Down Expand Up @@ -45,47 +60,50 @@ def _explanation(self, stream: stream.Stream, attributes_repr: str) -> str:

return explanation

def visit_stream(self, stream: stream.Stream) -> str:
def visit_stream(self, stream: Stream) -> str:
return self._explanation(stream, f"source={_util.get_name(stream.source)}")

def visit_catch_stream(self, stream: stream.CatchStream) -> str:
def visit_catch_stream(self, stream: CatchStream) -> str:
return self._explanation(
stream,
f"predicate={_util.get_name(stream.predicate)}, raise_at_exhaustion={stream.raise_at_exhaustion}",
)

def visit_filter_stream(self, stream: stream.FilterStream) -> str:
def visit_filter_stream(self, stream: FilterStream) -> str:
return self._explanation(
stream, f"predicate={_util.get_name(stream.predicate)}"
)

def visit_flatten_stream(self, stream: stream.FlattenStream) -> str:
def visit_flatten_stream(self, stream: FlattenStream) -> str:
return self._explanation(stream, f"concurrency={stream.concurrency}")

def visit_foreach_stream(self, stream: stream.ForeachStream) -> str:
return self._explanation(
stream,
f"func={_util.get_name(stream.func)}, concurrency={stream.concurrency}",
)
def visit_foreach_stream(self, stream: ForeachStream) -> str:
return self.visit_map_stream(cast(MapStream, stream))

def visit_aforeach_stream(self, stream: AForeachStream) -> str:
return self.visit_map_stream(cast(MapStream, stream))

def visit_group_stream(self, stream: stream.GroupStream) -> str:
def visit_group_stream(self, stream: GroupStream) -> str:
return self._explanation(
stream, f"size={stream.size}, seconds={stream.seconds}, by={stream.by}"
)

def visit_limit_stream(self, stream: stream.LimitStream) -> str:
def visit_limit_stream(self, stream: LimitStream) -> str:
return self._explanation(stream, f"count={stream.count}")

def visit_map_stream(self, stream: stream.MapStream) -> str:
def visit_map_stream(self, stream: MapStream) -> str:
return self._explanation(
stream,
f"func={_util.get_name(stream.func)}, concurrency={stream.concurrency}",
)

def visit_observe_stream(self, stream: stream.ObserveStream) -> str:
def visit_amap_stream(self, stream: AMapStream) -> str:
return self.visit_map_stream(cast(MapStream, stream))

def visit_observe_stream(self, stream: ObserveStream) -> str:
return self._explanation(
stream, f"what='{stream.what}', colored={stream.colored}"
)

def visit_slow_stream(self, stream: stream.SlowStream) -> str:
def visit_slow_stream(self, stream: SlowStream) -> str:
return self._explanation(stream, f"frequency={stream.frequency}")
Loading

0 comments on commit 16d0b82

Please sign in to comment.