Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added type annotations to streams.py and message.py #231

Merged
merged 13 commits into from Jun 29, 2022
1 change: 1 addition & 0 deletions python/doc/source/conf.py
Expand Up @@ -80,6 +80,7 @@ def mock_internal_type(qualname: str) -> mock.Mock:
("py:data", "typing.Optional"),
("py:data", "typing.Tuple"),
("py:data", "typing.Callable"),
("py:obj", "erdos.streams.T"),
]

# Add any paths that contain templates here, relative to this directory.
Expand Down
12 changes: 7 additions & 5 deletions python/erdos/message.py
@@ -1,11 +1,13 @@
import pickle
from typing import Any
from typing import Generic, TypeVar

from erdos.internal import PyMessage
from erdos.timestamp import Timestamp

T = TypeVar("T")

class Message:

class Message(Generic[T]):
"""A :py:class:`Message` allows an operator to send timestamped data to
other operators via a :py:class:`WriteStream` or an
:py:class:`IngestStream`.
Expand All @@ -15,7 +17,7 @@ class Message:
data: The data of the message.
"""

def __init__(self, timestamp: Timestamp, data: Any):
def __init__(self, timestamp: Timestamp, data: T) -> None:
"""Constructs a :py:class:`Message` with the given `data` and
`timestamp`.

Expand Down Expand Up @@ -57,7 +59,7 @@ def __str__(self):
return "{{timestamp: {}, data: {}}}".format(self.timestamp, self.data)


class WatermarkMessage(Message):
class WatermarkMessage(Message[None]):
"""A :py:class:`WatermarkMessage` allows an operator to convey the
completion of all outgoing data for a given timestamp on a
:py:class:`WriteStream`.
Expand All @@ -66,7 +68,7 @@ class WatermarkMessage(Message):
timestamp: The timestamp for which this is a watermark.
"""

def __init__(self, timestamp: Timestamp):
def __init__(self, timestamp: Timestamp) -> None:
super(WatermarkMessage, self).__init__(timestamp, None)

def __str__(self):
Expand Down
75 changes: 41 additions & 34 deletions python/erdos/streams.py
Expand Up @@ -3,7 +3,7 @@
import uuid
from abc import ABC
from itertools import zip_longest
from typing import Any, Callable, Sequence, Tuple, Type, Union
from typing import Callable, Generic, Sequence, Tuple, Type, TypeVar, Union

from erdos.internal import (
PyExtractStream,
Expand Down Expand Up @@ -36,7 +36,11 @@ def _parse_message(internal_msg: PyMessage):
raise Exception("Unable to parse message")


class Stream(ABC):
T = TypeVar("T")
U = TypeVar("U")


class Stream(ABC, Generic[T]):
"""Base class representing a stream to operators can be connected.
from which is subclassed by streams that are used to
connect operators in the driver.
Expand All @@ -45,8 +49,8 @@ class Stream(ABC):
This class should never be initialized manually.
"""

def __init__(self, internal_stream: PyStream):
self._internal_stream = internal_stream
def __init__(self, internal_stream: PyStream) -> None:
self._internal_stream: PyStream = internal_stream

@property
def id(self) -> str:
Expand All @@ -59,10 +63,10 @@ def name(self) -> str:
return self._internal_stream.name()

@name.setter
def name(self, name: str):
def name(self, name: str) -> None:
self._internal_stream.set_name(name)

def map(self, function: Callable[[Any], Any]) -> "OperatorStream":
def map(self, function: Callable[[T], U]) -> "OperatorStream[U]":
"""Applies the given function to each value sent on the stream, and outputs the
results on the returned stream.

Expand All @@ -79,7 +83,7 @@ def map_fn(serialized_data: bytes) -> bytes:

return OperatorStream(self._internal_stream._map(map_fn))

