diff --git a/pynumaflow/sourcer/servicer/async_servicer.py b/pynumaflow/sourcer/servicer/async_servicer.py index f5e4b0f9..7939d75b 100644 --- a/pynumaflow/sourcer/servicer/async_servicer.py +++ b/pynumaflow/sourcer/servicer/async_servicer.py @@ -1,5 +1,5 @@ import asyncio -from collections.abc import AsyncIterable +from collections.abc import AsyncIterator from google.protobuf import timestamp_pb2 as _timestamp_pb2 from google.protobuf import empty_pb2 as _empty_pb2 @@ -80,9 +80,9 @@ def __initialize_handlers(self): async def ReadFn( self, - request_iterator: AsyncIterable[source_pb2.ReadRequest], + request_iterator: AsyncIterator[source_pb2.ReadRequest], context: NumaflowServicerContext, - ) -> AsyncIterable[source_pb2.ReadResponse]: + ) -> AsyncIterator[source_pb2.ReadResponse]: """ Handles the Read function, processing incoming requests and sending responses. """ @@ -108,7 +108,7 @@ async def ReadFn( async for resp in riter: if isinstance(resp, BaseException): - await handle_async_error(context, resp) + await handle_async_error(context, resp, ERR_UDF_EXCEPTION_STRING) return yield _create_read_response(resp) @@ -139,9 +139,9 @@ async def __invoke_read(self, req, niter): async def AckFn( self, - request_iterator: AsyncIterable[source_pb2.AckRequest], + request_iterator: AsyncIterator[source_pb2.AckRequest], context: NumaflowServicerContext, - ) -> AsyncIterable[source_pb2.AckResponse]: + ) -> AsyncIterator[source_pb2.AckResponse]: """ Handles the Ack function for user-defined source. """ diff --git a/tests/source/test_async_source_err.py b/tests/source/test_async_source_err.py index c72060bd..7f4ab002 100644 --- a/tests/source/test_async_source_err.py +++ b/tests/source/test_async_source_err.py @@ -9,11 +9,10 @@ from grpc.aio._server import Server from pynumaflow import setup_logging -from pynumaflow._constants import ERR_UDF_EXCEPTION_STRING -from pynumaflow.sourcer import SourceAsyncServer from pynumaflow.proto.sourcer import source_pb2_grpc from google.protobuf import empty_pb2 as _empty_pb2 +from pynumaflow.sourcer import SourceAsyncServer from tests.source.test_async_source import request_generator from tests.source.utils import ( read_req_source_fn, @@ -92,20 +91,12 @@ def test_read_error(self) -> None: ) for _ in generator_response: pass - except BaseException as e: - self.assertTrue( - f"{ERR_UDF_EXCEPTION_STRING}: TypeError(" - '"handle_async_error() missing 1 required positional argument: ' - "'exception_type'\")" in e.__str__() - ) - return except grpc.RpcError as e: grpc_exception = e - self.assertEqual(grpc.StatusCode.UNKNOWN, e.code()) + self.assertEqual(grpc.StatusCode.INTERNAL, e.code()) print(e.details()) self.assertIsNotNone(grpc_exception) - self.fail("Expected an exception.") def test_read_handshake_error(self) -> None: grpc_exception = None