diff --git a/Makefile b/Makefile index 7ff26a37..b0127755 100644 --- a/Makefile +++ b/Makefile @@ -15,7 +15,7 @@ lint: format test: - poetry run pytest tests/ + poetry run pytest tests/ -rA requirements: diff --git a/examples/source/simple_source/example.py b/examples/source/simple_source/example.py index 2dfd65c4..56a1f785 100644 --- a/examples/source/simple_source/example.py +++ b/examples/source/simple_source/example.py @@ -1,5 +1,6 @@ import uuid from datetime import datetime +import logging from pynumaflow.shared.asynciter import NonBlockingIterator from pynumaflow.sourcer import ( @@ -12,8 +13,16 @@ get_default_partitions, Sourcer, SourceAsyncServer, + NackRequest, ) +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)-8s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger(__name__) + class AsyncSource(Sourcer): """ @@ -21,12 +30,12 @@ class AsyncSource(Sourcer): """ def __init__(self): - """ - to_ack_set: Set to maintain a track of the offsets yet to be acknowledged - read_idx : the offset idx till where the messages have been read - """ - self.to_ack_set = set() - self.read_idx = 0 + # The offset idx till where the messages have been read + self.read_idx: int = 0 + # Set to maintain a track of the offsets yet to be acknowledged + self.to_ack_set: set[int] = set() + # Set to maintain a track of the offsets that have been negatively acknowledged + self.nacked: set[int] = set() async def read_handler(self, datum: ReadRequest, output: NonBlockingIterator): """ @@ -38,17 +47,22 @@ async def read_handler(self, datum: ReadRequest, output: NonBlockingIterator): return for x in range(datum.num_records): + # If there are any nacked offsets, re-deliver them + if self.nacked: + idx = self.nacked.pop() + else: + idx = self.read_idx + self.read_idx += 1 headers = {"x-txn-id": str(uuid.uuid4())} await output.put( Message( payload=str(self.read_idx).encode(), - offset=Offset.offset_with_default_partition_id(str(self.read_idx).encode()), + offset=Offset.offset_with_default_partition_id(str(idx).encode()), event_time=datetime.now(), headers=headers, ) ) - self.to_ack_set.add(str(self.read_idx)) - self.read_idx += 1 + self.to_ack_set.add(idx) async def ack_handler(self, ack_request: AckRequest): """ @@ -56,7 +70,19 @@ async def ack_handler(self, ack_request: AckRequest): from the to_ack_set """ for req in ack_request.offsets: - self.to_ack_set.remove(str(req.offset, "utf-8")) + offset = int(req.offset) + self.to_ack_set.remove(offset) + + async def nack_handler(self, ack_request: NackRequest): + """ + Add the offsets that have been negatively acknowledged to the nacked set + """ + + for req in ack_request.offsets: + offset = int(req.offset) + self.to_ack_set.remove(offset) + self.nacked.add(offset) + logger.info("Negatively acknowledged offsets: %s", self.nacked) async def pending_handler(self) -> PendingResponse: """ @@ -74,4 +100,5 @@ async def partitions_handler(self) -> PartitionsResponse: if __name__ == "__main__": ud_source = AsyncSource() grpc_server = SourceAsyncServer(ud_source) + logger.info("Starting grpc server") grpc_server.start() diff --git a/pynumaflow/proto/sourcer/source.proto b/pynumaflow/proto/sourcer/source.proto index 33f73104..eab85847 100644 --- a/pynumaflow/proto/sourcer/source.proto +++ b/pynumaflow/proto/sourcer/source.proto @@ -21,6 +21,10 @@ service Source { // Clients sends n requests and expects n responses. rpc AckFn(stream AckRequest) returns (stream AckResponse); + // NackFn negatively acknowledges a batch of offsets. Invoked during a critical error in the monovertex or pipeline. + // Unlike AckFn its not a streaming rpc because this is only invoked when there is a critical error (error path). + rpc NackFn(NackRequest) returns (NackResponse); + // PendingFn returns the number of pending records at the user defined source. rpc PendingFn(google.protobuf.Empty) returns (PendingResponse); @@ -139,6 +143,24 @@ message AckResponse { optional Handshake handshake = 2; } +message NackRequest { + message Request { + // Required field holding the offset to be nacked + repeated Offset offsets = 1; + } + // Required field holding the request. The list will be ordered and will have the same order as the original Read response. + Request request = 1; +} + +message NackResponse { + message Result { + // Required field indicating the nack request is successful. + google.protobuf.Empty success = 1; + } + // Required field holding the result. + Result result = 1; +} + /* * ReadyResponse is the health check result for user defined source. */ diff --git a/pynumaflow/proto/sourcer/source_pb2.py b/pynumaflow/proto/sourcer/source_pb2.py index 8e059564..f9645ce8 100644 --- a/pynumaflow/proto/sourcer/source_pb2.py +++ b/pynumaflow/proto/sourcer/source_pb2.py @@ -26,7 +26,7 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0csource.proto\x12\tsource.v1\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1bgoogle/protobuf/empty.proto\"\x18\n\tHandshake\x12\x0b\n\x03sot\x18\x01 \x01(\x08\"\xb1\x01\n\x0bReadRequest\x12/\n\x07request\x18\x01 \x01(\x0b\x32\x1e.source.v1.ReadRequest.Request\x12,\n\thandshake\x18\x02 \x01(\x0b\x32\x14.source.v1.HandshakeH\x00\x88\x01\x01\x1a\x35\n\x07Request\x12\x13\n\x0bnum_records\x18\x01 \x01(\x04\x12\x15\n\rtimeout_in_ms\x18\x02 \x01(\rB\x0c\n\n_handshake\"\x81\x05\n\x0cReadResponse\x12.\n\x06result\x18\x01 \x01(\x0b\x32\x1e.source.v1.ReadResponse.Result\x12.\n\x06status\x18\x02 \x01(\x0b\x32\x1e.source.v1.ReadResponse.Status\x12,\n\thandshake\x18\x03 \x01(\x0b\x32\x14.source.v1.HandshakeH\x00\x88\x01\x01\x1a\xe8\x01\n\x06Result\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12!\n\x06offset\x18\x02 \x01(\x0b\x32\x11.source.v1.Offset\x12.\n\nevent_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04keys\x18\x04 \x03(\t\x12<\n\x07headers\x18\x05 \x03(\x0b\x32+.source.v1.ReadResponse.Result.HeadersEntry\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\xe9\x01\n\x06Status\x12\x0b\n\x03\x65ot\x18\x01 \x01(\x08\x12\x31\n\x04\x63ode\x18\x02 \x01(\x0e\x32#.source.v1.ReadResponse.Status.Code\x12\x38\n\x05\x65rror\x18\x03 \x01(\x0e\x32$.source.v1.ReadResponse.Status.ErrorH\x00\x88\x01\x01\x12\x10\n\x03msg\x18\x04 \x01(\tH\x01\x88\x01\x01\" \n\x04\x43ode\x12\x0b\n\x07SUCCESS\x10\x00\x12\x0b\n\x07\x46\x41ILURE\x10\x01\"\x1f\n\x05\x45rror\x12\x0b\n\x07UNACKED\x10\x00\x12\t\n\x05OTHER\x10\x01\x42\x08\n\x06_errorB\x06\n\x04_msgB\x0c\n\n_handshake\"\xa7\x01\n\nAckRequest\x12.\n\x07request\x18\x01 \x01(\x0b\x32\x1d.source.v1.AckRequest.Request\x12,\n\thandshake\x18\x02 \x01(\x0b\x32\x14.source.v1.HandshakeH\x00\x88\x01\x01\x1a-\n\x07Request\x12\"\n\x07offsets\x18\x01 \x03(\x0b\x32\x11.source.v1.OffsetB\x0c\n\n_handshake\"\xab\x01\n\x0b\x41\x63kResponse\x12-\n\x06result\x18\x01 \x01(\x0b\x32\x1d.source.v1.AckResponse.Result\x12,\n\thandshake\x18\x02 \x01(\x0b\x32\x14.source.v1.HandshakeH\x00\x88\x01\x01\x1a\x31\n\x06Result\x12\'\n\x07success\x18\x01 \x01(\x0b\x32\x16.google.protobuf.EmptyB\x0c\n\n_handshake\"\x1e\n\rReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08\"]\n\x0fPendingResponse\x12\x31\n\x06result\x18\x01 \x01(\x0b\x32!.source.v1.PendingResponse.Result\x1a\x17\n\x06Result\x12\r\n\x05\x63ount\x18\x01 \x01(\x03\"h\n\x12PartitionsResponse\x12\x34\n\x06result\x18\x01 \x01(\x0b\x32$.source.v1.PartitionsResponse.Result\x1a\x1c\n\x06Result\x12\x12\n\npartitions\x18\x01 \x03(\x05\".\n\x06Offset\x12\x0e\n\x06offset\x18\x01 \x01(\x0c\x12\x14\n\x0cpartition_id\x18\x02 \x01(\x05\x32\xc8\x02\n\x06Source\x12=\n\x06ReadFn\x12\x16.source.v1.ReadRequest\x1a\x17.source.v1.ReadResponse(\x01\x30\x01\x12:\n\x05\x41\x63kFn\x12\x15.source.v1.AckRequest\x1a\x16.source.v1.AckResponse(\x01\x30\x01\x12?\n\tPendingFn\x12\x16.google.protobuf.Empty\x1a\x1a.source.v1.PendingResponse\x12\x45\n\x0cPartitionsFn\x12\x16.google.protobuf.Empty\x1a\x1d.source.v1.PartitionsResponse\x12;\n\x07IsReady\x12\x16.google.protobuf.Empty\x1a\x18.source.v1.ReadyResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0csource.proto\x12\tsource.v1\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1bgoogle/protobuf/empty.proto\"\x18\n\tHandshake\x12\x0b\n\x03sot\x18\x01 \x01(\x08\"\xb1\x01\n\x0bReadRequest\x12/\n\x07request\x18\x01 \x01(\x0b\x32\x1e.source.v1.ReadRequest.Request\x12,\n\thandshake\x18\x02 \x01(\x0b\x32\x14.source.v1.HandshakeH\x00\x88\x01\x01\x1a\x35\n\x07Request\x12\x13\n\x0bnum_records\x18\x01 \x01(\x04\x12\x15\n\rtimeout_in_ms\x18\x02 \x01(\rB\x0c\n\n_handshake\"\x81\x05\n\x0cReadResponse\x12.\n\x06result\x18\x01 \x01(\x0b\x32\x1e.source.v1.ReadResponse.Result\x12.\n\x06status\x18\x02 \x01(\x0b\x32\x1e.source.v1.ReadResponse.Status\x12,\n\thandshake\x18\x03 \x01(\x0b\x32\x14.source.v1.HandshakeH\x00\x88\x01\x01\x1a\xe8\x01\n\x06Result\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12!\n\x06offset\x18\x02 \x01(\x0b\x32\x11.source.v1.Offset\x12.\n\nevent_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04keys\x18\x04 \x03(\t\x12<\n\x07headers\x18\x05 \x03(\x0b\x32+.source.v1.ReadResponse.Result.HeadersEntry\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\xe9\x01\n\x06Status\x12\x0b\n\x03\x65ot\x18\x01 \x01(\x08\x12\x31\n\x04\x63ode\x18\x02 \x01(\x0e\x32#.source.v1.ReadResponse.Status.Code\x12\x38\n\x05\x65rror\x18\x03 \x01(\x0e\x32$.source.v1.ReadResponse.Status.ErrorH\x00\x88\x01\x01\x12\x10\n\x03msg\x18\x04 \x01(\tH\x01\x88\x01\x01\" \n\x04\x43ode\x12\x0b\n\x07SUCCESS\x10\x00\x12\x0b\n\x07\x46\x41ILURE\x10\x01\"\x1f\n\x05\x45rror\x12\x0b\n\x07UNACKED\x10\x00\x12\t\n\x05OTHER\x10\x01\x42\x08\n\x06_errorB\x06\n\x04_msgB\x0c\n\n_handshake\"\xa7\x01\n\nAckRequest\x12.\n\x07request\x18\x01 \x01(\x0b\x32\x1d.source.v1.AckRequest.Request\x12,\n\thandshake\x18\x02 \x01(\x0b\x32\x14.source.v1.HandshakeH\x00\x88\x01\x01\x1a-\n\x07Request\x12\"\n\x07offsets\x18\x01 \x03(\x0b\x32\x11.source.v1.OffsetB\x0c\n\n_handshake\"\xab\x01\n\x0b\x41\x63kResponse\x12-\n\x06result\x18\x01 \x01(\x0b\x32\x1d.source.v1.AckResponse.Result\x12,\n\thandshake\x18\x02 \x01(\x0b\x32\x14.source.v1.HandshakeH\x00\x88\x01\x01\x1a\x31\n\x06Result\x12\'\n\x07success\x18\x01 \x01(\x0b\x32\x16.google.protobuf.EmptyB\x0c\n\n_handshake\"m\n\x0bNackRequest\x12/\n\x07request\x18\x01 \x01(\x0b\x32\x1e.source.v1.NackRequest.Request\x1a-\n\x07Request\x12\"\n\x07offsets\x18\x01 \x03(\x0b\x32\x11.source.v1.Offset\"q\n\x0cNackResponse\x12.\n\x06result\x18\x01 \x01(\x0b\x32\x1e.source.v1.NackResponse.Result\x1a\x31\n\x06Result\x12\'\n\x07success\x18\x01 \x01(\x0b\x32\x16.google.protobuf.Empty\"\x1e\n\rReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08\"]\n\x0fPendingResponse\x12\x31\n\x06result\x18\x01 \x01(\x0b\x32!.source.v1.PendingResponse.Result\x1a\x17\n\x06Result\x12\r\n\x05\x63ount\x18\x01 \x01(\x03\"h\n\x12PartitionsResponse\x12\x34\n\x06result\x18\x01 \x01(\x0b\x32$.source.v1.PartitionsResponse.Result\x1a\x1c\n\x06Result\x12\x12\n\npartitions\x18\x01 \x03(\x05\".\n\x06Offset\x12\x0e\n\x06offset\x18\x01 \x01(\x0c\x12\x14\n\x0cpartition_id\x18\x02 \x01(\x05\x32\x83\x03\n\x06Source\x12=\n\x06ReadFn\x12\x16.source.v1.ReadRequest\x1a\x17.source.v1.ReadResponse(\x01\x30\x01\x12:\n\x05\x41\x63kFn\x12\x15.source.v1.AckRequest\x1a\x16.source.v1.AckResponse(\x01\x30\x01\x12\x39\n\x06NackFn\x12\x16.source.v1.NackRequest\x1a\x17.source.v1.NackResponse\x12?\n\tPendingFn\x12\x16.google.protobuf.Empty\x1a\x1a.source.v1.PendingResponse\x12\x45\n\x0cPartitionsFn\x12\x16.google.protobuf.Empty\x1a\x1d.source.v1.PartitionsResponse\x12;\n\x07IsReady\x12\x16.google.protobuf.Empty\x1a\x18.source.v1.ReadyResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -61,18 +61,26 @@ _globals['_ACKRESPONSE']._serialized_end=1281 _globals['_ACKRESPONSE_RESULT']._serialized_start=1218 _globals['_ACKRESPONSE_RESULT']._serialized_end=1267 - _globals['_READYRESPONSE']._serialized_start=1283 - _globals['_READYRESPONSE']._serialized_end=1313 - _globals['_PENDINGRESPONSE']._serialized_start=1315 - _globals['_PENDINGRESPONSE']._serialized_end=1408 - _globals['_PENDINGRESPONSE_RESULT']._serialized_start=1385 - _globals['_PENDINGRESPONSE_RESULT']._serialized_end=1408 - _globals['_PARTITIONSRESPONSE']._serialized_start=1410 - _globals['_PARTITIONSRESPONSE']._serialized_end=1514 - _globals['_PARTITIONSRESPONSE_RESULT']._serialized_start=1486 - _globals['_PARTITIONSRESPONSE_RESULT']._serialized_end=1514 - _globals['_OFFSET']._serialized_start=1516 - _globals['_OFFSET']._serialized_end=1562 - _globals['_SOURCE']._serialized_start=1565 - _globals['_SOURCE']._serialized_end=1893 + _globals['_NACKREQUEST']._serialized_start=1283 + _globals['_NACKREQUEST']._serialized_end=1392 + _globals['_NACKREQUEST_REQUEST']._serialized_start=1048 + _globals['_NACKREQUEST_REQUEST']._serialized_end=1093 + _globals['_NACKRESPONSE']._serialized_start=1394 + _globals['_NACKRESPONSE']._serialized_end=1507 + _globals['_NACKRESPONSE_RESULT']._serialized_start=1218 + _globals['_NACKRESPONSE_RESULT']._serialized_end=1267 + _globals['_READYRESPONSE']._serialized_start=1509 + _globals['_READYRESPONSE']._serialized_end=1539 + _globals['_PENDINGRESPONSE']._serialized_start=1541 + _globals['_PENDINGRESPONSE']._serialized_end=1634 + _globals['_PENDINGRESPONSE_RESULT']._serialized_start=1611 + _globals['_PENDINGRESPONSE_RESULT']._serialized_end=1634 + _globals['_PARTITIONSRESPONSE']._serialized_start=1636 + _globals['_PARTITIONSRESPONSE']._serialized_end=1740 + _globals['_PARTITIONSRESPONSE_RESULT']._serialized_start=1712 + _globals['_PARTITIONSRESPONSE_RESULT']._serialized_end=1740 + _globals['_OFFSET']._serialized_start=1742 + _globals['_OFFSET']._serialized_end=1788 + _globals['_SOURCE']._serialized_start=1791 + _globals['_SOURCE']._serialized_end=2178 # @@protoc_insertion_point(module_scope) diff --git a/pynumaflow/proto/sourcer/source_pb2.pyi b/pynumaflow/proto/sourcer/source_pb2.pyi index c6a1b449..0e158815 100644 --- a/pynumaflow/proto/sourcer/source_pb2.pyi +++ b/pynumaflow/proto/sourcer/source_pb2.pyi @@ -111,6 +111,28 @@ class AckResponse(_message.Message): handshake: Handshake def __init__(self, result: _Optional[_Union[AckResponse.Result, _Mapping]] = ..., handshake: _Optional[_Union[Handshake, _Mapping]] = ...) -> None: ... +class NackRequest(_message.Message): + __slots__ = ("request",) + class Request(_message.Message): + __slots__ = ("offsets",) + OFFSETS_FIELD_NUMBER: _ClassVar[int] + offsets: _containers.RepeatedCompositeFieldContainer[Offset] + def __init__(self, offsets: _Optional[_Iterable[_Union[Offset, _Mapping]]] = ...) -> None: ... + REQUEST_FIELD_NUMBER: _ClassVar[int] + request: NackRequest.Request + def __init__(self, request: _Optional[_Union[NackRequest.Request, _Mapping]] = ...) -> None: ... + +class NackResponse(_message.Message): + __slots__ = ("result",) + class Result(_message.Message): + __slots__ = ("success",) + SUCCESS_FIELD_NUMBER: _ClassVar[int] + success: _empty_pb2.Empty + def __init__(self, success: _Optional[_Union[_empty_pb2.Empty, _Mapping]] = ...) -> None: ... + RESULT_FIELD_NUMBER: _ClassVar[int] + result: NackResponse.Result + def __init__(self, result: _Optional[_Union[NackResponse.Result, _Mapping]] = ...) -> None: ... + class ReadyResponse(_message.Message): __slots__ = ("ready",) READY_FIELD_NUMBER: _ClassVar[int] diff --git a/pynumaflow/proto/sourcer/source_pb2_grpc.py b/pynumaflow/proto/sourcer/source_pb2_grpc.py index 8f646444..1e4f6c15 100644 --- a/pynumaflow/proto/sourcer/source_pb2_grpc.py +++ b/pynumaflow/proto/sourcer/source_pb2_grpc.py @@ -45,6 +45,11 @@ def __init__(self, channel): request_serializer=source__pb2.AckRequest.SerializeToString, response_deserializer=source__pb2.AckResponse.FromString, _registered_method=True) + self.NackFn = channel.unary_unary( + '/source.v1.Source/NackFn', + request_serializer=source__pb2.NackRequest.SerializeToString, + response_deserializer=source__pb2.NackResponse.FromString, + _registered_method=True) self.PendingFn = channel.unary_unary( '/source.v1.Source/PendingFn', request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, @@ -88,6 +93,14 @@ def AckFn(self, request_iterator, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def NackFn(self, request, context): + """NackFn negatively acknowledges a batch of offsets. Invoked during a critical error in the monovertex or pipeline. + Unlike AckFn its not a streaming rpc because this is only invoked when there is a critical error (error path). + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def PendingFn(self, request, context): """PendingFn returns the number of pending records at the user defined source. """ @@ -122,6 +135,11 @@ def add_SourceServicer_to_server(servicer, server): request_deserializer=source__pb2.AckRequest.FromString, response_serializer=source__pb2.AckResponse.SerializeToString, ), + 'NackFn': grpc.unary_unary_rpc_method_handler( + servicer.NackFn, + request_deserializer=source__pb2.NackRequest.FromString, + response_serializer=source__pb2.NackResponse.SerializeToString, + ), 'PendingFn': grpc.unary_unary_rpc_method_handler( servicer.PendingFn, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, @@ -202,6 +220,33 @@ def AckFn(request_iterator, metadata, _registered_method=True) + @staticmethod + def NackFn(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/source.v1.Source/NackFn', + source__pb2.NackRequest.SerializeToString, + source__pb2.NackResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + @staticmethod def PendingFn(request, target, diff --git a/pynumaflow/sourcer/__init__.py b/pynumaflow/sourcer/__init__.py index fcd11d61..013fac4e 100644 --- a/pynumaflow/sourcer/__init__.py +++ b/pynumaflow/sourcer/__init__.py @@ -3,10 +3,12 @@ ReadRequest, PendingResponse, AckRequest, + NackRequest, Offset, PartitionsResponse, get_default_partitions, Sourcer, + SourceCallable, ) from pynumaflow.sourcer.async_server import SourceAsyncServer @@ -15,9 +17,11 @@ "ReadRequest", "PendingResponse", "AckRequest", + "NackRequest", "Offset", "PartitionsResponse", "get_default_partitions", "Sourcer", "SourceAsyncServer", + "SourceCallable", ] diff --git a/pynumaflow/sourcer/_dtypes.py b/pynumaflow/sourcer/_dtypes.py index b2b34810..42c30515 100644 --- a/pynumaflow/sourcer/_dtypes.py +++ b/pynumaflow/sourcer/_dtypes.py @@ -1,10 +1,9 @@ import os from abc import ABCMeta, abstractmethod -from collections.abc import Iterable +from collections.abc import Iterator, AsyncIterator from dataclasses import dataclass from datetime import datetime from typing import Callable, Optional -from collections.abc import AsyncIterable from pynumaflow.shared.asynciter import NonBlockingIterator @@ -170,6 +169,31 @@ def offsets(self) -> list[Offset]: return self._offsets +@dataclass +class NackRequest: + """ + Class for defining the request for negatively acknowledging an offset. + It takes a list of offsets that need to be negatively acknowledged on the source. + Args: + offsets: the offsets to be negatively acknowledged. + >>> # Example usage + >>> from pynumaflow.sourcer import NackRequest, Offset + >>> offset_val = Offset(offset=b"123", partition_id=0) + >>> nack_request = NackRequest(offsets=[offset_val, offset_val]) + """ + + __slots__ = ("_offsets",) + _offsets: list[Offset] + + def __init__(self, offsets: list[Offset]): + self._offsets = offsets + + @property + def offsets(self) -> list[Offset]: + """Returns the offsets to be negatively acknowledged.""" + return self._offsets + + @dataclass(init=False) class PendingResponse: """ @@ -246,8 +270,14 @@ async def read_handler(self, datum: ReadRequest, output: NonBlockingIterator): @abstractmethod def ack_handler(self, ack_request: AckRequest): """ - The ack handler is used acknowledge the offsets that have been read, and remove them - from the to_ack_set + The ack handler is used to acknowledge the offsets that have been read + """ + pass + + @abstractmethod + def nack_handler(self, nack_request: NackRequest): + """ + The nack handler is used to negatively acknowledge the offsets on the source """ pass @@ -268,8 +298,8 @@ def partitions_handler(self) -> PartitionsResponse: # Create default partition id from the environment variable "NUMAFLOW_REPLICA" DefaultPartitionId = int(os.getenv("NUMAFLOW_REPLICA", "0")) -SourceReadCallable = Callable[[ReadRequest], Iterable[Message]] -AsyncSourceReadCallable = Callable[[ReadRequest], AsyncIterable[Message]] +SourceReadCallable = Callable[[ReadRequest], Iterator[Message]] +AsyncSourceReadCallable = Callable[[ReadRequest], AsyncIterator[Message]] SourceAckCallable = Callable[[AckRequest], None] SourceCallable = Sourcer diff --git a/pynumaflow/sourcer/async_server.py b/pynumaflow/sourcer/async_server.py index a2051990..264558b9 100644 --- a/pynumaflow/sourcer/async_server.py +++ b/pynumaflow/sourcer/async_server.py @@ -42,8 +42,8 @@ def __init__( defaults to 4 and max capped at 16 Example invocation: - from collections.abc import AsyncIterable from datetime import datetime + from pynumaflow.shared.asynciter import NonBlockingIterator from pynumaflow.sourcer import ( ReadRequest, Message, @@ -54,49 +54,78 @@ def __init__( get_default_partitions, Sourcer, SourceAsyncServer, + NackRequest, ) class AsyncSource(Sourcer): # AsyncSource is a class for User Defined Source implementation. def __init__(self): - # to_ack_set: Set to maintain a track of the offsets yet to be acknowledged - # read_idx : the offset idx till where the messages have been read - self.to_ack_set = set() - self.read_idx = 0 - - async def read_handler(self, datum: ReadRequest) -> AsyncIterable[Message]: - # read_handler is used to read the data from the source and send - # the data forward - # for each read request we process num_records and increment - # the read_idx to indicate that - # the message has been read and the same is added to the ack set + # The offset idx till where the messages have been read + self.read_idx: int = 0 + # Set to maintain a track of the offsets yet to be acknowledged + self.to_ack_set: set[int] = set() + # Set to maintain a track of the offsets that have been negatively acknowledged + self.nacked: set[int] = set() + + async def read_handler(self, datum: ReadRequest, output: NonBlockingIterator): + ''' + read_handler is used to read the data from the source and send the data forward + for each read request we process num_records and increment the read_idx to + indicate that the message has been read and the same is added to the ack set + ''' if self.to_ack_set: return for x in range(datum.num_records): - yield Message( - payload=str(self.read_idx).encode(), - offset=Offset.offset_with_default_partition_id(str(self.read_idx).encode()), - event_time=datetime.now(), + # If there are any nacked offsets, re-deliver them + if self.nacked: + idx = self.nacked.pop() + else: + idx = self.read_idx + self.read_idx += 1 + headers = {"x-txn-id": str(uuid.uuid4())} + await output.put( + Message( + payload=str(self.read_idx).encode(), + offset=Offset.offset_with_default_partition_id(str(idx).encode()), + event_time=datetime.now(), + headers=headers, + ) ) - self.to_ack_set.add(str(self.read_idx)) - self.read_idx += 1 + self.to_ack_set.add(idx) async def ack_handler(self, ack_request: AckRequest): - # The ack handler is used acknowledge the offsets that have been read, - # and remove them from the to_ack_set - for offset in ack_request.offset: - self.to_ack_set.remove(str(offset.offset, "utf-8")) + ''' + The ack handler is used acknowledge the offsets that have been read, and remove + them from the to_ack_set + ''' + for req in ack_request.offsets: + offset = int(req.offset) + self.to_ack_set.remove(offset) + + async def nack_handler(self, ack_request: NackRequest): + ''' + Add the offsets that have been negatively acknowledged to the nacked set + ''' + for req in ack_request.offsets: + offset = int(req.offset) + self.to_ack_set.remove(offset) + self.nacked.add(offset) async def pending_handler(self) -> PendingResponse: - # The simple source always returns zero to indicate there is no pending record. + ''' + The simple source always returns zero to indicate there is no pending record. + ''' return PendingResponse(count=0) async def partitions_handler(self) -> PartitionsResponse: - # The simple source always returns default partitions. + ''' + The simple source always returns default partitions. + ''' return PartitionsResponse(partitions=get_default_partitions()) + if __name__ == "__main__": ud_source = AsyncSource() grpc_server = SourceAsyncServer(ud_source) diff --git a/pynumaflow/sourcer/servicer/async_servicer.py b/pynumaflow/sourcer/servicer/async_servicer.py index 7939d75b..bb8e58fd 100644 --- a/pynumaflow/sourcer/servicer/async_servicer.py +++ b/pynumaflow/sourcer/servicer/async_servicer.py @@ -6,8 +6,7 @@ from pynumaflow.shared.asynciter import NonBlockingIterator from pynumaflow.shared.server import handle_async_error -from pynumaflow.sourcer._dtypes import ReadRequest, Offset -from pynumaflow.sourcer._dtypes import AckRequest, SourceCallable +from pynumaflow.sourcer import ReadRequest, Offset, NackRequest, AckRequest, SourceCallable from pynumaflow.proto.sourcer import source_pb2 from pynumaflow.proto.sourcer import source_pb2_grpc from pynumaflow.types import NumaflowServicerContext @@ -75,6 +74,7 @@ def __initialize_handlers(self): """Initialize handler methods from the provided source handler.""" self.__source_read_handler = self.source_handler.read_handler self.__source_ack_handler = self.source_handler.ack_handler + self.__source_nack_handler = self.source_handler.nack_handler self.__source_pending_handler = self.source_handler.pending_handler self.__source_partitions_handler = self.source_handler.partitions_handler @@ -167,6 +167,26 @@ async def AckFn( _LOGGER.critical("User-Defined Source AckFn error", exc_info=True) await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + async def NackFn( + self, + request: source_pb2.NackRequest, + context: NumaflowServicerContext, + ) -> source_pb2.NackResponse: + """ + Handles the Nack function for user-defined source. + """ + try: + offsets = [ + Offset(offset.offset, offset.partition_id) for offset in request.request.offsets + ] + await self.__source_nack_handler(NackRequest(offsets=offsets)) + except BaseException as err: + _LOGGER.critical("User-Defined Source NackFn error", exc_info=True) + await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + return source_pb2.NackResponse( + result=source_pb2.NackResponse.Result(success=_empty_pb2.Empty()) + ) + async def IsReady( self, request: _empty_pb2.Empty, context: NumaflowServicerContext ) -> source_pb2.ReadyResponse: diff --git a/tests/source/test_async_source.py b/tests/source/test_async_source.py index 332ce7c3..e255b7c8 100644 --- a/tests/source/test_async_source.py +++ b/tests/source/test_async_source.py @@ -18,6 +18,7 @@ mock_partitions, AsyncSource, mock_offset, + nack_req_source_fn, ) LOGGER = setup_logging(__name__) @@ -180,6 +181,13 @@ def test_ack(self) -> None: self.assertEqual(count, 2) self.assertFalse(first) + def test_nack(self) -> None: + with grpc.insecure_channel(server_port) as channel: + stub = source_pb2_grpc.SourceStub(channel) + request = nack_req_source_fn() + response = stub.NackFn(request=request) + self.assertTrue(response.result.success) + def test_pending(self) -> None: with grpc.insecure_channel(server_port) as channel: stub = source_pb2_grpc.SourceStub(channel) diff --git a/tests/source/test_async_source_err.py b/tests/source/test_async_source_err.py index 7f4ab002..0ce4fa3d 100644 --- a/tests/source/test_async_source_err.py +++ b/tests/source/test_async_source_err.py @@ -18,6 +18,7 @@ read_req_source_fn, ack_req_source_fn, AsyncSourceError, + nack_req_source_fn, ) from tests.testing_utils import mock_terminate_on_stop @@ -137,6 +138,17 @@ def test_ack_error(self) -> None: print(e.details()) self.fail("Expected an exception.") + def test_nack_error(self): + with grpc.insecure_channel(server_port) as channel: + stub = source_pb2_grpc.SourceStub(channel) + request = nack_req_source_fn() + with self.assertRaisesRegex( + grpc.RpcError, "Got a runtime error from nack handler." + ) as resp: + stub.NackFn(request=request) + + self.assertEqual(grpc.StatusCode.INTERNAL, resp.exception.code()) + def test_ack_no_handshake_error(self) -> None: with grpc.insecure_channel(server_port) as channel: stub = source_pb2_grpc.SourceStub(channel) diff --git a/tests/source/utils.py b/tests/source/utils.py index 5a63683d..7baa91b7 100644 --- a/tests/source/utils.py +++ b/tests/source/utils.py @@ -3,12 +3,13 @@ from pynumaflow.shared.asynciter import NonBlockingIterator from pynumaflow.sourcer import ReadRequest, Message -from pynumaflow.sourcer._dtypes import ( +from pynumaflow.sourcer import ( AckRequest, PendingResponse, Offset, PartitionsResponse, Sourcer, + NackRequest, ) from pynumaflow.proto.sourcer import source_pb2 from tests.testing_utils import mock_event_time @@ -36,6 +37,9 @@ async def read_handler(self, datum: ReadRequest, output: NonBlockingIterator): async def ack_handler(self, ack_request: AckRequest): return + async def nack_handler(self, nack_request: NackRequest): + return + async def pending_handler(self) -> PendingResponse: return PendingResponse(count=10) @@ -55,6 +59,9 @@ def read_handler(self, datum: ReadRequest) -> Iterable[Message]: def ack_handler(self, ack_request: AckRequest): return + def nack_handler(self, nack_request: NackRequest): + return + def pending_handler(self) -> PendingResponse: return PendingResponse(count=10) @@ -76,6 +83,12 @@ def ack_req_source_fn(): return request +def nack_req_source_fn(): + msg = source_pb2.Offset(offset=mock_offset().offset, partition_id=mock_offset().partition_id) + request = source_pb2.NackRequest.Request(offsets=[msg]) + return source_pb2.NackRequest(request=request) + + class AsyncSourceError(Sourcer): # This handler mimics the scenario where map stream UDF throws a runtime error. async def read_handler(self, datum: ReadRequest, output: NonBlockingIterator): @@ -92,6 +105,9 @@ async def read_handler(self, datum: ReadRequest, output: NonBlockingIterator): async def ack_handler(self, ack_request: AckRequest): raise RuntimeError("Got a runtime error from ack handler.") + async def nack_handler(self, nack_request: NackRequest): + raise RuntimeError("Got a runtime error from nack handler.") + async def pending_handler(self) -> PendingResponse: raise RuntimeError("Got a runtime error from pending handler.") @@ -106,6 +122,9 @@ def read_handler(self, datum: ReadRequest) -> Iterable[Message]: def ack_handler(self, ack_request: AckRequest): raise RuntimeError("Got a runtime error from ack handler.") + def nack_handler(self, nack_request: NackRequest): + raise RuntimeError("Got a runtime error from nack handler.") + def pending_handler(self) -> PendingResponse: raise RuntimeError("Got a runtime error from pending handler.")