diff --git a/examples/sink/async_log/pyproject.toml b/examples/sink/async_log/pyproject.toml index 9c1c6b92..2ff3e97a 100644 --- a/examples/sink/async_log/pyproject.toml +++ b/examples/sink/async_log/pyproject.toml @@ -5,7 +5,7 @@ description = "" authors = ["Numaflow developers"] [tool.poetry.dependencies] -python = "~3.10" +python = ">=3.10,<3.13" pynumaflow = { path = "../../../"} [tool.poetry.dev-dependencies] diff --git a/examples/sink/log/pyproject.toml b/examples/sink/log/pyproject.toml index 9c1c6b92..2ff3e97a 100644 --- a/examples/sink/log/pyproject.toml +++ b/examples/sink/log/pyproject.toml @@ -5,7 +5,7 @@ description = "" authors = ["Numaflow developers"] [tool.poetry.dependencies] -python = "~3.10" +python = ">=3.10,<3.13" pynumaflow = { path = "../../../"} [tool.poetry.dev-dependencies] diff --git a/pynumaflow/proto/sinker/sink.proto b/pynumaflow/proto/sinker/sink.proto index df599f03..0cb2f69b 100644 --- a/pynumaflow/proto/sinker/sink.proto +++ b/pynumaflow/proto/sinker/sink.proto @@ -1,4 +1,5 @@ syntax = "proto3"; + import "google/protobuf/empty.proto"; import "google/protobuf/timestamp.proto"; @@ -7,7 +8,7 @@ package sink.v1; service Sink { // SinkFn writes the request to a user defined sink. - rpc SinkFn(stream SinkRequest) returns (SinkResponse); + rpc SinkFn(stream SinkRequest) returns (stream SinkResponse); // IsReady is the heartbeat endpoint for gRPC. rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); @@ -17,12 +18,29 @@ service Sink { * SinkRequest represents a request element. */ message SinkRequest { - repeated string keys = 1; - bytes value = 2; - google.protobuf.Timestamp event_time = 3; - google.protobuf.Timestamp watermark = 4; - string id = 5; - map headers = 6; + message Request { + repeated string keys = 1; + bytes value = 2; + google.protobuf.Timestamp event_time = 3; + google.protobuf.Timestamp watermark = 4; + string id = 5; + map headers = 6; + } + // Required field indicating the request. + Request request = 1; + // Required field indicating the status of the request. + // If eot is set to true, it indicates the end of transmission. + TransmissionStatus status = 2; + // optional field indicating the handshake message. + optional Handshake handshake = 3; +} + +/* + * Handshake message between client and server to indicate the start of transmission. + */ +message Handshake { + // Required field indicating the start of transmission. + bool sot = 1; } /** @@ -32,6 +50,13 @@ message ReadyResponse { bool ready = 1; } +/** + * TransmissionStatus is the status of the transmission. + */ +message TransmissionStatus { + bool eot = 1; +} + /* * Status is the status of the response. */ @@ -53,5 +78,7 @@ message SinkResponse { // err_msg is the error message, set it if success is set to false. string err_msg = 3; } - repeated Result results = 1; + Result result = 1; + optional Handshake handshake = 2; + optional TransmissionStatus status = 3; } \ No newline at end of file diff --git a/pynumaflow/proto/sinker/sink_pb2.py b/pynumaflow/proto/sinker/sink_pb2.py index 00b8326e..3dcfcd9b 100644 --- a/pynumaflow/proto/sinker/sink_pb2.py +++ b/pynumaflow/proto/sinker/sink_pb2.py @@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\nsink.proto\x12\x07sink.v1\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1fgoogle/protobuf/timestamp.proto"\xf9\x01\n\x0bSinkRequest\x12\x0c\n\x04keys\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x12.\n\nevent_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12-\n\twatermark\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\n\n\x02id\x18\x05 \x01(\t\x12\x32\n\x07headers\x18\x06 \x03(\x0b\x32!.sink.v1.SinkRequest.HeadersEntry\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01"\x1e\n\rReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08"\x85\x01\n\x0cSinkResponse\x12-\n\x07results\x18\x01 \x03(\x0b\x32\x1c.sink.v1.SinkResponse.Result\x1a\x46\n\x06Result\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1f\n\x06status\x18\x02 \x01(\x0e\x32\x0f.sink.v1.Status\x12\x0f\n\x07\x65rr_msg\x18\x03 \x01(\t*0\n\x06Status\x12\x0b\n\x07SUCCESS\x10\x00\x12\x0b\n\x07\x46\x41ILURE\x10\x01\x12\x0c\n\x08\x46\x41LLBACK\x10\x02\x32z\n\x04Sink\x12\x37\n\x06SinkFn\x12\x14.sink.v1.SinkRequest\x1a\x15.sink.v1.SinkResponse(\x01\x12\x39\n\x07IsReady\x12\x16.google.protobuf.Empty\x1a\x16.sink.v1.ReadyResponseb\x06proto3' + b'\n\nsink.proto\x12\x07sink.v1\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1fgoogle/protobuf/timestamp.proto"\xa3\x03\n\x0bSinkRequest\x12-\n\x07request\x18\x01 \x01(\x0b\x32\x1c.sink.v1.SinkRequest.Request\x12+\n\x06status\x18\x02 \x01(\x0b\x32\x1b.sink.v1.TransmissionStatus\x12*\n\thandshake\x18\x03 \x01(\x0b\x32\x12.sink.v1.HandshakeH\x00\x88\x01\x01\x1a\xfd\x01\n\x07Request\x12\x0c\n\x04keys\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x12.\n\nevent_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12-\n\twatermark\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\n\n\x02id\x18\x05 \x01(\t\x12:\n\x07headers\x18\x06 \x03(\x0b\x32).sink.v1.SinkRequest.Request.HeadersEntry\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\x0c\n\n_handshake"\x18\n\tHandshake\x12\x0b\n\x03sot\x18\x01 \x01(\x08"\x1e\n\rReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08"!\n\x12TransmissionStatus\x12\x0b\n\x03\x65ot\x18\x01 \x01(\x08"\xfb\x01\n\x0cSinkResponse\x12,\n\x06result\x18\x01 \x01(\x0b\x32\x1c.sink.v1.SinkResponse.Result\x12*\n\thandshake\x18\x02 \x01(\x0b\x32\x12.sink.v1.HandshakeH\x00\x88\x01\x01\x12\x30\n\x06status\x18\x03 \x01(\x0b\x32\x1b.sink.v1.TransmissionStatusH\x01\x88\x01\x01\x1a\x46\n\x06Result\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1f\n\x06status\x18\x02 \x01(\x0e\x32\x0f.sink.v1.Status\x12\x0f\n\x07\x65rr_msg\x18\x03 \x01(\tB\x0c\n\n_handshakeB\t\n\x07_status*0\n\x06Status\x12\x0b\n\x07SUCCESS\x10\x00\x12\x0b\n\x07\x46\x41ILURE\x10\x01\x12\x0c\n\x08\x46\x41LLBACK\x10\x02\x32|\n\x04Sink\x12\x39\n\x06SinkFn\x12\x14.sink.v1.SinkRequest\x1a\x15.sink.v1.SinkResponse(\x01\x30\x01\x12\x39\n\x07IsReady\x12\x16.google.protobuf.Empty\x1a\x16.sink.v1.ReadyResponseb\x06proto3' ) _globals = globals() @@ -26,20 +26,26 @@ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "sink_pb2", _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals["_SINKREQUEST_HEADERSENTRY"]._options = None - _globals["_SINKREQUEST_HEADERSENTRY"]._serialized_options = b"8\001" - _globals["_STATUS"]._serialized_start = 505 - _globals["_STATUS"]._serialized_end = 553 + _globals["_SINKREQUEST_REQUEST_HEADERSENTRY"]._options = None + _globals["_SINKREQUEST_REQUEST_HEADERSENTRY"]._serialized_options = b"8\001" + _globals["_STATUS"]._serialized_start = 854 + _globals["_STATUS"]._serialized_end = 902 _globals["_SINKREQUEST"]._serialized_start = 86 - _globals["_SINKREQUEST"]._serialized_end = 335 - _globals["_SINKREQUEST_HEADERSENTRY"]._serialized_start = 289 - _globals["_SINKREQUEST_HEADERSENTRY"]._serialized_end = 335 - _globals["_READYRESPONSE"]._serialized_start = 337 - _globals["_READYRESPONSE"]._serialized_end = 367 - _globals["_SINKRESPONSE"]._serialized_start = 370 - _globals["_SINKRESPONSE"]._serialized_end = 503 - _globals["_SINKRESPONSE_RESULT"]._serialized_start = 433 - _globals["_SINKRESPONSE_RESULT"]._serialized_end = 503 - _globals["_SINK"]._serialized_start = 555 - _globals["_SINK"]._serialized_end = 677 + _globals["_SINKREQUEST"]._serialized_end = 505 + _globals["_SINKREQUEST_REQUEST"]._serialized_start = 238 + _globals["_SINKREQUEST_REQUEST"]._serialized_end = 491 + _globals["_SINKREQUEST_REQUEST_HEADERSENTRY"]._serialized_start = 445 + _globals["_SINKREQUEST_REQUEST_HEADERSENTRY"]._serialized_end = 491 + _globals["_HANDSHAKE"]._serialized_start = 507 + _globals["_HANDSHAKE"]._serialized_end = 531 + _globals["_READYRESPONSE"]._serialized_start = 533 + _globals["_READYRESPONSE"]._serialized_end = 563 + _globals["_TRANSMISSIONSTATUS"]._serialized_start = 565 + _globals["_TRANSMISSIONSTATUS"]._serialized_end = 598 + _globals["_SINKRESPONSE"]._serialized_start = 601 + _globals["_SINKRESPONSE"]._serialized_end = 852 + _globals["_SINKRESPONSE_RESULT"]._serialized_start = 757 + _globals["_SINKRESPONSE_RESULT"]._serialized_end = 827 + _globals["_SINK"]._serialized_start = 904 + _globals["_SINK"]._serialized_end = 1028 # @@protoc_insertion_point(module_scope) diff --git a/pynumaflow/proto/sinker/sink_pb2.pyi b/pynumaflow/proto/sinker/sink_pb2.pyi index 71dcdf69..70b24c22 100644 --- a/pynumaflow/proto/sinker/sink_pb2.pyi +++ b/pynumaflow/proto/sinker/sink_pb2.pyi @@ -25,45 +25,72 @@ FAILURE: Status FALLBACK: Status class SinkRequest(_message.Message): - __slots__ = ("keys", "value", "event_time", "watermark", "id", "headers") + __slots__ = ("request", "status", "handshake") - class HeadersEntry(_message.Message): - __slots__ = ("key", "value") - KEY_FIELD_NUMBER: _ClassVar[int] + class Request(_message.Message): + __slots__ = ("keys", "value", "event_time", "watermark", "id", "headers") + + class HeadersEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + KEYS_FIELD_NUMBER: _ClassVar[int] VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: str - def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - KEYS_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - EVENT_TIME_FIELD_NUMBER: _ClassVar[int] - WATERMARK_FIELD_NUMBER: _ClassVar[int] - ID_FIELD_NUMBER: _ClassVar[int] - HEADERS_FIELD_NUMBER: _ClassVar[int] - keys: _containers.RepeatedScalarFieldContainer[str] - value: bytes - event_time: _timestamp_pb2.Timestamp - watermark: _timestamp_pb2.Timestamp - id: str - headers: _containers.ScalarMap[str, str] + EVENT_TIME_FIELD_NUMBER: _ClassVar[int] + WATERMARK_FIELD_NUMBER: _ClassVar[int] + ID_FIELD_NUMBER: _ClassVar[int] + HEADERS_FIELD_NUMBER: _ClassVar[int] + keys: _containers.RepeatedScalarFieldContainer[str] + value: bytes + event_time: _timestamp_pb2.Timestamp + watermark: _timestamp_pb2.Timestamp + id: str + headers: _containers.ScalarMap[str, str] + def __init__( + self, + keys: _Optional[_Iterable[str]] = ..., + value: _Optional[bytes] = ..., + event_time: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., + watermark: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., + id: _Optional[str] = ..., + headers: _Optional[_Mapping[str, str]] = ..., + ) -> None: ... + REQUEST_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + HANDSHAKE_FIELD_NUMBER: _ClassVar[int] + request: SinkRequest.Request + status: TransmissionStatus + handshake: Handshake def __init__( self, - keys: _Optional[_Iterable[str]] = ..., - value: _Optional[bytes] = ..., - event_time: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., - watermark: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., - id: _Optional[str] = ..., - headers: _Optional[_Mapping[str, str]] = ..., + request: _Optional[_Union[SinkRequest.Request, _Mapping]] = ..., + status: _Optional[_Union[TransmissionStatus, _Mapping]] = ..., + handshake: _Optional[_Union[Handshake, _Mapping]] = ..., ) -> None: ... +class Handshake(_message.Message): + __slots__ = ("sot",) + SOT_FIELD_NUMBER: _ClassVar[int] + sot: bool + def __init__(self, sot: bool = ...) -> None: ... + class ReadyResponse(_message.Message): __slots__ = ("ready",) READY_FIELD_NUMBER: _ClassVar[int] ready: bool def __init__(self, ready: bool = ...) -> None: ... +class TransmissionStatus(_message.Message): + __slots__ = ("eot",) + EOT_FIELD_NUMBER: _ClassVar[int] + eot: bool + def __init__(self, eot: bool = ...) -> None: ... + class SinkResponse(_message.Message): - __slots__ = ("results",) + __slots__ = ("result", "handshake", "status") class Result(_message.Message): __slots__ = ("id", "status", "err_msg") @@ -79,8 +106,15 @@ class SinkResponse(_message.Message): status: _Optional[_Union[Status, str]] = ..., err_msg: _Optional[str] = ..., ) -> None: ... - RESULTS_FIELD_NUMBER: _ClassVar[int] - results: _containers.RepeatedCompositeFieldContainer[SinkResponse.Result] + RESULT_FIELD_NUMBER: _ClassVar[int] + HANDSHAKE_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + result: SinkResponse.Result + handshake: Handshake + status: TransmissionStatus def __init__( - self, results: _Optional[_Iterable[_Union[SinkResponse.Result, _Mapping]]] = ... + self, + result: _Optional[_Union[SinkResponse.Result, _Mapping]] = ..., + handshake: _Optional[_Union[Handshake, _Mapping]] = ..., + status: _Optional[_Union[TransmissionStatus, _Mapping]] = ..., ) -> None: ... diff --git a/pynumaflow/proto/sinker/sink_pb2_grpc.py b/pynumaflow/proto/sinker/sink_pb2_grpc.py index ef673e9d..9609c76e 100644 --- a/pynumaflow/proto/sinker/sink_pb2_grpc.py +++ b/pynumaflow/proto/sinker/sink_pb2_grpc.py @@ -15,7 +15,7 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ - self.SinkFn = channel.stream_unary( + self.SinkFn = channel.stream_stream( "/sink.v1.Sink/SinkFn", request_serializer=sink__pb2.SinkRequest.SerializeToString, response_deserializer=sink__pb2.SinkResponse.FromString, @@ -45,7 +45,7 @@ def IsReady(self, request, context): def add_SinkServicer_to_server(servicer, server): rpc_method_handlers = { - "SinkFn": grpc.stream_unary_rpc_method_handler( + "SinkFn": grpc.stream_stream_rpc_method_handler( servicer.SinkFn, request_deserializer=sink__pb2.SinkRequest.FromString, response_serializer=sink__pb2.SinkResponse.SerializeToString, @@ -77,7 +77,7 @@ def SinkFn( timeout=None, metadata=None, ): - return grpc.experimental.stream_unary( + return grpc.experimental.stream_stream( request_iterator, target, "/sink.v1.Sink/SinkFn", diff --git a/pynumaflow/reducestreamer/servicer/async_servicer.py b/pynumaflow/reducestreamer/servicer/async_servicer.py index 0242deab..43b91986 100644 --- a/pynumaflow/reducestreamer/servicer/async_servicer.py +++ b/pynumaflow/reducestreamer/servicer/async_servicer.py @@ -2,7 +2,6 @@ from collections.abc import AsyncIterable from typing import Union -import grpc from google.protobuf import empty_pb2 as _empty_pb2 from pynumaflow.proto.reducer import reduce_pb2, reduce_pb2_grpc @@ -13,7 +12,7 @@ ReduceRequest, ) from pynumaflow.reducestreamer.servicer.task_manager import TaskManager -from pynumaflow.shared.server import exit_on_error, handle_error +from pynumaflow.shared.server import handle_async_error from pynumaflow.types import NumaflowServicerContext @@ -95,35 +94,20 @@ async def ReduceFn( async for msg in consumer: # If the message is an exception, we raise the exception if isinstance(msg, BaseException): - handle_error(context, msg) - await asyncio.gather( - context.abort(grpc.StatusCode.UNKNOWN, details=repr(msg)), - return_exceptions=True, - ) - exit_on_error( - err=repr(msg), parent=False, context=context, update_context=False - ) + await handle_async_error(context, msg) return # Send window EOF response or Window result response # back to the client else: yield msg except BaseException as e: - handle_error(context, e) - await asyncio.gather( - context.abort(grpc.StatusCode.UNKNOWN, details=repr(e)), return_exceptions=True - ) - exit_on_error(err=repr(e), parent=False, context=context, update_context=False) + await handle_async_error(context, e) return # Wait for the process_input_stream task to finish for a clean exit try: await producer except BaseException as e: - handle_error(context, e) - await asyncio.gather( - context.abort(grpc.StatusCode.UNKNOWN, details=repr(e)), return_exceptions=True - ) - exit_on_error(err=repr(e), parent=False, context=context, update_context=False) + await handle_async_error(context, e) return async def IsReady( diff --git a/pynumaflow/shared/asynciter.py b/pynumaflow/shared/asynciter.py index 3ab6135b..91155b93 100644 --- a/pynumaflow/shared/asynciter.py +++ b/pynumaflow/shared/asynciter.py @@ -8,8 +8,8 @@ class NonBlockingIterator: __slots__ = "_queue" - def __init__(self): - self._queue = asyncio.Queue() + def __init__(self, size=0): + self._queue = asyncio.Queue(maxsize=size) async def read_iterator(self): item = await self._queue.get() diff --git a/pynumaflow/shared/server.py b/pynumaflow/shared/server.py index 2e9de168..ab86c9f0 100644 --- a/pynumaflow/shared/server.py +++ b/pynumaflow/shared/server.py @@ -1,3 +1,4 @@ +import asyncio import contextlib import io import multiprocessing @@ -266,7 +267,10 @@ def exit_on_error( p.kill() -def handle_error(context: NumaflowServicerContext, e: BaseException): +def update_context_err(context: NumaflowServicerContext, e: BaseException): + """ + Update the context with the error and log the exception. + """ trace = get_exception_traceback_str(e) _LOGGER.critical(trace) _LOGGER.critical(e.__str__()) @@ -278,3 +282,14 @@ def get_exception_traceback_str(exc) -> str: file = io.StringIO() traceback.print_exception(exc, value=exc, tb=exc.__traceback__, file=file) return file.getvalue().rstrip() + + +async def handle_async_error(context: NumaflowServicerContext, exception: BaseException): + """ + Handle exceptions for async servers by updating the context and exiting. + """ + update_context_err(context, exception) + await asyncio.gather( + context.abort(grpc.StatusCode.UNKNOWN, details=repr(exception)), return_exceptions=True + ) + exit_on_error(err=repr(exception), parent=False, context=context, update_context=False) diff --git a/pynumaflow/shared/synciter.py b/pynumaflow/shared/synciter.py new file mode 100644 index 00000000..b7c38455 --- /dev/null +++ b/pynumaflow/shared/synciter.py @@ -0,0 +1,23 @@ +from queue import Queue + +from pynumaflow._constants import STREAM_EOF + + +class SyncIterator: + """A Sync Interator backed by a queue""" + + __slots__ = "_queue" + + def __init__(self, size=0): + self._queue = Queue(maxsize=size) + + def read_iterator(self): + item = self._queue.get() + while True: + if item == STREAM_EOF: + break + yield item + item = self._queue.get() + + def put(self, item): + self._queue.put(item) diff --git a/pynumaflow/shared/thread_with_return.py b/pynumaflow/shared/thread_with_return.py new file mode 100644 index 00000000..9b9a7643 --- /dev/null +++ b/pynumaflow/shared/thread_with_return.py @@ -0,0 +1,54 @@ +from threading import Thread + + +class ThreadWithReturnValue(Thread): + """ + A custom Thread class that allows the target function to return a value. + This class extends the built-in threading.Thread class. + """ + + def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, verbose=None): + """ + Initializes the ThreadWithReturnValue object. + + Parameters: + group (threading.ThreadGroup) + target (callable): The callable object to be invoked by the run() method. Defaults to None. + name (str): The thread name. Defaults to None. + args (tuple): The argument tuple for the target invocation. Defaults to (). + kwargs (dict): The dictionary of keyword arguments for the target invocation. + Defaults to {}. + verbose: Not used, defaults to None. + """ + Thread.__init__(self, group, target, name, args, kwargs) + # Variable to store the return value of the target function + self._return = None + + def run(self): + """ + Run the thread. + + This method is overridden from the Thread class. + It calls the target function and saves the return value. + """ + if self._target is not None: + # Execute target and store the result + self._return = self._target(*self._args, **self._kwargs) + + def join(self, *args): + """ + Wait for the thread to complete and return the result. + + This method is overridden from the Thread class. + It calls the parent class's join() method and then returns the stored return value. + + Parameters: + *args: Variable length argument list to pass to the join() method. + + Returns: + The return value from the target function. + """ + # Call the parent class's join() method to wait for the thread to finish + Thread.join(self, *args) + # Return the result of the target function + return self._return diff --git a/pynumaflow/sinker/__init__.py b/pynumaflow/sinker/__init__.py index 4df6f270..322b5e81 100644 --- a/pynumaflow/sinker/__init__.py +++ b/pynumaflow/sinker/__init__.py @@ -1,6 +1,7 @@ from pynumaflow.sinker.async_server import SinkAsyncServer + from pynumaflow.sinker.server import SinkServer from pynumaflow.sinker._dtypes import Response, Responses, Datum, Sinker -__all__ = ["Response", "Responses", "Datum", "Sinker", "SinkServer", "SinkAsyncServer"] +__all__ = ["Response", "Responses", "Datum", "Sinker", "SinkAsyncServer", "SinkServer"] diff --git a/pynumaflow/sinker/servicer/async_servicer.py b/pynumaflow/sinker/servicer/async_servicer.py index 9f02d005..96d6a62f 100644 --- a/pynumaflow/sinker/servicer/async_servicer.py +++ b/pynumaflow/sinker/servicer/async_servicer.py @@ -1,29 +1,19 @@ +import asyncio from collections.abc import AsyncIterable from google.protobuf import empty_pb2 as _empty_pb2 +from pynumaflow.shared.asynciter import NonBlockingIterator from pynumaflow.shared.server import exit_on_error -from pynumaflow.sinker._dtypes import Datum -from pynumaflow.sinker._dtypes import SyncSinkCallable +from pynumaflow.sinker._dtypes import Datum, AsyncSinkCallable from pynumaflow.proto.sinker import sink_pb2_grpc, sink_pb2 -from pynumaflow.sinker.servicer.utils import build_sink_response +from pynumaflow.sinker.servicer.utils import ( + datum_from_sink_req, + _create_read_handshake_response, + build_sink_resp_results, +) from pynumaflow.types import NumaflowServicerContext -from pynumaflow._constants import _LOGGER - - -async def datum_generator( - request_iterator: AsyncIterable[sink_pb2.SinkRequest], -) -> AsyncIterable[Datum]: - async for d in request_iterator: - datum = Datum( - keys=list(d.keys), - sink_msg_id=d.id, - value=d.value, - event_time=d.event_time.ToDatetime(), - watermark=d.watermark.ToDatetime(), - headers=dict(d.headers), - ) - yield datum +from pynumaflow._constants import _LOGGER, STREAM_EOF class AsyncSinkServicer(sink_pb2_grpc.SinkServicer): @@ -35,9 +25,10 @@ class AsyncSinkServicer(sink_pb2_grpc.SinkServicer): def __init__( self, - handler: SyncSinkCallable, + handler: AsyncSinkCallable, ): - self.__sink_handler: SyncSinkCallable = handler + self.background_tasks = set() + self.__sink_handler: AsyncSinkCallable = handler self.cleanup_coroutines = [] async def SinkFn( @@ -49,32 +40,67 @@ async def SinkFn( Applies a sink function to a list of datum elements. The pascal case function name comes from the proto sink_pb2_grpc.py file. """ - # if there is an exception, we will mark all the responses as a failure - datum_iterator = datum_generator(request_iterator=request_iterator) try: - results = await self.__invoke_sink(datum_iterator, context) + # The first message to be received should be a valid handshake + req = await request_iterator.__anext__() + # check if it is a valid handshake req + if not (req.handshake and req.handshake.sot): + raise Exception("ReadFn: expected handshake message") + yield _create_read_handshake_response() + + # cur_task is used to track the task (coroutine) processing + # the current batch of messages. + cur_task = None + # iterate of the incoming messages ot the sink + async for d in request_iterator: + # if we do not have any active task currently processing the batch + # we need to create one and call the User function for processing the same. + if cur_task is None: + req_queue = NonBlockingIterator() + cur_task = asyncio.create_task( + self.__invoke_sink(req_queue.read_iterator(), context) + ) + self.background_tasks.add(cur_task) + cur_task.add_done_callback(self.background_tasks.discard) + + # when we have end of transmission message, we need to stop the processing the + # current batch and wait for the next batch of messages. + # We will also wait for the current task to finish processing the current batch. + # We mark the current task as None to indicate that we are + # ready to process the next batch. + if d.status and d.status.eot: + await req_queue.put(STREAM_EOF) + await cur_task + ret = cur_task.result() + for r in ret: + yield sink_pb2.SinkResponse(result=r) + # send EOT after each finishing sink responses + yield sink_pb2.SinkResponse(status=sink_pb2.TransmissionStatus(eot=True)) + cur_task = None + continue + + # if we have a valid message, we will add it to the request queue for processing. + datum = datum_from_sink_req(d) + await req_queue.put(datum) except BaseException as err: + # if there is an exception, we will mark all the responses as a failure err_msg = f"UDSinkError: {repr(err)}" _LOGGER.critical(err_msg, exc_info=True) exit_on_error(context, err_msg) return - return sink_pb2.SinkResponse(results=results) - async def __invoke_sink( - self, datum_iterator: AsyncIterable[Datum], context: NumaflowServicerContext + self, request_queue: AsyncIterable[Datum], context: NumaflowServicerContext ): try: - rspns = await self.__sink_handler(datum_iterator) + # invoke the user function with the request queue + rspns = await self.__sink_handler(request_queue) + return build_sink_resp_results(rspns) except BaseException as err: err_msg = f"UDSinkError: {repr(err)}" _LOGGER.critical(err_msg, exc_info=True) exit_on_error(context, err_msg) raise err - responses = [] - for rspn in rspns: - responses.append(build_sink_response(rspn)) - return responses async def IsReady( self, request: _empty_pb2.Empty, context: NumaflowServicerContext diff --git a/pynumaflow/sinker/servicer/sync_servicer.py b/pynumaflow/sinker/servicer/sync_servicer.py index a1f307d1..629ad96b 100644 --- a/pynumaflow/sinker/servicer/sync_servicer.py +++ b/pynumaflow/sinker/servicer/sync_servicer.py @@ -1,28 +1,20 @@ -from collections.abc import Iterator, Iterable +from collections.abc import Iterable -from google.protobuf import empty_pb2 as _empty_pb2 -from pynumaflow._constants import _LOGGER + +from pynumaflow._constants import _LOGGER, STREAM_EOF +from pynumaflow.proto.sinker import sink_pb2_grpc, sink_pb2 from pynumaflow.shared.server import exit_on_error -from pynumaflow.sinker._dtypes import Datum +from pynumaflow.shared.synciter import SyncIterator +from pynumaflow.shared.thread_with_return import ThreadWithReturnValue from pynumaflow.sinker._dtypes import SyncSinkCallable -from pynumaflow.proto.sinker import sink_pb2_grpc, sink_pb2 -from pynumaflow.sinker.servicer.utils import build_sink_response +from pynumaflow.sinker.servicer.utils import ( + datum_from_sink_req, + _create_read_handshake_response, + build_sink_resp_results, +) from pynumaflow.types import NumaflowServicerContext -def datum_generator(request_iterator: Iterable[sink_pb2.SinkRequest]) -> Iterable[Datum]: - for d in request_iterator: - datum = Datum( - keys=list(d.keys), - sink_msg_id=d.id, - value=d.value, - event_time=d.event_time.ToDatetime(), - watermark=d.watermark.ToDatetime(), - headers=dict(d.headers), - ) - yield datum - - class SyncSinkServicer(sink_pb2_grpc.SinkServicer): """ This class is used to create a new grpc Sink servicer instance. @@ -30,40 +22,83 @@ class SyncSinkServicer(sink_pb2_grpc.SinkServicer): Provides the functionality for the required rpc methods. """ - def __init__( - self, - handler: SyncSinkCallable, - ): - self.__sink_handler: SyncSinkCallable = handler + def __init__(self, handler: SyncSinkCallable): + self.handler: SyncSinkCallable = handler def SinkFn( - self, request_iterator: Iterator[sink_pb2.SinkRequest], context: NumaflowServicerContext - ) -> sink_pb2.SinkResponse: + self, request_iterator: Iterable[sink_pb2.SinkRequest], context: NumaflowServicerContext + ) -> Iterable[sink_pb2.SinkResponse]: """ - Applies a sink function to a list of datum elements. - The pascal case function name comes from the proto sink_pb2_grpc.py file. + Applies a sink function to datum elements. """ - # if there is an exception, we will mark all the responses as a failure - datum_iterator = datum_generator(request_iterator) + try: - rspns = self.__sink_handler(datum_iterator) + # The first message to be received should be a valid handshake + req = next(request_iterator) + # check if it is a valid handshake req + if not (req.handshake and req.handshake.sot): + raise Exception("SinkFn: expected handshake message") + yield _create_read_handshake_response() + # cur_task is used to track the thread processing + # the current batch of messages. + cur_task = None + # Use a queue backed to handle request batches + req_queue = SyncIterator() + + # iterate of the incoming messages ot the sink + for d in request_iterator: + # if we do not have any active thread currently processing the batch + # we need to create one and call the User function for processing the same. + if cur_task is None: + # Use a queue to handle request batches + req_queue = SyncIterator() + cur_task = ThreadWithReturnValue( + target=self._invoke_sink, args=(req_queue, context) + ) + cur_task.start() + + # when we have end of transmission message, we need to stop the processing the + # current batch and wait for the next batch of messages. + # We will also wait for the current task to finish processing the current batch. + # We mark the current task as None to indicate that we are + # ready to process the next batch. + if d.status and d.status.eot: + req_queue.put(STREAM_EOF) + ret = cur_task.join() + for resp in ret: + yield sink_pb2.SinkResponse(result=resp) + # send EOT after each finishing sink responses + yield sink_pb2.SinkResponse(status=sink_pb2.TransmissionStatus(eot=True)) + cur_task = None + continue + + # if we have a valid message, we will add it to the request queue for processing. + datum = datum_from_sink_req(d) + req_queue.put(datum) + + if cur_task: + cur_task.join() + except BaseException as err: + # Handle exceptions err_msg = f"UDSinkError: {repr(err)}" _LOGGER.critical(err_msg, exc_info=True) exit_on_error(context, err_msg) return - responses = [] - for rspn in rspns: - responses.append(build_sink_response(rspn)) - - return sink_pb2.SinkResponse(results=responses) + def _invoke_sink(self, request_queue: SyncIterator, context: NumaflowServicerContext): + try: + # Invoke the handler function with the request queue + rspns = self.handler(request_queue.read_iterator()) + return build_sink_resp_results(rspns) + except BaseException as err: + err_msg = f"UDSinkError: {repr(err)}" + _LOGGER.critical(err_msg, exc_info=True) + exit_on_error(context, err_msg) + raise err - def IsReady( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> sink_pb2.ReadyResponse: + def IsReady(self, request, context: NumaflowServicerContext) -> sink_pb2.ReadyResponse: """ IsReady is the heartbeat endpoint for gRPC. - The pascal case function name comes from the proto sink_pb2_grpc.py file. """ return sink_pb2.ReadyResponse(ready=True) diff --git a/pynumaflow/sinker/servicer/utils.py b/pynumaflow/sinker/servicer/utils.py index 2de6f93c..ddd6b863 100644 --- a/pynumaflow/sinker/servicer/utils.py +++ b/pynumaflow/sinker/servicer/utils.py @@ -1,13 +1,71 @@ from pynumaflow.proto.sinker import sink_pb2 -from pynumaflow.sinker._dtypes import Response +from pynumaflow.sinker._dtypes import Response, Datum, Responses -def build_sink_response(rspn: Response): +def build_sink_resp_results(responses: Responses) -> list[sink_pb2.SinkResponse.Result]: + """ + Given a list of Response objects, build the corresponding list of SinkResponse.Result objects. + + Parameters: + responses (Responses): A list of Response objects containing the results of sink operations. + + Returns: + list[sink_pb2.SinkResponse.Result]: A list of SinkResponse.Result objects. + """ + return [build_sink_response(rspn) for rspn in responses] + + +def build_sink_response(rspn: Response) -> sink_pb2.SinkResponse.Result: + """ + Build a SinkResponse.Result object from a Response object. + + Parameters: + rspn (Response): A Response object containing the result information of a single sink operation. + + Returns: + sink_pb2.SinkResponse.Result: A SinkResponse.Result + object populated with the status and id of the response. + """ + rid = rspn.id if rspn.success: - return sink_pb2.SinkResponse.Result(id=rspn.id, status=sink_pb2.Status.SUCCESS) + return sink_pb2.SinkResponse.Result(id=rid, status=sink_pb2.Status.SUCCESS) elif rspn.fallback: - return sink_pb2.SinkResponse.Result(id=rspn.id, status=sink_pb2.Status.FALLBACK) + return sink_pb2.SinkResponse.Result(id=rid, status=sink_pb2.Status.FALLBACK) else: return sink_pb2.SinkResponse.Result( - id=rspn.id, status=sink_pb2.Status.FAILURE, err_msg=rspn.err + id=rid, status=sink_pb2.Status.FAILURE, err_msg=rspn.err ) + + +def datum_from_sink_req(d: sink_pb2.SinkRequest) -> Datum: + """ + Convert a SinkRequest object to a Datum object. + + Parameters: + d (sink_pb2.SinkRequest): A SinkRequest object containing the input data. + + Returns: + Datum: A Datum object populated with the data from the input SinkRequest object. + """ + datum = Datum( + keys=list(d.request.keys), + sink_msg_id=d.request.id, + value=d.request.value, + event_time=d.request.event_time.ToDatetime(), + watermark=d.request.watermark.ToDatetime(), + headers=dict(d.request.headers), + ) + return datum + + +def _create_read_handshake_response() -> sink_pb2.SinkResponse: + """ + Create a handshake response for the Sink function. + + Returns: + sink_pb2.SinkResponse: A SinkResponse object indicating a successful handshake. + """ + return sink_pb2.SinkResponse( + result=sink_pb2.SinkResponse.Result(status=sink_pb2.SUCCESS), + handshake=sink_pb2.Handshake(sot=True), + ) diff --git a/pynumaflow/sourcer/servicer/async_servicer.py b/pynumaflow/sourcer/servicer/async_servicer.py index dd07478b..e0d9c61d 100644 --- a/pynumaflow/sourcer/servicer/async_servicer.py +++ b/pynumaflow/sourcer/servicer/async_servicer.py @@ -1,12 +1,11 @@ import asyncio from collections.abc import AsyncIterable -import grpc from google.protobuf import timestamp_pb2 as _timestamp_pb2 from google.protobuf import empty_pb2 as _empty_pb2 from pynumaflow.shared.asynciter import NonBlockingIterator -from pynumaflow.shared.server import exit_on_error, handle_error +from pynumaflow.shared.server import exit_on_error, handle_async_error from pynumaflow.sourcer._dtypes import ReadRequest from pynumaflow.sourcer._dtypes import AckRequest, SourceCallable from pynumaflow.proto.sourcer import source_pb2 @@ -15,15 +14,6 @@ from pynumaflow._constants import _LOGGER, STREAM_EOF -async def _handle_exception(context, exception): - """Handle exceptions by updating the context and exiting.""" - handle_error(context, exception) - await asyncio.gather( - context.abort(grpc.StatusCode.UNKNOWN, details=repr(exception)), return_exceptions=True - ) - exit_on_error(err=repr(exception), parent=False, context=context, update_context=False) - - def _create_read_handshake_response(): """Create a handshake response for the Read function.""" return source_pb2.ReadResponse( @@ -99,7 +89,8 @@ async def ReadFn( try: # The first message to be received should be a valid handshake req = await request_iterator.__anext__() - if not _is_valid_handshake(req): + # check if it is a valid handshake req + if not (req.handshake and req.handshake.sot): raise Exception("ReadFn: expected handshake message") yield _create_read_handshake_response() @@ -117,7 +108,7 @@ async def ReadFn( async for resp in riter: if isinstance(resp, BaseException): - await _handle_exception(context, resp) + await handle_async_error(context, resp) return yield _create_read_response(resp) @@ -157,7 +148,8 @@ async def AckFn( try: # The first message to be received should be a valid handshake req = await request_iterator.__anext__() - if not _is_valid_handshake(req): + # check if it is a valid handshake req + if not (req.handshake and req.handshake.sot): raise Exception("AckFn: expected handshake message") yield _create_ack_handshake_response() @@ -214,8 +206,3 @@ def clean_background(self, task): Remove the task from the background tasks collection """ self.background_tasks.remove(task) - - -def _is_valid_handshake(req): - """Check if the handshake message is valid.""" - return req.handshake and req.handshake.sot diff --git a/tests/sink/test_async_sink.py b/tests/sink/test_async_sink.py index f04230cd..14cc2008 100644 --- a/tests/sink/test_async_sink.py +++ b/tests/sink/test_async_sink.py @@ -44,13 +44,7 @@ async def udsink_handler(datums: AsyncIterable[Datum]) -> Responses: return responses -def request_generator(count, request): - for i in range(count): - request.id = str(i) - yield request - - -def start_sink_streaming_request(req_type="success") -> (Datum, tuple): +def start_sink_streaming_request(_id: str, req_type) -> (Datum, tuple): event_time_timestamp, watermark_timestamp = get_time_args() value = mock_message() if req_type == "err": @@ -59,12 +53,21 @@ def start_sink_streaming_request(req_type="success") -> (Datum, tuple): if req_type == "fallback": value = mock_fallback_message() - request = sink_pb2.SinkRequest( - value=value, - event_time=event_time_timestamp, - watermark=watermark_timestamp, + request = sink_pb2.SinkRequest.Request( + value=value, event_time=event_time_timestamp, watermark=watermark_timestamp, id=_id ) - return request + return sink_pb2.SinkRequest(request=request) + + +def request_generator(count, req_type="success", session=1, handshake=True): + if handshake: + yield sink_pb2.SinkRequest(handshake=sink_pb2.Handshake(sot=True)) + + for j in range(session): + for i in range(count): + yield start_sink_streaming_request(str(i), req_type) + + yield sink_pb2.SinkRequest(status=sink_pb2.TransmissionStatus(eot=True)) _s: Server = None @@ -137,50 +140,96 @@ def test_run_server(self) -> None: def test_sink(self) -> None: stub = self.__stub() - request = start_sink_streaming_request() - print(request) generator_response = None + grpc_exception = None try: generator_response = stub.SinkFn( - request_iterator=request_generator(count=10, request=request) + request_iterator=request_generator(count=10, req_type="success", session=1) ) + handshake = next(generator_response) + # assert that handshake response is received. + self.assertTrue(handshake.handshake.sot) + data_resp = [] + for r in generator_response: + data_resp.append(r) + idx = 0 + while idx < len(data_resp) - 1: + # capture the output from the SinkFn generator and assert. + self.assertEqual(data_resp[idx].result.status, sink_pb2.Status.SUCCESS) + idx += 1 + # EOT Response + self.assertEqual(data_resp[len(data_resp) - 1].status.eot, True) + # 10 sink responses + 1 EOT response + self.assertEqual(11, len(data_resp)) except grpc.RpcError as e: logging.error(e) + grpc_exception = e - # capture the output from the ReduceFn generator and assert. - self.assertEqual(10, len(generator_response.results)) - for x in generator_response.results: - self.assertEqual(x.status, sink_pb2.Status.SUCCESS) + self.assertIsNone(grpc_exception) def test_sink_err(self) -> None: stub = self.__stub() - request = start_sink_streaming_request(req_type="err") - grpcException = None + grpc_exception = None try: - stub.SinkFn(request_iterator=request_generator(count=10, request=request)) + generator_response = stub.SinkFn( + request_iterator=request_generator(count=10, req_type="err") + ) + for _ in generator_response: + pass + except BaseException as e: + self.assertTrue("UDSinkError: ValueError('test_mock_err_message')" in e.__str__()) + return except grpc.RpcError as e: - grpcException = e + grpc_exception = e self.assertEqual(grpc.StatusCode.UNKNOWN, e.code()) - logging.error(e) + print(e.details()) - self.assertIsNotNone(grpcException) + self.assertIsNotNone(grpc_exception) + + def test_sink_err_handshake(self) -> None: + stub = self.__stub() + grpc_exception = None + try: + generator_response = stub.SinkFn( + request_iterator=request_generator(count=10, req_type="success", handshake=False) + ) + for _ in generator_response: + pass + except BaseException as e: + self.assertTrue("ReadFn: expected handshake message" in e.__str__()) + return + except grpc.RpcError as e: + grpc_exception = e + self.assertEqual(grpc.StatusCode.UNKNOWN, e.code()) + print(e.details()) + + self.assertIsNotNone(grpc_exception) def test_sink_fallback(self) -> None: stub = self.__stub() - request = start_sink_streaming_request(req_type="fallback") - generator_response = None try: generator_response = stub.SinkFn( - request_iterator=request_generator(count=10, request=request) + request_iterator=request_generator(count=10, req_type="fallback", session=1) ) + handshake = next(generator_response) + # assert that handshake response is received. + self.assertTrue(handshake.handshake.sot) + data_resp = [] + for r in generator_response: + data_resp.append(r) + + idx = 0 + while idx < len(data_resp) - 1: + # capture the output from the SinkFn generator and assert. + self.assertEqual(data_resp[idx].result.status, sink_pb2.Status.FALLBACK) + idx += 1 + # EOT Response + self.assertEqual(data_resp[len(data_resp) - 1].status.eot, True) + # 10 sink responses + 1 EOT response + self.assertEqual(11, len(data_resp)) except grpc.RpcError as e: logging.error(e) - # capture the output from the ReduceFn generator and assert. - self.assertEqual(10, len(generator_response.results)) - for x in generator_response.results: - self.assertEqual(x.status, sink_pb2.Status.FALLBACK) - def __stub(self): return sink_pb2_grpc.SinkStub(_channel) diff --git a/tests/sink/test_server.py b/tests/sink/test_server.py index d5fa8b17..f0e0af07 100644 --- a/tests/sink/test_server.py +++ b/tests/sink/test_server.py @@ -1,22 +1,22 @@ import os import unittest -from datetime import datetime, timezone from collections.abc import Iterator +from datetime import datetime, timezone from unittest import mock from unittest.mock import patch -from google.protobuf import empty_pb2 as _empty_pb2 -from google.protobuf import timestamp_pb2 as _timestamp_pb2 from grpc import StatusCode from grpc_testing import server_from_dictionary, strict_real_time +from google.protobuf import empty_pb2 as _empty_pb2 +from google.protobuf import timestamp_pb2 as _timestamp_pb2 from pynumaflow._constants import ( UD_CONTAINER_FALLBACK_SINK, - FALLBACK_SINK_SERVER_INFO_FILE_PATH, FALLBACK_SINK_SOCK_PATH, + FALLBACK_SINK_SERVER_INFO_FILE_PATH, ) -from pynumaflow.sinker import Responses, Datum, Response, SinkServer from pynumaflow.proto.sinker import sink_pb2 +from pynumaflow.sinker import Responses, Datum, Response, SinkServer from tests.testing_utils import mock_terminate_on_stop @@ -89,6 +89,44 @@ def test_is_ready(self): self.assertEqual(expected, response) self.assertEqual(code, StatusCode.OK) + def test_udsink_err_handshake(self): + server = SinkServer(sinker_instance=err_udsink_handler) + my_servicer = server.servicer + services = {sink_pb2.DESCRIPTOR.services_by_name["Sink"]: my_servicer} + self.test_server = server_from_dictionary(services, strict_real_time()) + + event_time_timestamp = _timestamp_pb2.Timestamp() + event_time_timestamp.FromDatetime(dt=mock_event_time()) + watermark_timestamp = _timestamp_pb2.Timestamp() + watermark_timestamp.FromDatetime(dt=mock_watermark()) + + test_datums = [ + sink_pb2.SinkRequest( + request=sink_pb2.SinkRequest.Request( + id="test_id_0", + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + ) + ), + sink_pb2.SinkRequest(status=sink_pb2.TransmissionStatus(eot=True)), + ] + + method = self.test_server.invoke_stream_stream( + method_descriptor=( + sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["SinkFn"] + ), + invocation_metadata={}, + timeout=1, + ) + + method.send_request(test_datums[0]) + + metadata, code, details = method.termination() + print("HERE", details) + self.assertTrue("UDSinkError: Exception('SinkFn: expected handshake message')" in details) + self.assertEqual(StatusCode.UNKNOWN, code) + def test_udsink_err(self): server = SinkServer(sinker_instance=err_udsink_handler) my_servicer = server.servicer @@ -102,26 +140,28 @@ def test_udsink_err(self): test_datums = [ sink_pb2.SinkRequest( - id="test_id_0", - value=mock_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, + handshake=sink_pb2.Handshake(sot=True), ), sink_pb2.SinkRequest( - id="test_id_1", - value=mock_err_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, + request=sink_pb2.SinkRequest.Request( + id="test_id_0", + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + ) ), sink_pb2.SinkRequest( - id="test_id_2", - value=mock_fallback_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, + request=sink_pb2.SinkRequest.Request( + id="test_id_1", + value=mock_err_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + ) ), + sink_pb2.SinkRequest(status=sink_pb2.TransmissionStatus(eot=True)), ] - method = self.test_server.invoke_stream_unary( + method = self.test_server.invoke_stream_stream( method_descriptor=( sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["SinkFn"] ), @@ -133,8 +173,17 @@ def test_udsink_err(self): method.send_request(test_datums[1]) method.send_request(test_datums[2]) method.requests_closed() - - response, metadata, code, details = method.termination() + responses = [] + while True: + try: + resp = method.take_response() + responses.append(resp) + except ValueError as err: + if "No more responses!" in err.__str__(): + break + + metadata, code, details = method.termination() + print(code) self.assertEqual(StatusCode.UNKNOWN, code) def test_forward_message(self): @@ -145,49 +194,64 @@ def test_forward_message(self): test_datums = [ sink_pb2.SinkRequest( - id="test_id_0", - value=mock_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, + handshake=sink_pb2.Handshake(sot=True), ), sink_pb2.SinkRequest( - id="test_id_1", - value=mock_err_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, + request=sink_pb2.SinkRequest.Request( + id="test_id_0", + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + ) ), sink_pb2.SinkRequest( - id="test_id_2", - value=mock_fallback_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, + request=sink_pb2.SinkRequest.Request( + id="test_id_1", + value=mock_err_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + ) ), + sink_pb2.SinkRequest(status=sink_pb2.TransmissionStatus(eot=True)), ] - method = self.test_server.invoke_stream_unary( + method = self.test_server.invoke_stream_stream( method_descriptor=( sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["SinkFn"] ), invocation_metadata={}, timeout=1, ) - - method.send_request(test_datums[0]) - method.send_request(test_datums[1]) - method.send_request(test_datums[2]) + for x in test_datums: + method.send_request(x) method.requests_closed() - response, metadata, code, details = method.termination() - self.assertEqual(3, len(response.results)) - self.assertEqual("test_id_0", response.results[0].id) - self.assertEqual("test_id_1", response.results[1].id) - self.assertEqual("test_id_2", response.results[2].id) - self.assertEqual(response.results[0].status, sink_pb2.Status.SUCCESS) - self.assertEqual(response.results[1].status, sink_pb2.Status.FAILURE) - self.assertEqual(response.results[2].status, sink_pb2.Status.FALLBACK) - self.assertEqual("", response.results[0].err_msg) - self.assertEqual("mock sink message error", response.results[1].err_msg) - self.assertEqual("", response.results[2].err_msg) + responses = [] + while True: + try: + resp = method.take_response() + responses.append(resp) + except ValueError as err: + if "No more responses!" in err.__str__(): + break + + # 1 handshake + 2 data messages + 1 EOT + self.assertEqual(4, len(responses)) + # first message should be handshake response + self.assertTrue(responses[0].handshake.sot) + + # assert the values for the corresponding messages + self.assertEqual("test_id_0", responses[1].result.id) + self.assertEqual("test_id_1", responses[2].result.id) + self.assertEqual(responses[1].result.status, sink_pb2.Status.SUCCESS) + self.assertEqual(responses[2].result.status, sink_pb2.Status.FAILURE) + self.assertEqual("", responses[1].result.err_msg) + self.assertEqual("mock sink message error", responses[2].result.err_msg) + + # last message should be EOT response + self.assertTrue(responses[3].status.eot) + + _, code, _ = method.termination() self.assertEqual(code, StatusCode.OK) def test_invalid_init(self):