diff --git a/README.md b/README.md index 898384d..20f7ff3 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,8 @@ integer_strings: Stream[str] = integers.map(str) It has an optional `concurrency: int` parameter to execute the function concurrently (threads) while preserving the order. +There is also an async counterpart operation called `.amap` to run coroutines concurrently (`asyncio`). + ## `.foreach` Applies a function on elements like `.map` but yields the elements instead of the results. @@ -74,6 +76,8 @@ printed_integers: Stream[int] = integers.foreach(print) It has an optional `concurrency: int` parameter to execute the function concurrently (threads) while preserving the order. +There is also an async counterpart operation called `.aforeach` to run coroutines concurrently (`asyncio`). + ## `.filter` Filters elements based on a predicate function. @@ -189,6 +193,38 @@ stream: Stream[str] = ( ) ``` +## asyncio +Even though the vast majority of the use cases should find the threads-based concurrency convenient (by setting the `concurrency`parameter when using `.map` / `.foreach`), there is 2 alternative operations `.amap` and `.aforeach` that allows to apply `async` functions concurrently on your stream, e.g: + +```python +import asyncio + +async def slow_format(i: int) -> str: + await asyncio.sleep(10) + return f"This is {i}." + +async def slow_print(o) -> None: + await asyncio.sleep(10) + print(o) + +( + integers + .map(lambda n: n**2) + .amap(slow_format, concurrency=8) + .limit(5) + .aforeach(slow_print, concurrency=8) + .exhaust() +) +``` +This prints (in 20 seconds): +```bash +This is 0. +This is 1. +This is 4. +This is 9. +This is 16. +``` + ## functions The `Stream`'s methods are also exposed as functions: ```python diff --git a/streamable/_util.py b/streamable/_util.py index 68509ba..cfb3ee6 100644 --- a/streamable/_util.py +++ b/streamable/_util.py @@ -1,6 +1,6 @@ import logging import sys -from typing import Any, Callable, Optional, Type, TypeVar +from typing import Any, Callable, Coroutine, Optional, Type, TypeVar LOGGER = logging.getLogger("streamable") LOGGER.propagate = False @@ -16,13 +16,28 @@ def sidify(func: Callable[[T], Any]) -> Callable[[T], T]: - def wrap(arg): + def wrap(arg: T): func(arg) return arg return wrap +def async_sidify( + func: Callable[[T], Coroutine] +) -> Callable[[T], Coroutine[Any, Any, T]]: + async def wrap(arg: T) -> T: + coroutine = func(arg) + if not isinstance(coroutine, Coroutine): + raise TypeError( + f"`func` is expected to return a Coroutine but got a {type(coroutine)}." + ) + await coroutine + return arg + + return wrap + + def reraise_as( func: Callable[[T], R], source: Type[Exception], target: Type[Exception] ) -> Callable[[T], R]: diff --git a/streamable/functions.py b/streamable/functions.py index 76cfe5e..72bce3d 100644 --- a/streamable/functions.py +++ b/streamable/functions.py @@ -1,3 +1,4 @@ +import asyncio import builtins import itertools import time @@ -8,6 +9,7 @@ from typing import ( Any, Callable, + Coroutine, Deque, Dict, Iterable, @@ -294,6 +296,56 @@ def __iter__(self) -> Iterator[Union[U, _RaisingIterator.ExceptionContainer]]: yield _RaisingIterator.ExceptionContainer(e) +class _AsyncConcurrentMappingIterable( + Iterable[Union[U, _RaisingIterator.ExceptionContainer]] +): + _LOOP = asyncio.new_event_loop() + + def __init__( + self, + iterator: Iterator[T], + func: Callable[[T], Coroutine[Any, Any, U]], + concurrency: int, + buffer_size: int, + ) -> None: + self.iterator = iterator + self.func = func + self.concurrency = concurrency + self.buffer_size = buffer_size + + async def _safe_func( + self, elem: T + ) -> Union[U, _RaisingIterator.ExceptionContainer]: + try: + coroutine = self.func(elem) + if not isinstance(coroutine, Coroutine): + raise TypeError( + f"The `func` passed to `amap` or `aforeach` must return a Coroutine object, but got a {type(coroutine)}." + ) + return await coroutine + except Exception as e: + return _RaisingIterator.ExceptionContainer(e) + + def __iter__(self) -> Iterator[Union[U, _RaisingIterator.ExceptionContainer]]: + awaitables: Deque[ + asyncio.Task[Union[U, _RaisingIterator.ExceptionContainer]] + ] = deque() + # queue and yield (FIFO) + while True: + # queue tasks up to buffer_size + while len(awaitables) < self.buffer_size: + try: + elem = next(self.iterator) + except StopIteration: + # the upstream iterator is exhausted + break + awaitables.append(self._LOOP.create_task(self._safe_func(elem))) + if not awaitables: + break + while len(awaitables): + yield self._LOOP.run_until_complete(awaitables.popleft()) + + class _ConcurrentFlatteningIterable( Iterable[Union[T, _RaisingIterator.ExceptionContainer]] ): @@ -413,6 +465,24 @@ def map( ) +def amap( + func: Callable[[T], Coroutine[Any, Any, U]], + iterator: Iterator[T], + concurrency: int = 1, +) -> Iterator[U]: + _util.validate_concurrency(concurrency) + return _RaisingIterator( + iter( + _AsyncConcurrentMappingIterable( + iterator, + _util.reraise_as(func, StopIteration, WrappedStopIteration), + concurrency=concurrency, + buffer_size=concurrency, + ) + ) + ) + + def limit(iterator: Iterator[T], count: int) -> Iterator[T]: _util.validate_limit_count(count) return _LimitingIterator(iterator, count) diff --git a/streamable/stream.py b/streamable/stream.py index c2955d8..a55dd27 100644 --- a/streamable/stream.py +++ b/streamable/stream.py @@ -3,6 +3,7 @@ Any, Callable, Collection, + Coroutine, Generic, Iterable, Iterator, @@ -237,6 +238,24 @@ def foreach( validate_concurrency(concurrency) return ForeachStream(self, func, concurrency) + def aforeach( + self, + func: Callable[[T], Coroutine], + concurrency: int = 1, + ) -> "Stream[T]": + """ + Call the asynchronous `func` on upstream elements and yield them in order. + If the `func(elem)` coroutine throws an exception then it will be thrown and `elem` will not be yielded. + + Args: + func (Callable[[T], Any]): The asynchronous function to be applied to each element. + concurrency (int): How many asyncio tasks will run at the same time. + Returns: + Stream[T]: A stream of upstream elements, unchanged. + """ + validate_concurrency(concurrency) + return AForeachStream(self, func, concurrency) + def group( self, size: Optional[int] = None, @@ -292,6 +311,23 @@ def map( validate_concurrency(concurrency) return MapStream(self, func, concurrency) + def amap( + self, + func: Callable[[T], Coroutine[Any, Any, U]], + concurrency: int = 1, + ) -> "Stream[U]": + """ + Apply an asynchronous `func` on upstream elements and yield the results in order. + + Args: + func (Callable[[T], Coroutine[Any, Any, U]]): The asynchronous function to be applied to each element. + concurrency (int): How many asyncio tasks will run at the same time. + Returns: + Stream[R]: A stream of results of `func` applied to upstream elements. + """ + validate_concurrency(concurrency) + return AMapStream(self, func, concurrency) + def observe(self, what: str = "elements", colored: bool = False) -> "Stream[T]": """ Log the progress of any iteration over this stream's elements. @@ -389,6 +425,18 @@ def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_foreach_stream(self) +class AForeachStream(DownStream[T, T]): + def __init__( + self, upstream: Stream[T], func: Callable[[T], Coroutine], concurrency: int + ) -> None: + super().__init__(upstream) + self.func = func + self.concurrency = concurrency + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_aforeach_stream(self) + + class GroupStream(DownStream[T, List[T]]): def __init__( self, @@ -427,6 +475,21 @@ def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_map_stream(self) +class AMapStream(DownStream[T, U]): + def __init__( + self, + upstream: Stream[T], + func: Callable[[T], Coroutine[Any, Any, U]], + concurrency: int, + ) -> None: + super().__init__(upstream) + self.func = func + self.concurrency = concurrency + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_amap_stream(self) + + class ObserveStream(DownStream[T, T]): def __init__(self, upstream: Stream[T], what: str, colored: bool) -> None: super().__init__(upstream) diff --git a/streamable/visitor.py b/streamable/visitor.py index 56d6a4c..be9400f 100644 --- a/streamable/visitor.py +++ b/streamable/visitor.py @@ -24,6 +24,9 @@ def visit_flatten_stream(self, stream: stream.FlattenStream) -> V: def visit_foreach_stream(self, stream: stream.ForeachStream) -> V: return self.visit_stream(stream) + def visit_aforeach_stream(self, stream: stream.AForeachStream) -> V: + return self.visit_stream(stream) + def visit_group_stream(self, stream: stream.GroupStream) -> V: return self.visit_stream(stream) @@ -36,5 +39,8 @@ def visit_observe_stream(self, stream: stream.ObserveStream) -> V: def visit_map_stream(self, stream: stream.MapStream) -> V: return self.visit_stream(stream) + def visit_amap_stream(self, stream: stream.AMapStream) -> V: + return self.visit_stream(stream) + def visit_slow_stream(self, stream: stream.SlowStream) -> V: return self.visit_stream(stream) diff --git a/streamable/visitors/explanation.py b/streamable/visitors/explanation.py index 0458e22..1eee9a9 100644 --- a/streamable/visitors/explanation.py +++ b/streamable/visitors/explanation.py @@ -1,6 +1,21 @@ import textwrap +from typing import cast from streamable import _util, stream +from streamable.stream import ( + AForeachStream, + AMapStream, + CatchStream, + FilterStream, + FlattenStream, + ForeachStream, + GroupStream, + LimitStream, + MapStream, + ObserveStream, + SlowStream, + Stream, +) from streamable.visitor import Visitor @@ -45,47 +60,50 @@ def _explanation(self, stream: stream.Stream, attributes_repr: str) -> str: return explanation - def visit_stream(self, stream: stream.Stream) -> str: + def visit_stream(self, stream: Stream) -> str: return self._explanation(stream, f"source={_util.get_name(stream.source)}") - def visit_catch_stream(self, stream: stream.CatchStream) -> str: + def visit_catch_stream(self, stream: CatchStream) -> str: return self._explanation( stream, f"predicate={_util.get_name(stream.predicate)}, raise_at_exhaustion={stream.raise_at_exhaustion}", ) - def visit_filter_stream(self, stream: stream.FilterStream) -> str: + def visit_filter_stream(self, stream: FilterStream) -> str: return self._explanation( stream, f"predicate={_util.get_name(stream.predicate)}" ) - def visit_flatten_stream(self, stream: stream.FlattenStream) -> str: + def visit_flatten_stream(self, stream: FlattenStream) -> str: return self._explanation(stream, f"concurrency={stream.concurrency}") - def visit_foreach_stream(self, stream: stream.ForeachStream) -> str: - return self._explanation( - stream, - f"func={_util.get_name(stream.func)}, concurrency={stream.concurrency}", - ) + def visit_foreach_stream(self, stream: ForeachStream) -> str: + return self.visit_map_stream(cast(MapStream, stream)) + + def visit_aforeach_stream(self, stream: AForeachStream) -> str: + return self.visit_map_stream(cast(MapStream, stream)) - def visit_group_stream(self, stream: stream.GroupStream) -> str: + def visit_group_stream(self, stream: GroupStream) -> str: return self._explanation( stream, f"size={stream.size}, seconds={stream.seconds}, by={stream.by}" ) - def visit_limit_stream(self, stream: stream.LimitStream) -> str: + def visit_limit_stream(self, stream: LimitStream) -> str: return self._explanation(stream, f"count={stream.count}") - def visit_map_stream(self, stream: stream.MapStream) -> str: + def visit_map_stream(self, stream: MapStream) -> str: return self._explanation( stream, f"func={_util.get_name(stream.func)}, concurrency={stream.concurrency}", ) - def visit_observe_stream(self, stream: stream.ObserveStream) -> str: + def visit_amap_stream(self, stream: AMapStream) -> str: + return self.visit_map_stream(cast(MapStream, stream)) + + def visit_observe_stream(self, stream: ObserveStream) -> str: return self._explanation( stream, f"what='{stream.what}', colored={stream.colored}" ) - def visit_slow_stream(self, stream: stream.SlowStream) -> str: + def visit_slow_stream(self, stream: SlowStream) -> str: return self._explanation(stream, f"frequency={stream.frequency}") diff --git a/streamable/visitors/iterator.py b/streamable/visitors/iterator.py index a883476..f1e132d 100644 --- a/streamable/visitors/iterator.py +++ b/streamable/visitors/iterator.py @@ -2,6 +2,8 @@ from streamable import _util, functions from streamable.stream import ( + AForeachStream, + AMapStream, CatchStream, FilterStream, FlattenStream, @@ -50,6 +52,15 @@ def visit_foreach_stream(self, stream: ForeachStream[T]) -> Iterator[T]: ) ) + def visit_aforeach_stream(self, stream: AForeachStream[T]) -> Iterator[T]: + return self.visit_amap_stream( + AMapStream( + stream.upstream, + _util.async_sidify(stream.func), + stream.concurrency, + ) + ) + def visit_group_stream(self, stream: GroupStream[U]) -> Iterator[T]: return cast( Iterator[T], @@ -74,6 +85,13 @@ def visit_map_stream(self, stream: MapStream[U, T]) -> Iterator[T]: concurrency=stream.concurrency, ) + def visit_amap_stream(self, stream: AMapStream[U, T]) -> Iterator[T]: + return functions.amap( + stream.func, + stream.upstream.accept(IteratorVisitor[U]()), + concurrency=stream.concurrency, + ) + def visit_observe_stream(self, stream: ObserveStream[T]) -> Iterator[T]: return functions.observe( stream.upstream.accept(self), diff --git a/tests/test_functions.py b/tests/test_functions.py index 87aee9a..fb60748 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -1,5 +1,5 @@ import unittest -from typing import Callable, Iterable, Iterator, List, TypeVar, cast +from typing import Callable, Iterator, List, TypeVar, cast from streamable.functions import catch, flatten, group, limit, map, observe, slow @@ -10,13 +10,12 @@ N = 256 -def src() -> Iterable[int]: - return range(N) +src = range(N) class TestFunctions(unittest.TestCase): def test_signatures(self) -> None: - it = iter(src()) + it = iter(src) func = cast(Callable[[int], int], ...) mapped_it_1: Iterator[int] = map(func, it) mapped_it_2: Iterator[int] = map(func, it, concurrency=1) diff --git a/tests/test_stream.py b/tests/test_stream.py index c5473c9..3db66cc 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -1,7 +1,20 @@ +import asyncio import time import timeit import unittest -from typing import Any, Callable, Iterable, Iterator, List, Set, Type, TypeVar, cast +from typing import ( + Any, + Callable, + Coroutine, + Iterable, + Iterator, + List, + Set, + Tuple, + Type, + TypeVar, + cast, +) from parameterized import parameterized # type: ignore @@ -11,12 +24,14 @@ T = TypeVar("T") -def timestream(stream: Stream): +def timestream(stream: Stream[T]) -> Tuple[float, List[T]]: + res: List[T] = [] + def iterate(): - for _ in stream: - pass + nonlocal res + res = list(stream) - return timeit.timeit(iterate, number=1) + return timeit.timeit(iterate, number=1), res # simulates an I/0 bound function @@ -28,18 +43,54 @@ def slow_identity(x: T) -> T: return x +async def async_slow_identity(x: T) -> T: + await asyncio.sleep(slow_identity_duration) + return x + + def identity(x: T) -> T: return x +# fmt: off +async def async_identity(x: T) -> T: return x +# fmt: on + + def square(x): return x**2 +async def async_square(x): + return x**2 + + def throw(exc: Type[Exception]): raise exc() +def throw_func(exc: Type[Exception]) -> Callable[[Any], None]: + return lambda _: throw(exc) + + +def async_throw_func(exc: Type[Exception]) -> Callable[[Any], Coroutine]: + async def f(_: Any) -> None: + raise exc + + return f + + +def throw_for_odd_func(exc): + return lambda i: throw(exc) if i % 2 == 1 else i + + +def async_throw_for_odd_func(exc): + async def f(i): + return throw(exc) if i % 2 == 1 else i + + return f + + class TestError(Exception): pass @@ -48,19 +99,18 @@ class TestError(Exception): # size of the test collections N = 256 -src = range(N).__iter__ +src = range(N) + +pair_src = range(0, N, 2) def less_and_less_slow_src() -> Iterator[int]: """ Same as `src` but each element is yielded after a sleep time that gets shorter and shorter. """ - time.sleep(0.1 / N) - return iter(range(N)) - - -def pair_src() -> Iterable[int]: - return range(0, N, 2) + for i, elem in enumerate(src): + time.sleep(0.1 / (i + 1)) + yield elem def range_raising_at_exhaustion( @@ -98,8 +148,10 @@ def test_init(self) -> None: .group(100) .flatten() .map(identity) + .amap(async_identity) .filter() .foreach(identity) + .aforeach(async_identity) .catch() .observe() .slow(1) @@ -129,7 +181,9 @@ class CustomCallable: .limit(1024) .filter() .foreach(lambda _: _) + .aforeach(async_identity) .map(cast(Callable[[Any], Any], CustomCallable())) + .amap(async_identity) .group(100) .observe("groups") .flatten(concurrency=4) @@ -196,6 +250,7 @@ def test_add(self) -> None: @parameterized.expand( [ [Stream.map, [identity]], + [Stream.amap, [async_identity]], [Stream.foreach, [identity]], [Stream.flatten, []], ] @@ -229,7 +284,7 @@ def test_sanitize_concurrency(self, method, args) -> None: def test_map(self, concurrency) -> None: self.assertListEqual( list(Stream(less_and_less_slow_src).map(square, concurrency=concurrency)), - list(map(square, src())), + list(map(square, src)), msg="At any concurrency the `map` method should act as the builtin map function, transforming elements while preserving input elements order.", ) @@ -254,24 +309,35 @@ def side_effect(x: int, func: Callable[[int], int]): self.assertListEqual( res, - list(src()), + list(src), msg="At any concurrency the `foreach` method should return the upstream elements in order.", ) self.assertSetEqual( side_collection, - set(map(square, src())), + set(map(square, src)), msg="At any concurrency the `foreach` method should call func on upstream elements (in any order).", ) @parameterized.expand( [ - [raised_exc, catched_exc, concurrency, method] + [ + raised_exc, + catched_exc, + concurrency, + method, + throw_func_, + throw_for_odd_func_, + ] for raised_exc, catched_exc in [ (TestError, TestError), - (StopIteration, WrappedStopIteration), + (StopIteration, (WrappedStopIteration, RuntimeError)), ] for concurrency in [1, 2] - for method in [Stream.foreach, Stream.map] + for method, throw_func_, throw_for_odd_func_ in [ + (Stream.foreach, throw_func, throw_for_odd_func), + (Stream.map, throw_func, throw_for_odd_func), + (Stream.amap, async_throw_func, async_throw_for_odd_func), + ] ] ) def test_map_or_foreach_with_exception( @@ -280,36 +346,42 @@ def test_map_or_foreach_with_exception( catched_exc: Type[Exception], concurrency: int, method: Callable[[Stream, Callable[[Any], int], int], Stream], + throw_func: Callable[[Exception], Callable[[Any], int]], + throw_for_odd_func: Callable[[Exception], Callable[[Any], int]], ) -> None: with self.assertRaises( catched_exc, msg="At any concurrency, `map`and `foreach` must raise", ): - list(method(Stream(src), lambda _: throw(raised_exc), concurrency)) + list(method(Stream(src), throw_func(raised_exc), concurrency)) # type: ignore self.assertListEqual( list( - method( - Stream(src), - lambda i: throw(raised_exc) if i % 2 == 1 else i, - concurrency, - ).catch(catched_exc) + method(Stream(src), throw_for_odd_func(raised_exc), concurrency).catch( # type: ignore + lambda exc: isinstance(exc, catched_exc) + ) ), - list(pair_src()), + list(pair_src), msg="At any concurrency, `map`and `foreach` must not stop after one exception occured.", ) @parameterized.expand( [ - [method, concurrency] - for method in [Stream.foreach, Stream.map] + [method, func, concurrency] + for method, func in [ + (Stream.foreach, slow_identity), + (Stream.map, slow_identity), + (Stream.amap, async_slow_identity), + ] for concurrency in [1, 2, 4] ] ) - def test_map_and_foreach_concurrency(self, method, concurrency) -> None: + def test_map_and_foreach_concurrency(self, method, func, concurrency) -> None: expected_iteration_duration = N * slow_identity_duration / concurrency + duration, res = timestream(method(Stream(src), func, concurrency=concurrency)) + self.assertListEqual(res, list(src)) self.assertAlmostEqual( - timestream(method(Stream(src), slow_identity, concurrency=concurrency)), + duration, expected_iteration_duration, delta=expected_iteration_duration * DELTA_RATE, msg="Increasing the concurrency of mapping should decrease proportionnally the iteration's duration.", @@ -371,8 +443,9 @@ def test_flatten_concurrency(self, concurrency) -> None: iterables_stream = Stream(range(n_iterables)).map( lambda _: map(slow_identity, range(N // n_iterables)) ) + duration, _ = timestream(iterables_stream.flatten(concurrency=concurrency)) self.assertAlmostEqual( - timestream(iterables_stream.flatten(concurrency=concurrency)), + duration, expected_iteration_duration, delta=expected_iteration_duration * DELTA_RATE, msg="Increasing the concurrency of mapping should decrease proportionnally the iteration's duration.", @@ -425,7 +498,7 @@ def test_partial_iteration_on_streams_using_concurrency( def remembering_src() -> Iterator[int]: nonlocal yielded_elems - for elem in src(): + for elem in src: yielded_elems.append(elem) yield elem @@ -456,19 +529,19 @@ def predicate(x) -> Any: self.assertListEqual( list(Stream(src).filter(predicate)), - list(filter(predicate, src())), + list(filter(predicate, src)), msg="`filter` must act like builtin filter", ) self.assertListEqual( list(Stream(src).filter()), - list(filter(None, src())), + list(filter(None, src)), msg="`filter` without predicate must act like builtin filter with None predicate.", ) def test_limit(self) -> None: self.assertEqual( list(Stream(src).limit(N * 2)), - list(src()), + list(src), msg="`limit` must be ok with count >= stream length", ) self.assertEqual( @@ -502,7 +575,7 @@ def test_limit(self) -> None: n_iterations = 0 count = N // 2 raising_stream_iterator = iter( - Stream(lambda: map(lambda x: x / 0, src())).limit(count) + Stream(lambda: map(lambda x: x / 0, src)).limit(count) ) while True: try: @@ -553,7 +626,7 @@ def test_group(self) -> None: def f(i): return i / (110 - i) - stream_iterator = iter(Stream(lambda: map(f, src())).group(100)) + stream_iterator = iter(Stream(lambda: map(f, src)).group(100)) next(stream_iterator) self.assertListEqual( next(stream_iterator), @@ -576,26 +649,26 @@ def f(i): # behavior of the `seconds` parameter self.assertListEqual( list( - Stream(lambda: map(slow_identity, src())).group( + Stream(lambda: map(slow_identity, src)).group( size=100, seconds=0.9 * slow_identity_duration ) ), - list(map(lambda e: [e], src())), + list(map(lambda e: [e], src)), msg="`group` should yield each upstream element alone in a single-element group if `seconds` inferior to the upstream yield period", ) self.assertListEqual( list( - Stream(lambda: map(slow_identity, src())).group( + Stream(lambda: map(slow_identity, src)).group( size=100, seconds=1.8 * slow_identity_duration ) ), - list(map(lambda e: [e, e + 1], pair_src())), + list(map(lambda e: [e, e + 1], pair_src)), msg="`group` should yield upstream elements in a two-element group if `seconds` inferior to twice the upstream yield period", ) self.assertListEqual( next(iter(Stream(src).group())), - list(src()), + list(src), msg="`group` without arguments should group the elements all together", ) @@ -689,16 +762,15 @@ def test_slow(self) -> None: super_slow_elem_pull_seconds = 1 N = 10 expected_duration = (N - 1) * period + super_slow_elem_pull_seconds + duration, _ = timestream( + Stream(range(N)) + .foreach( + lambda e: time.sleep(super_slow_elem_pull_seconds) if e == 0 else None + ) + .slow(frequency=frequency) + ) self.assertAlmostEqual( - timestream( - Stream(range(N)) - .foreach( - lambda e: time.sleep(super_slow_elem_pull_seconds) - if e == 0 - else None - ) - .slow(frequency=frequency) - ), + duration, expected_duration, delta=0.1 * expected_duration, msg="avoid bursts after very slow particular upstream elements", @@ -748,8 +820,8 @@ def test_catch(self) -> None: def f(i): return i / (3 - i) - stream = Stream(lambda: map(f, src())) - safe_src = list(src()) + stream = Stream(lambda: map(f, src)) + safe_src = list(src) del safe_src[3] self.assertListEqual( list(stream.catch(lambda e: isinstance(e, ZeroDivisionError))), @@ -862,12 +934,12 @@ def effect(x: int) -> None: l.append(x) self.assertEqual( - Stream(lambda: map(effect, src())).exhaust(), + Stream(lambda: map(effect, src)).exhaust(), N, msg="`__len__` should return the number of iterated elements.", ) self.assertListEqual( - l, list(src()), msg="`__len__` should iterate over the entire stream." + l, list(src), msg="`__len__` should iterate over the entire stream." ) def test_multiple_iterations(self) -> None: @@ -875,6 +947,54 @@ def test_multiple_iterations(self) -> None: for _ in range(3): self.assertEqual( list(stream), - list(src()), + list(src), msg="The first iteration over a stream should yield the same elements as any subsequent iteration on the same stream, even if it is based on a `source` returning an iterator that only support 1 iteration.", ) + + @parameterized.expand( + [ + [1], + [100], + ] + ) + def test_amap(self, concurrency) -> None: + self.assertListEqual( + list( + Stream(less_and_less_slow_src).amap( + async_square, concurrency=concurrency + ) + ), + list(map(square, src)), + msg="At any concurrency the `amap` method should act as the builtin map function, transforming elements while preserving input elements order.", + ) + stream = Stream(src).amap(identity) # type: ignore + with self.assertRaisesRegex( + TypeError, + "The `func` passed to `amap` or `aforeach` must return a Coroutine object, but got a .", + msg="`amap` should raise a TypeError if a non async function is passed to it.", + ): + next(iter(stream)) + + @parameterized.expand( + [ + [1], + [100], + ] + ) + def test_aforeach(self, concurrency) -> None: + self.assertListEqual( + list( + Stream(less_and_less_slow_src).aforeach( + async_square, concurrency=concurrency + ) + ), + list(src), + msg="At any concurrency the `foreach` method must preserve input elements order.", + ) + stream = Stream(src).aforeach(identity) # type: ignore + with self.assertRaisesRegex( + TypeError, + "`func` is expected to return a Coroutine but got a .", + msg="`aforeach` should raise a TypeError if a non async function is passed to it.", + ): + next(iter(stream)) diff --git a/tests/test_visitor.py b/tests/test_visitor.py index 6dd9e85..d903f07 100644 --- a/tests/test_visitor.py +++ b/tests/test_visitor.py @@ -2,6 +2,8 @@ from typing import cast from streamable.stream import ( + AForeachStream, + AMapStream, CatchStream, FilterStream, FlattenStream, @@ -28,8 +30,10 @@ def visit_stream(self, stream: Stream) -> None: visitor.visit_filter_stream(cast(FilterStream, ...)) visitor.visit_flatten_stream(cast(FlattenStream, ...)) visitor.visit_foreach_stream(cast(ForeachStream, ...)) + visitor.visit_aforeach_stream(cast(AForeachStream, ...)) visitor.visit_limit_stream(cast(LimitStream, ...)) visitor.visit_map_stream(cast(MapStream, ...)) + visitor.visit_amap_stream(cast(AMapStream, ...)) visitor.visit_observe_stream(cast(ObserveStream, ...)) visitor.visit_slow_stream(cast(SlowStream, ...)) visitor.visit_stream(cast(Stream, ...))