def flat_map(self, function: Callable[[Any], Sequence[Any]]) -> "OperatorStream":
def flat_map(self, function: Callable[[T], Sequence[U]]) -> "OperatorStream[U]":
"""Applies the given function to each value sent on the stream, and outputs the
sequence of received outputs as individual messages.

Expand All @@ -102,7 +106,7 @@ def flat_map_fn(serialized_data: bytes) -> Sequence[bytes]:

return OperatorStream(self._internal_stream._flat_map(flat_map_fn))

def filter(self, function: Callable[[Any], bool]) -> "OperatorStream":
def filter(self, function: Callable[[T], bool]) -> "OperatorStream[T]":
"""Applies the given function to each value sent on the stream, and sends the
value on the returned stream if the function evaluates to `True`.

Expand All @@ -120,8 +124,8 @@ def filter_fn(serialized_data: bytes) -> bool:
return OperatorStream(self._internal_stream._filter(filter_fn))

def split(
self, function: Callable[[Any], bool]
) -> Tuple["OperatorStream", "OperatorStream"]:
self, function: Callable[[T], bool]
) -> Tuple["OperatorStream[T]", "OperatorStream[T]"]:
"""Applies the given function to each value sent on the stream, and outputs the
value to either the left or the right stream depending on if the returned
boolean value is `True` or `False` respectively.
Expand Down Expand Up @@ -170,7 +174,7 @@ def split_by_type(self, *data_type: Type) -> Tuple["OperatorStream"]:

return streams + (last_stream,)

def timestamp_join(self, other: "Stream") -> "OperatorStream":
def timestamp_join(self, other: "Stream[U]") -> "OperatorStream[Tuple[T,U]]":
"""Joins the data with matching timestamps from the two different streams.

Args:
Expand All @@ -190,7 +194,7 @@ def join_fn(serialized_data_left: bytes, serialized_data_right: bytes) -> bytes:
self._internal_stream._timestamp_join(other._internal_stream, join_fn)
)

def concat(self, *other: "Stream") -> "OperatorStream":
def concat(self, *other: "Stream[T]") -> "OperatorStream[T]":
"""Merges the data messages from the given streams into a single stream and
forwards a watermark when a minimum watermark on the streams is achieved.

Expand Down Expand Up @@ -225,7 +229,8 @@ def concat(self, *other: "Stream") -> "OperatorStream":
return streams_to_be_merged[0]


class ReadStream:
class ReadStream(Generic[T]):

"""A :py:class:`ReadStream` allows an operator to read and do work on
data sent by other operators on a corresponding :py:class:`WriteStream`.

Expand All @@ -240,13 +245,13 @@ class ReadStream:
in :code:`run`.
"""

def __init__(self, _py_read_stream: PyReadStream):
def __init__(self, _py_read_stream: PyReadStream) -> None:
logger.debug(
"Initializing ReadStream with the name: {}, and ID: {}.".format(
_py_read_stream.name(), _py_read_stream.id
)
)
self._py_read_stream = _py_read_stream
self._py_read_stream: PyReadStream = _py_read_stream

@property
def name(self) -> str:
Expand All @@ -263,11 +268,11 @@ def is_closed(self) -> bool:
"""Whether a top watermark message has been received."""
return self._py_read_stream.is_closed()

def read(self) -> Message:
def read(self) -> Message[T]:
"""Blocks until a message is read from the stream."""
return _parse_message(self._py_read_stream.read())

def try_read(self) -> Union[Message, None]:
def try_read(self) -> Union[Message[T], None]:
"""Tries to read a mesage from the stream.

Returns None if no messages are available at the moment.
Expand All @@ -278,7 +283,7 @@ def try_read(self) -> Union[Message, None]:
return _parse_message(internal_msg)


class WriteStream:
class WriteStream(Generic[T]):
"""A :py:class:`WriteStream` allows an operator to send messages and
watermarks to other operators that connect to the corresponding
:py:class:`ReadStream`.
Expand All @@ -288,13 +293,13 @@ class WriteStream:
and should never be initialized manually.
"""

