Skip to content

Commit

Permalink
Merge pull request #418 from martindurant/example
Browse files Browse the repository at this point in the history
Allow example for dataframe to not be a dataframe
  • Loading branch information
martindurant committed May 19, 2021
2 parents 0bd2f50 + b52dbca commit fc379a0
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 16 deletions.
5 changes: 1 addition & 4 deletions streamz/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,7 @@ def __init__(self, stream=None, example=None, stream_type=None):
assert example is not None
self.example = example
if not isinstance(self.example, self._subtype):
msg = ("For streaming type %s we expect an example of type %s. "
"Got %s") % (type(self).__name__, self._subtype.__name__,
str(self.example))
raise TypeError(msg)
self.example = self._subtype(example)
assert isinstance(self.example, self._subtype)
self.stream = stream or Stream()
if stream_type:
Expand Down
1 change: 0 additions & 1 deletion streamz/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from __future__ import absolute_import, division, print_function
from collections import deque, defaultdict
from datetime import timedelta
import functools
Expand Down
4 changes: 1 addition & 3 deletions streamz/dataframe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import pandas as pd
import toolz

from tornado import gen

from ..collection import Streaming, _stream_types, OperatorMixin
from ..sources import Source
from ..utils import M
Expand Down Expand Up @@ -1049,7 +1047,7 @@ def stop(self):
async def _cb(interval, source, continue_):
last = pd.Timestamp.now()
while continue_[0]:
await gen.sleep(interval)
await asyncio.sleep(interval)
now = pd.Timestamp.now()
await asyncio.gather(*source._emit(dict(last=last, now=now)))
last = now
Expand Down
8 changes: 4 additions & 4 deletions streamz/dataframe/tests/test_dataframe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ def test_utils_get_base_frame_type_pandas():
with pytest.raises(TypeError):
get_base_frame_type("Index", is_index_like, df)

with pytest.raises(TypeError):
get_base_frame_type("DataFrame", is_dataframe_like, df.x)
# casts Series to DataFrame, if that's what we ask for
assert pd.DataFrame == get_base_frame_type("DataFrame", is_dataframe_like, df.x)
assert pd.Series == get_base_frame_type("Series", is_series_like, df.x)
with pytest.raises(TypeError):
get_base_frame_type("Index", is_index_like, df.x)

with pytest.raises(TypeError):
get_base_frame_type("DataFrame", is_dataframe_like, df.index)
# casts Series to DataFrame, if that's what we ask for
assert pd.DataFrame == get_base_frame_type("DataFrame", is_dataframe_like, df.index)
with pytest.raises(TypeError):
get_base_frame_type("Series", is_series_like, df.index)
assert issubclass(get_base_frame_type("Index", is_index_like, df.index), pd.Index)
Expand Down
6 changes: 5 additions & 1 deletion streamz/dataframe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def get_base_frame_type(frame_name, is_frame_like, example=None):
Returns the base type of streaming objects if type checks pass."""
if example is None:
raise TypeError("Missing required argument:'example'")
if not is_frame_like(example):
if is_frame_like is is_dataframe_like and not is_frame_like(example):
import pandas as pd
example = pd.DataFrame(example)

elif not is_frame_like(example):
msg = "Streaming {0} expects an example of {0} like objects. Got: {1}."\
.format(frame_name, example)
raise TypeError(msg)
Expand Down
6 changes: 3 additions & 3 deletions streamz/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from contextlib import contextmanager
import logging
import os
Expand Down Expand Up @@ -116,11 +117,10 @@ def wait_for(predicate, timeout, fail_func=None, period=0.001):
pytest.fail("condition not reached within %s seconds" % timeout)


@gen.coroutine
def await_for(predicate, timeout, fail_func=None, period=0.001):
async def await_for(predicate, timeout, fail_func=None, period=0.001):
deadline = time() + timeout
while not predicate():
yield gen.sleep(period)
await asyncio.sleep(period)
if time() > deadline: # pragma: no cover
if fail_func is not None:
fail_func()
Expand Down

0 comments on commit fc379a0

Please sign in to comment.