Skip to content

Commit

Permalink
allow to instantiate a from an
Browse files Browse the repository at this point in the history
  • Loading branch information
ebonnal committed Jun 7, 2024
1 parent 4e07f4c commit 9680333
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 50 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ from streamable import Stream
## 3. init

```python
integers: Stream[int] = Stream(lambda: range(10))
integers: Stream[int] = Stream(range(10))
```

Instantiate a `Stream[T]` by providing a function that returns a fresh `Iterable[T]` (the data source).
Instantiate a `Stream[T]` from an `Iterable[T]` (the data source).

## 4. operate

Expand Down Expand Up @@ -179,7 +179,7 @@ Tip: enclose operations in parentheses to avoid trailing backslashes `\`.

```python
stream: Stream[str] = (
Stream(lambda: range(10))
Stream(range(10))
.map(str)
.group(2)
.foreach(print)
Expand Down
8 changes: 4 additions & 4 deletions streamable/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ def wrap(arg):
return wrap


def validate_iterable(expected_iterator: Any) -> bool:
def validate_iterable(expected_iterable: Any) -> bool:
"""
Raises:
TypeError: If the expected_iterator does not implement __iter__ and __next__ methods.
TypeError: If the expected_iterable does not implement __iter__ and __next__ methods.
"""
try:
expected_iterator.__iter__
expected_iterable.__iter__
except AttributeError:
raise TypeError(
f"Provided object is not an iterator because it does not implement the __iter__ methods."
f"Provided object is not an iterable because it does not implement the __iter__ methods."
)
return True

Expand Down
15 changes: 11 additions & 4 deletions streamable/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Sequence,
Set,
TypeVar,
Union,
cast,
overload,
)
Expand All @@ -20,6 +21,7 @@
validate_concurrency,
validate_group_seconds,
validate_group_size,
validate_iterable,
validate_limit_count,
validate_slow_frequency,
)
Expand All @@ -35,15 +37,20 @@


class Stream(Iterable[T]):
def __init__(self, source: Callable[[], Iterable[T]]) -> None:
def __init__(self, source: Union[Iterable[T], Callable[[], Iterator[T]]]) -> None:
"""
Initialize a Stream with a source iterable.
Args:
source (Callable[[], Iterable[T]]): Function to be called at iteration to get the stream's source iterable.
source (Union[Iterable[T], Callable[[], Iterator[T]]]): Either an iterable or a function to be called at iteration time to get a fresh source iterator.
"""
if not callable(source):
raise TypeError(f"`source` must be a callable but got a {type(source)}")
try:
validate_iterable(source)
except TypeError:
raise TypeError(
"`source` must be either a Callable[[], Iterator] or an Iterable, but got a <class 'int'>"
)
self._source = source
self._upstream: "Optional[Stream]" = None

Expand All @@ -56,7 +63,7 @@ def upstream(self) -> "Optional[Stream]":
return self._upstream