def __init__(self, _py_write_stream: PyWriteStream):
def __init__(self, _py_write_stream: PyWriteStream) -> None:
logger.debug(
"Initializing WriteStream with the name: {}, and ID: {}.".format(
_py_write_stream.name(), _py_write_stream.id
)
)
self._py_write_stream = (
self._py_write_stream: PyWriteStream = (
PyWriteStream() if _py_write_stream is None else _py_write_stream
)

Expand All @@ -313,7 +318,7 @@ def is_closed(self) -> bool:
"""Whether a top watermark message has been sent."""
return self._py_write_stream.is_closed()

def send(self, msg: Message):
def send(self, msg: Message[T]) -> None:
"""Sends a message on the stream.

Args:
Expand All @@ -334,36 +339,36 @@ def send(self, msg: Message):
) from e


class OperatorStream(Stream):
class OperatorStream(Stream[T]):
"""Returned when connecting an operator to the dataflow graph.

Note:
This class is created automatically by the `connect` functions, and
should never be initialized manually.
"""

def __init__(self, operator_stream: PyOperatorStream):
def __init__(self, operator_stream: PyOperatorStream) -> None:
super().__init__(operator_stream)


class LoopStream(Stream):
class LoopStream(Stream[T]):
"""Stream placeholder used to construct loops in the dataflow graph.

Note:
Must call `connect_loop` with a valid :py:class:`OperatorStream` to
complete the loop.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__(PyLoopStream())

def connect_loop(self, stream: OperatorStream):
def connect_loop(self, stream: OperatorStream[T]) -> None:
if not isinstance(stream, OperatorStream):
raise TypeError("Loop must be connected to an `OperatorStream`")
self._internal_stream.connect_loop(stream._internal_stream)


class IngestStream(Stream):
class IngestStream(Stream[T]):
"""An :py:class:`IngestStream` enables drivers to inject data into a
running ERDOS application.

Expand All @@ -374,7 +379,7 @@ class IngestStream(Stream):
operator to which it was connected.
"""

def __init__(self, name: Union[str, None] = None):
def __init__(self, name: Union[str, None] = None) -> None:
super().__init__(PyIngestStream(name))

def is_closed(self) -> bool:
Expand All @@ -385,7 +390,7 @@ def is_closed(self) -> bool:
"""
return self._internal_stream.is_closed()

def send(self, msg: Message):
def send(self, msg: Message[T]) -> None:
"""Sends a message on the stream.

Args:
Expand All @@ -403,7 +408,7 @@ def send(self, msg: Message):
self._internal_stream.send(internal_msg)


class ExtractStream:
class ExtractStream(Stream[T]):
"""An :py:class:`ExtractStream` enables drivers to read data from a
running ERDOS applications.

Expand All @@ -418,13 +423,15 @@ class ExtractStream:
stream: The stream from which to read messages.
"""

def __init__(self, stream: OperatorStream):
def __init__(self, stream: OperatorStream[T]) -> None:
if not isinstance(stream, OperatorStream):
raise ValueError(
"ExtractStream needs to be initialized with a Stream. "
"Received a {}".format(type(stream))
)
self._py_extract_stream = PyExtractStream(stream._internal_stream)
self._py_extract_stream: PyExtractStream = PyExtractStream(
stream._internal_stream
)

@property
def name(self) -> str:
Expand All @@ -444,11 +451,11 @@ def is_closed(self) -> bool:
"""
return self._py_extract_stream.is_closed()

def read(self) -> Message:
def read(self) -> Message[T]:
"""Blocks until a message is read from the stream."""
return _parse_message(self._py_extract_stream.read())

def try_read(self) -> Union[Message, None]:
def try_read(self) -> Union[Message[T], None]:
"""Tries to read a mesage from the stream.

Returns :code:`None` if no messages are available at the moment.
Expand Down