@property
def source(self) -> Callable[[], Iterable]:
def source(self) -> Union[Iterable, Callable[[], Iterator]]:
"""
Returns:
Callable[[], Iterable]: Function to be called at iteration to get the stream's source iterable.
Expand Down
2 changes: 1 addition & 1 deletion streamable/visitors/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,6 @@ def visit_slow_stream(self, stream: SlowStream[T]) -> Iterator[T]:
return functions.slow(stream.upstream.accept(self), stream.frequency)

def visit_stream(self, stream: Stream[T]) -> Iterator[T]:
iterable = stream._source()
iterable = stream._source() if callable(stream._source) else stream._source
_util.validate_iterable(iterable)
return iter(iterable)
69 changes: 31 additions & 38 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,19 @@ class TestError(Exception):
pass


DELTA_RATE = 0.3
DELTA_RATE = 0.4
# size of the test collections
N = 256

src = range(N).__iter__

def src() -> Iterable[int]:
return range(N)


def less_and_less_slow_src() -> Iterable[int]:
def less_and_less_slow_src() -> Iterator[int]:
"""
Same as `src` but each element is yielded after a sleep time that gets shorter and shorter.
"""
time.sleep(0.1 / N)
return range(N)
return iter(range(N))


def pair_src() -> Iterable[int]:
Expand Down Expand Up @@ -90,10 +88,10 @@ def test_init(self) -> None:

with self.assertRaisesRegex(
TypeError,
"`source` must be a callable but got a <class 'range'>",
msg="Instantiating a Stream with a source not being a callable must raise TypeError.",
"`source` must be either a Callable\[\[\], Iterator\] or an Iterable, but got a <class 'int'>",
msg="Getting an Iterator from a Stream with a source not being a Union[Callable[[], Iterator], ITerable] must raise TypeError.",
):
Stream(range(N)) # type: ignore
Stream(1) # type: ignore

self.assertIs(
Stream(src)
Expand Down Expand Up @@ -186,9 +184,9 @@ def test_add(self) -> None:
msg="stream addition must return a FlattenStream.",
)

stream_a = Stream(lambda: range(10))
stream_b = Stream(lambda: range(10, 20))
stream_c = Stream(lambda: range(20, 30))
stream_a = Stream(range(10))
stream_b = Stream(range(10, 20))
stream_c = Stream(range(20, 30))
self.assertListEqual(
list(stream_a + stream_b + stream_c),
list(range(30)),
Expand Down Expand Up @@ -337,9 +335,7 @@ def test_flatten(self, concurrency) -> None:
)
self.assertListEqual(
list(
Stream(lambda: [iter([]) for _ in range(2000)]).flatten(
concurrency=concurrency
)
Stream([iter([]) for _ in range(2000)]).flatten(concurrency=concurrency)
),
[],
msg="`flatten` should not yield any element if upstream elements are empty iterables, and be resilient to recursion issue in case of successive empty upstream iterables.",
Expand All @@ -349,21 +345,17 @@ def test_flatten(self, concurrency) -> None:
TypeError,
msg="`flatten` should raise if an upstream element is not iterable.",
):
next(iter(Stream(cast(Callable[[], Iterable], src)).flatten()))
next(iter(Stream(cast(Iterable, src)).flatten()))

def test_flatten_typing(self) -> None:
flattened_iterator_stream: Stream[str] = (
Stream(lambda: "abc").map(iter).flatten()
)
flattened_list_stream: Stream[str] = Stream(lambda: "abc").map(list).flatten()
flattened_set_stream: Stream[str] = Stream(lambda: "abc").map(set).flatten()
flattened_iterator_stream: Stream[str] = Stream("abc").map(iter).flatten()
flattened_list_stream: Stream[str] = Stream("abc").map(list).flatten()
flattened_set_stream: Stream[str] = Stream("abc").map(set).flatten()
flattened_map_stream: Stream[str] = (
Stream(lambda: "abc").map(lambda char: map(lambda x: x, char)).flatten()
Stream("abc").map(lambda char: map(lambda x: x, char)).flatten()
)
flattened_filter_stream: Stream[str] = (
Stream(lambda: "abc")
.map(lambda char: filter(lambda _: True, char))
.flatten()
Stream("abc").map(lambda char: filter(lambda _: True, char)).flatten()
)

@parameterized.expand(
Expand All @@ -376,7 +368,7 @@ def test_flatten_typing(self) -> None:
def test_flatten_concurrency(self, concurrency) -> None:
expected_iteration_duration = N * slow_identity_duration / concurrency
n_iterables = 32
iterables_stream = Stream(lambda: range(n_iterables)).map(
iterables_stream = Stream(range(n_iterables)).map(
lambda _: map(slow_identity, range(N // n_iterables))
)
self.assertAlmostEqual(
Expand Down Expand Up @@ -416,7 +408,7 @@ def __iter__(self) -> Iterator[int]:

self.assertSetEqual(
set(
Stream(lambda: range(n_iterables))
Stream(range(n_iterables))
.map(lambda i: cast(Iterable[int], odd_iterable(i, raised_exc)))
.flatten(concurrency=concurrency)
.catch(catched_exc)
Expand Down Expand Up @@ -532,27 +524,27 @@ def test_group(self) -> None:
ValueError,
msg="`group` should raise error when called with `seconds` <= 0.",
):
list(Stream(lambda: [1]).group(size=100, seconds=seconds)),
list(Stream([1]).group(size=100, seconds=seconds)),
for size in [-1, 0]:
with self.assertRaises(
ValueError,
msg="`group` should raise error when called with `size` < 1.",
):
list(Stream(lambda: [1]).group(size=size)),
list(Stream([1]).group(size=size)),

# group size
self.assertListEqual(
list(Stream(lambda: range(6)).group(size=4)),
list(Stream(range(6)).group(size=4)),
[[0, 1, 2, 3], [4, 5]],
msg="",
)
self.assertListEqual(
list(Stream(lambda: range(6)).group(size=2)),
list(Stream(range(6)).group(size=2)),
[[0, 1], [2, 3], [4, 5]],
msg="",
)
self.assertListEqual(
list(Stream(lambda: []).group(size=2)),
list(Stream([]).group(size=2)),
[],
msg="",
)
Expand Down Expand Up @@ -634,7 +626,7 @@ def f(i):
)

self.assertListEqual(
list(Stream(lambda: range(10)).group(by=lambda n: n % 4 == 0)),
list(Stream(range(10)).group(by=lambda n: n % 4 == 0)),
[[0, 4, 8], [1, 2, 3, 5, 6, 7, 9]],
msg="`group` called with a `by` function and reaching exhaustion must cogroup elements and yield uncomplete groups starting with the group containing the oldest element, even though it's not the largest.",
)
Expand Down Expand Up @@ -690,7 +682,7 @@ def test_slow(self) -> None:
ValueError,
msg="`slow` should raise error when called with `frequency` <= 0.",
):
list(Stream(lambda: [1]).slow(frequency=frequency))
list(Stream([1]).slow(frequency=frequency))

frequency = 3
period = 1 / frequency
Expand All @@ -699,7 +691,7 @@ def test_slow(self) -> None:
expected_duration = (N - 1) * period + super_slow_elem_pull_seconds
self.assertAlmostEqual(
timestream(
Stream(lambda: range(N))
Stream(range(N))
.foreach(
lambda e: time.sleep(super_slow_elem_pull_seconds)
if e == 0
Expand Down Expand Up @@ -821,7 +813,7 @@ def f(i):
)

only_catched_errors_stream = (
Stream(lambda: range(2000))
Stream(range(2000))
.map(lambda i: throw(TestError))
.catch(lambda e: isinstance(e, TestError))
)
Expand All @@ -838,7 +830,8 @@ def f(i):

def test_observe(self) -> None:
value_error_rainsing_stream: Stream[List[int]] = (
Stream(lambda: "123--567")
Stream("123--567")
.slow(1)
.observe("chars")
.map(int)
.observe("ints", colored=True)
Expand Down Expand Up @@ -878,7 +871,7 @@ def effect(x: int) -> None:
)

def test_multiple_iterations(self):
stream = Stream(lambda: map(identity, src()))
stream = Stream(src)
for _ in range(3):
self.assertEqual(
list(stream),
Expand Down

0 comments on commit 9680333

Please sign in to comment.