diff --git a/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py index 18bc3775..8a37ad89 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py @@ -13,7 +13,7 @@ # limitations under the License. from asyncio import Future, Queue, ensure_future -from typing import Callable, NamedTuple, Dict, Set, Optional +from typing import Callable, NamedTuple, Dict, List, Set, Optional from google.cloud.pubsub_v1.subscriber.message import Message @@ -41,7 +41,7 @@ class AssigningSingleSubscriber(AsyncSingleSubscriber, PermanentFailable): # Lazily initialized to ensure they are initialized on the thread where __aenter__ is called. _assigner: Optional[Assigner] - _messages: Optional["Queue[Message]"] + _batches: Optional["Queue[List[Message]]"] _assign_poller: Future def __init__( @@ -58,14 +58,14 @@ def __init__( self._assigner = None self._subscriber_factory = subscriber_factory self._subscribers = {} - self._messages = None + self._batches = None - async def read(self) -> Message: - return await self.await_unless_failed(self._messages.get()) + async def read(self) -> List[Message]: + return await self.await_unless_failed(self._batches.get()) async def _subscribe_action(self, subscriber: AsyncSingleSubscriber): - message = await subscriber.read() - await self._messages.put(message) + batch = await subscriber.read() + await self._batches.put(batch) async def _start_subscriber(self, partition: Partition): new_subscriber = self._subscriber_factory(partition) @@ -92,7 +92,7 @@ async def _assign_action(self): await self._stop_subscriber(subscriber) async def __aenter__(self): - self._messages = Queue() + self._batches = Queue() self._assigner = self._assigner_factory() await self._assigner.__aenter__() self._assign_poller = ensure_future(self.run_poller(self._assign_action)) diff --git a/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_async_subscriber_client.py b/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_async_subscriber_client.py index 4187af33..6becdfc4 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_async_subscriber_client.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/multiplexed_async_subscriber_client.py @@ -38,27 +38,17 @@ from overrides import overrides -class _SubscriberAsyncIterator(AsyncIterator): - _subscriber: AsyncSingleSubscriber - _on_failure: Callable[[], Awaitable[None]] - - def __init__( - self, - subscriber: AsyncSingleSubscriber, - on_failure: Callable[[], Awaitable[None]], - ): - self._subscriber = subscriber - self._on_failure = on_failure - - async def __anext__(self) -> Message: - try: - return await self._subscriber.read() - except: # noqa: E722 - await self._on_failure() - raise - - def __aiter__(self): - return self +async def _iterate_subscriber( + subscriber: AsyncSingleSubscriber, on_failure: Callable[[], Awaitable[None]] +) -> AsyncIterator[Message]: + try: + while True: + batch = await subscriber.read() + for message in batch: + yield message + except: # noqa: E722 + await on_failure() + raise class MultiplexedAsyncSubscriberClient(AsyncSubscriberClientInterface): @@ -85,7 +75,7 @@ async def subscribe( await subscriber.__aenter__() self._live_clients.add(subscriber) - return _SubscriberAsyncIterator( + return _iterate_subscriber( subscriber, lambda: self._try_remove_client(subscriber) ) diff --git a/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py index 85cb864a..00290889 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py @@ -13,8 +13,7 @@ # limitations under the License. import asyncio -import json -from typing import Callable, Union, Dict, NamedTuple +from typing import Callable, Union, List, Dict, NamedTuple import queue from google.api_core.exceptions import FailedPrecondition, GoogleAPICallError @@ -22,6 +21,8 @@ from google.pubsub_v1 import PubsubMessage from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled +from google.cloud.pubsublite.internal.wire.permanent_failable import adapt_error +from google.cloud.pubsublite.internal import fast_serialize from google.cloud.pubsublite.types import FlowControlSettings from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer @@ -47,15 +48,13 @@ class _AckId(NamedTuple): generation: int offset: int - def str(self) -> str: - return json.dumps({"generation": self.generation, "offset": self.offset}) + def encode(self) -> str: + return fast_serialize.dump([self.generation, self.offset]) @staticmethod def parse(payload: str) -> "_AckId": # pytype: disable=invalid-annotation - loaded = json.loads(payload) - return _AckId( - generation=int(loaded["generation"]), offset=int(loaded["offset"]), - ) + loaded = fast_serialize.load(payload) + return _AckId(generation=loaded[0], offset=loaded[1]) ResettableSubscriberFactory = Callable[[SubscriberResetHandler], Subscriber] @@ -99,26 +98,31 @@ async def handle_reset(self): self._ack_generation_id += 1 await self._ack_set_tracker.clear_and_commit() - async def read(self) -> Message: + def _wrap_message(self, message: SequencedMessage.meta.pb) -> Message: + # Rewrap in the proto-plus-python wrapper for passing to the transform + rewrapped = SequencedMessage() + rewrapped._pb = message + cps_message = self._transformer.transform(rewrapped) + offset = message.cursor.offset + ack_id_str = _AckId(self._ack_generation_id, offset).encode() + self._ack_set_tracker.track(offset) + self._messages_by_ack_id[ack_id_str] = _SizedMessage( + cps_message, message.size_bytes + ) + wrapped_message = Message( + cps_message._pb, + ack_id=ack_id_str, + delivery_attempt=0, + request_queue=self._queue, + ) + return wrapped_message + + async def read(self) -> List[Message]: try: - message: SequencedMessage = await self.await_unless_failed( - self._underlying.read() - ) - cps_message = self._transformer.transform(message) - offset = message.cursor.offset - ack_id = _AckId(self._ack_generation_id, offset) - self._ack_set_tracker.track(offset) - self._messages_by_ack_id[ack_id.str()] = _SizedMessage( - cps_message, message.size_bytes - ) - wrapped_message = Message( - cps_message._pb, - ack_id=ack_id.str(), - delivery_attempt=0, - request_queue=self._queue, - ) - return wrapped_message - except GoogleAPICallError as e: + latest_batch = await self.await_unless_failed(self._underlying.read()) + return [self._wrap_message(message) for message in latest_batch] + except Exception as e: + e = adapt_error(e) # This could be from user code self.fail(e) raise e diff --git a/google/cloud/pubsublite/cloudpubsub/internal/single_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/single_subscriber.py index ab787bf3..314d77c4 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/single_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/single_subscriber.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import abstractmethod, ABCMeta -from typing import AsyncContextManager, Callable, Set, Optional +from typing import AsyncContextManager, Callable, List, Set, Optional from google.cloud.pubsub_v1.subscriber.message import Message @@ -32,12 +32,13 @@ class AsyncSingleSubscriber(AsyncContextManager, metaclass=ABCMeta): """ @abstractmethod - async def read(self) -> Message: + async def read(self) -> List[Message]: """ - Read the next message off of the stream. + Read the next batch off of the stream. Returns: - The next message. ack() or nack() must eventually be called exactly once. + The next batch of messages. ack() or nack() must eventually be called + exactly once on each message. Pub/Sub Lite does not support nack() by default- if you do call nack(), it will immediately fail the client unless you have a NackHandler installed. diff --git a/google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py b/google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py index ed6c8368..05a58c13 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py @@ -84,8 +84,8 @@ def _fail(self, error: GoogleAPICallError): async def _poller(self): try: while True: - message = await self._underlying.read() - self._unowned_executor.submit(self._callback, message) + batch = await self._underlying.read() + self._unowned_executor.map(self._callback, batch) except GoogleAPICallError as e: # noqa: F841 Flake8 thinks e is unused self._unowned_executor.submit(lambda: self._fail(e)) # noqa: F821 diff --git a/google/cloud/pubsublite/cloudpubsub/message_transforms.py b/google/cloud/pubsublite/cloudpubsub/message_transforms.py index 381cf5cd..7f56711b 100644 --- a/google/cloud/pubsublite/cloudpubsub/message_transforms.py +++ b/google/cloud/pubsublite/cloudpubsub/message_transforms.py @@ -19,27 +19,42 @@ from google.pubsub_v1 import PubsubMessage from google.cloud.pubsublite.cloudpubsub import MessageTransformer +from google.cloud.pubsublite.internal import fast_serialize from google.cloud.pubsublite.types import Partition, MessageMetadata from google.cloud.pubsublite_v1 import AttributeValues, SequencedMessage, PubSubMessage PUBSUB_LITE_EVENT_TIME = "x-goog-pubsublite-event-time" -def encode_attribute_event_time(dt: datetime.datetime) -> str: - ts = Timestamp() - ts.FromDatetime(dt) - return ts.ToJsonString() +def _encode_attribute_event_time_proto(ts: Timestamp) -> str: + return fast_serialize.dump([ts.seconds, ts.nanos]) -def decode_attribute_event_time(attr: str) -> datetime.datetime: +def _decode_attribute_event_time_proto(attr: str) -> Timestamp: try: ts = Timestamp() - ts.FromJsonString(attr) - return ts.ToDatetime() - except ValueError: + loaded = fast_serialize.load(attr) + ts.seconds = loaded[0] + ts.nanos = loaded[1] + return ts + except Exception: # noqa: E722 raise InvalidArgument("Invalid value for event time attribute.") +def encode_attribute_event_time(dt: datetime.datetime) -> str: + ts = Timestamp() + ts.FromDatetime(dt.astimezone(datetime.timezone.utc)) + return _encode_attribute_event_time_proto(ts) + + +def decode_attribute_event_time(attr: str) -> datetime.datetime: + return ( + _decode_attribute_event_time_proto(attr) + .ToDatetime() + .replace(tzinfo=datetime.timezone.utc) + ) + + def _parse_attributes(values: AttributeValues) -> str: if not len(values.values) == 1: raise InvalidArgument( @@ -58,25 +73,34 @@ def add_id_to_cps_subscribe_transformer( partition: Partition, transformer: MessageTransformer ) -> MessageTransformer: def add_id_to_message(source: SequencedMessage): + source_pb = source._pb message: PubsubMessage = transformer.transform(source) - if message.message_id: + message_pb = message._pb + if message_pb.message_id: raise InvalidArgument( "Message after transforming has the message_id field set." ) - message.message_id = MessageMetadata(partition, source.cursor).encode() + message_pb.message_id = MessageMetadata._encode_parts( + partition.value, source_pb.cursor.offset + ) return message return MessageTransformer.of_callable(add_id_to_message) def to_cps_subscribe_message(source: SequencedMessage) -> PubsubMessage: - message: PubsubMessage = to_cps_publish_message(source.message) - message.publish_time = source.publish_time - return message + source_pb = source._pb + out_pb = _to_cps_publish_message_proto(source_pb.message) + out_pb.publish_time.CopyFrom(source_pb.publish_time) + out = PubsubMessage() + out._pb = out_pb + return out -def to_cps_publish_message(source: PubSubMessage) -> PubsubMessage: - out = PubsubMessage() +def _to_cps_publish_message_proto( + source: PubSubMessage.meta.pb, +) -> PubsubMessage.meta.pb: + out = PubsubMessage.meta.pb() try: out.ordering_key = source.key.decode("utf-8") except UnicodeError: @@ -88,22 +112,32 @@ def to_cps_publish_message(source: PubSubMessage) -> PubsubMessage: out.data = source.data for key, values in source.attributes.items(): out.attributes[key] = _parse_attributes(values) - if "event_time" in source: - out.attributes[PUBSUB_LITE_EVENT_TIME] = encode_attribute_event_time( + if source.HasField("event_time"): + out.attributes[PUBSUB_LITE_EVENT_TIME] = _encode_attribute_event_time_proto( source.event_time ) return out +def to_cps_publish_message(source: PubSubMessage) -> PubsubMessage: + out = PubsubMessage() + out._pb = _to_cps_publish_message_proto(source._pb) + return out + + def from_cps_publish_message(source: PubsubMessage) -> PubSubMessage: + source_pb = source._pb out = PubSubMessage() - if PUBSUB_LITE_EVENT_TIME in source.attributes: - out.event_time = decode_attribute_event_time( - source.attributes[PUBSUB_LITE_EVENT_TIME] + out_pb = out._pb + if PUBSUB_LITE_EVENT_TIME in source_pb.attributes: + out_pb.event_time.CopyFrom( + _decode_attribute_event_time_proto( + source_pb.attributes[PUBSUB_LITE_EVENT_TIME] + ) ) - out.data = source.data - out.key = source.ordering_key.encode("utf-8") - for key, value in source.attributes.items(): + out_pb.data = source_pb.data + out_pb.key = source_pb.ordering_key.encode("utf-8") + for key, value in source_pb.attributes.items(): if key != PUBSUB_LITE_EVENT_TIME: - out.attributes[key] = AttributeValues(values=[value.encode("utf-8")]) + out_pb.attributes[key].values.append(value.encode("utf-8")) return out diff --git a/google/cloud/pubsublite/internal/fast_serialize.py b/google/cloud/pubsublite/internal/fast_serialize.py new file mode 100644 index 00000000..c07236c0 --- /dev/null +++ b/google/cloud/pubsublite/internal/fast_serialize.py @@ -0,0 +1,13 @@ +""" +A fast serialization method for lists of integers. +""" + +from typing import List + + +def dump(data: List[int]) -> str: + return ",".join(str(x) for x in data) + + +def load(source: str) -> List[int]: + return [int(x) for x in source.split(",")] diff --git a/google/cloud/pubsublite/internal/wire/flow_control_batcher.py b/google/cloud/pubsublite/internal/wire/flow_control_batcher.py index 7e99fb58..ee442808 100644 --- a/google/cloud/pubsublite/internal/wire/flow_control_batcher.py +++ b/google/cloud/pubsublite/internal/wire/flow_control_batcher.py @@ -21,18 +21,27 @@ class _AggregateRequest: - request: FlowControlRequest + _request: FlowControlRequest.meta.pb def __init__(self): - self.request = FlowControlRequest() + self._request = FlowControlRequest.meta.pb() - def __add__(self, other: FlowControlRequest): - self.request.allowed_bytes += other.allowed_bytes - self.request.allowed_bytes = min(self.request.allowed_bytes, _MAX_INT64) - self.request.allowed_messages += other.allowed_messages - self.request.allowed_messages = min(self.request.allowed_messages, _MAX_INT64) + def __add__(self, other: FlowControlRequest.meta.pb): + self._request.allowed_bytes = self._request.allowed_bytes + other.allowed_bytes + self._request.allowed_bytes = min(self._request.allowed_bytes, _MAX_INT64) + self._request.allowed_messages = ( + self._request.allowed_messages + other.allowed_messages + ) + self._request.allowed_messages = min(self._request.allowed_messages, _MAX_INT64) return self + def to_optional(self) -> Optional[FlowControlRequest]: + if self._request.allowed_messages == 0 and self._request.allowed_bytes == 0: + return None + request = FlowControlRequest() + request._pb = self._request + return request + def _exceeds_expedite_ratio(pending: int, client: int): if client <= 0: @@ -40,12 +49,6 @@ def _exceeds_expedite_ratio(pending: int, client: int): return (pending / client) >= _EXPEDITE_BATCH_REQUEST_RATIO -def _to_optional(req: FlowControlRequest) -> Optional[FlowControlRequest]: - if req.allowed_messages == 0 and req.allowed_bytes == 0: - return None - return req - - class FlowControlBatcher: _client_tokens: _AggregateRequest _pending_tokens: _AggregateRequest @@ -59,23 +62,25 @@ def add(self, request: FlowControlRequest): self._pending_tokens += request def on_messages(self, messages: List[SequencedMessage]): - byte_size = sum(message.size_bytes for message in messages) + byte_size = 0 + for message in messages: + byte_size += message.size_bytes self._client_tokens += FlowControlRequest( allowed_bytes=-byte_size, allowed_messages=-len(messages) ) def request_for_restart(self) -> Optional[FlowControlRequest]: self._pending_tokens = _AggregateRequest() - return _to_optional(self._client_tokens.request) + return self._client_tokens.to_optional() def release_pending_request(self) -> Optional[FlowControlRequest]: - request = self._pending_tokens.request + request = self._pending_tokens self._pending_tokens = _AggregateRequest() - return _to_optional(request) + return request.to_optional() def should_expedite(self): - pending_request = self._pending_tokens.request - client_request = self._client_tokens.request + pending_request = self._pending_tokens._request + client_request = self._client_tokens._request if _exceeds_expedite_ratio( pending_request.allowed_bytes, client_request.allowed_bytes ): diff --git a/google/cloud/pubsublite/internal/wire/permanent_failable.py b/google/cloud/pubsublite/internal/wire/permanent_failable.py index 7522864e..d96c4c9d 100644 --- a/google/cloud/pubsublite/internal/wire/permanent_failable.py +++ b/google/cloud/pubsublite/internal/wire/permanent_failable.py @@ -15,13 +15,19 @@ import asyncio from typing import Awaitable, TypeVar, Optional, Callable -from google.api_core.exceptions import GoogleAPICallError +from google.api_core.exceptions import GoogleAPICallError, Unknown from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors T = TypeVar("T") +def adapt_error(e: Exception) -> GoogleAPICallError: + if isinstance(e, GoogleAPICallError): + return e + return Unknown("Had an unknown error", errors=[e]) + + class _TaskWithCleanup: def __init__(self, a: Awaitable): self._task = asyncio.ensure_future(a) diff --git a/google/cloud/pubsublite/internal/wire/subscriber.py b/google/cloud/pubsublite/internal/wire/subscriber.py index 610ba9d1..dec650f6 100644 --- a/google/cloud/pubsublite/internal/wire/subscriber.py +++ b/google/cloud/pubsublite/internal/wire/subscriber.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import abstractmethod, ABCMeta -from typing import AsyncContextManager +from typing import AsyncContextManager, List from google.cloud.pubsublite_v1.types import SequencedMessage, FlowControlRequest @@ -23,12 +23,12 @@ class Subscriber(AsyncContextManager, metaclass=ABCMeta): """ @abstractmethod - async def read(self) -> SequencedMessage: + async def read(self) -> List[SequencedMessage.meta.pb]: """ - Read the next message off of the stream. + Read a batch of messages off of the stream. Returns: - The next message. + The next batch of messages. Raises: GoogleAPICallError: On a permanent error. diff --git a/google/cloud/pubsublite/internal/wire/subscriber_impl.py b/google/cloud/pubsublite/internal/wire/subscriber_impl.py index 89b9a2a2..02466bee 100644 --- a/google/cloud/pubsublite/internal/wire/subscriber_impl.py +++ b/google/cloud/pubsublite/internal/wire/subscriber_impl.py @@ -14,7 +14,7 @@ import asyncio from copy import deepcopy -from typing import Optional +from typing import Optional, List from google.api_core.exceptions import GoogleAPICallError, FailedPrecondition @@ -59,7 +59,7 @@ class SubscriberImpl( _reinitializing: bool _last_received_offset: Optional[int] - _message_queue: "asyncio.Queue[SequencedMessage]" + _message_queue: "asyncio.Queue[List[SequencedMessage.meta.pb]]" _receiver: Optional[asyncio.Future] _flusher: Optional[asyncio.Future] @@ -110,8 +110,10 @@ def _handle_response(self, response: SubscribeResponse): ) ) return - self._outstanding_flow_control.on_messages(response.messages.messages) - for message in response.messages.messages: + # Workaround for incredibly slow proto-plus-python accesses + messages = list(response.messages.messages._pb) + self._outstanding_flow_control.on_messages(messages) + for message in messages: if ( self._last_received_offset is not None and message.cursor.offset <= self._last_received_offset @@ -125,9 +127,8 @@ def _handle_response(self, response: SubscribeResponse): ) return self._last_received_offset = message.cursor.offset - for message in response.messages.messages: - # queue is unbounded. - self._message_queue.put_nowait(message) + # queue is unbounded. + self._message_queue.put_nowait(messages) async def _receive_loop(self): while True: @@ -163,12 +164,14 @@ async def reinitialize( if last_error and is_reset_signal(last_error): # Discard undelivered messages and refill flow control tokens. while not self._message_queue.empty(): - msg = self._message_queue.get_nowait() + batch: List[SequencedMessage.meta.pb] = self._message_queue.get_nowait() + allowed_bytes = sum(message.size_bytes for message in batch) self._outstanding_flow_control.add( FlowControlRequest( - allowed_messages=1, allowed_bytes=msg.size_bytes, + allowed_messages=len(batch), allowed_bytes=allowed_bytes, ) ) + await self._reset_handler.handle_reset() self._last_received_offset = None initial = deepcopy(self._base_initial) @@ -195,7 +198,7 @@ async def reinitialize( self._reinitializing = False self._start_loopers() - async def read(self) -> SequencedMessage: + async def read(self) -> List[SequencedMessage.meta.pb]: return await self._connection.await_unless_failed(self._message_queue.get()) async def allow_flow(self, request: FlowControlRequest): diff --git a/google/cloud/pubsublite/types/message_metadata.py b/google/cloud/pubsublite/types/message_metadata.py index cd6b2dc5..3a9d1437 100644 --- a/google/cloud/pubsublite/types/message_metadata.py +++ b/google/cloud/pubsublite/types/message_metadata.py @@ -13,8 +13,8 @@ # limitations under the License. from typing import NamedTuple -import json +from google.cloud.pubsublite.internal import fast_serialize from google.cloud.pubsublite_v1.types.common import Cursor from google.cloud.pubsublite.types.partition import Partition @@ -24,14 +24,15 @@ class MessageMetadata(NamedTuple): cursor: Cursor def encode(self) -> str: - return json.dumps( - {"partition": self.partition.value, "offset": self.cursor.offset} - ) + return self._encode_parts(self.partition.value, self.cursor._pb.offset) + + @staticmethod + def _encode_parts(partition: int, offset: int) -> str: + return fast_serialize.dump([partition, offset]) @staticmethod def decode(source: str) -> "MessageMetadata": - loaded = json.loads(source) - return MessageMetadata( - partition=Partition(loaded["partition"]), - cursor=Cursor(offset=loaded["offset"]), - ) + loaded = fast_serialize.load(source) + cursor = Cursor() + cursor._pb.offset = loaded[1] + return MessageMetadata(partition=Partition(loaded[0]), cursor=cursor) diff --git a/noxfile.py b/noxfile.py index ff06f0e6..e2cd0e64 100644 --- a/noxfile.py +++ b/noxfile.py @@ -112,6 +112,7 @@ def default(session): "--cov-config=.coveragerc", "--cov-report=", "--cov-fail-under=0", + "-v", os.path.join("tests", "unit"), *session.posargs, ) diff --git a/tests/unit/pubsublite/cloudpubsub/internal/multiplexed_async_subscriber_client_test.py b/tests/unit/pubsublite/cloudpubsub/internal/multiplexed_async_subscriber_client_test.py index 7283f057..4fe0a9bc 100644 --- a/tests/unit/pubsublite/cloudpubsub/internal/multiplexed_async_subscriber_client_test.py +++ b/tests/unit/pubsublite/cloudpubsub/internal/multiplexed_async_subscriber_client_test.py @@ -66,7 +66,8 @@ async def test_iterator( ): read_queues = wire_queues(default_subscriber.read) subscription = SubscriptionPath(1, CloudZone.parse("us-central1-a"), "abc") - message = Message(PubsubMessage(message_id="1")._pb, "", 0, None) + message1 = Message(PubsubMessage(message_id="1")._pb, "", 0, None) + message2 = Message(PubsubMessage(message_id="2")._pb, "", 0, None) async with multiplexed_client: iterator = await multiplexed_client.subscribe( subscription, DISABLED_FLOW_CONTROL @@ -78,8 +79,9 @@ async def test_iterator( assert not read_fut_1.done() await read_queues.called.get() default_subscriber.read.assert_has_calls([call()]) - await read_queues.results.put(message) - assert await read_fut_1 is message + await read_queues.results.put([message1, message2]) + assert await read_fut_1 is message1 + assert await iterator.__anext__() is message2 read_fut_2 = asyncio.ensure_future(iterator.__anext__()) assert not read_fut_2.done() await read_queues.called.get() diff --git a/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py b/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py index 4a151484..6949fddc 100644 --- a/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py +++ b/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py @@ -13,8 +13,7 @@ # limitations under the License. import asyncio -import json -from typing import Callable +from typing import Callable, List from asynctest.mock import MagicMock, call import pytest @@ -26,6 +25,7 @@ from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker from google.cloud.pubsublite.cloudpubsub.internal.single_partition_subscriber import ( SinglePartitionSingleSubscriber, + _AckId, ) from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer from google.cloud.pubsublite.cloudpubsub.nack_handler import NackHandler @@ -49,7 +49,7 @@ def mock_async_context_manager(cm): def ack_id(generation, offset) -> str: - return json.dumps({"generation": generation, "offset": offset}) + return _AckId(generation, offset).encode() @pytest.fixture() @@ -80,12 +80,14 @@ def nack_handler(): return MagicMock(spec=NackHandler) +def return_message(source: SequencedMessage): + return PubsubMessage(message_id=str(source.cursor.offset)) + + @pytest.fixture() def transformer(): result = MagicMock(spec=MessageTransformer) - result.transform.side_effect = lambda source: PubsubMessage( - message_id=str(source.cursor.offset) - ) + result.transform.side_effect = return_message return result @@ -117,7 +119,7 @@ async def test_init(subscriber, underlying, ack_set_tracker, initial_flow_reques async def test_failed_transform(subscriber, underlying, transformer): async with subscriber: transformer.transform.side_effect = FailedPrecondition("Bad message") - underlying.read.return_value = SequencedMessage() + underlying.read.return_value = [SequencedMessage()._pb] with pytest.raises(FailedPrecondition): await subscriber.read() @@ -131,15 +133,15 @@ async def test_ack( ack_called_queue, ack_result_queue ) async with subscriber: - message_1 = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) - message_2 = SequencedMessage(cursor=Cursor(offset=2), size_bytes=10) - underlying.read.return_value = message_1 - read_1: Message = await subscriber.read() - ack_set_tracker.track.assert_has_calls([call(1)]) - assert read_1.message_id == "1" - underlying.read.return_value = message_2 - read_2: Message = await subscriber.read() + message_1 = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5)._pb + message_2 = SequencedMessage(cursor=Cursor(offset=2), size_bytes=10)._pb + underlying.read.return_value = [message_1, message_2] + read: List[Message] = await subscriber.read() + assert len(read) == 2 + read_1 = read[0] + read_2 = read[1] ack_set_tracker.track.assert_has_calls([call(1), call(2)]) + assert read_1.message_id == "1" assert read_2.message_id == "2" read_2.ack() await ack_called_queue.get() @@ -159,8 +161,8 @@ async def test_track_failure( ): async with subscriber: ack_set_tracker.track.side_effect = FailedPrecondition("Bad track") - message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) - underlying.read.return_value = message + message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5)._pb + underlying.read.return_value = [message] with pytest.raises(FailedPrecondition): await subscriber.read() ack_set_tracker.track.assert_has_calls([call(1)]) @@ -178,11 +180,12 @@ async def test_ack_failure( ack_called_queue, ack_result_queue ) async with subscriber: - message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) - underlying.read.return_value = message - read: Message = await subscriber.read() + message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5)._pb + underlying.read.return_value = [message] + read: List[Message] = await subscriber.read() + assert len(read) == 1 ack_set_tracker.track.assert_has_calls([call(1)]) - read.ack() + read[0].ack() await ack_called_queue.get() ack_set_tracker.ack.assert_has_calls([call(1)]) await ack_result_queue.put(FailedPrecondition("Bad ack")) @@ -203,12 +206,13 @@ async def test_nack_failure( nack_handler, ): async with subscriber: - message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) - underlying.read.return_value = message - read: Message = await subscriber.read() + message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5)._pb + underlying.read.return_value = [message] + read: List[Message] = await subscriber.read() + assert len(read) == 1 ack_set_tracker.track.assert_has_calls([call(1)]) nack_handler.on_nack.side_effect = FailedPrecondition("Bad nack") - read.nack() + read[0].nack() async def sleep_forever(): await asyncio.sleep(float("inf")) @@ -231,9 +235,10 @@ async def test_nack_calls_ack( ack_called_queue, ack_result_queue ) async with subscriber: - message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) - underlying.read.return_value = message - read: Message = await subscriber.read() + message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5)._pb + underlying.read.return_value = [message] + read: List[Message] = await subscriber.read() + assert len(read) == 1 ack_set_tracker.track.assert_has_calls([call(1)]) def on_nack(nacked: PubsubMessage, ack: Callable[[], None]): @@ -241,7 +246,7 @@ def on_nack(nacked: PubsubMessage, ack: Callable[[], None]): ack() nack_handler.on_nack.side_effect = on_nack - read.nack() + read[0].nack() await ack_called_queue.get() await ack_result_queue.put(None) ack_set_tracker.ack.assert_has_calls([call(1)]) @@ -259,27 +264,29 @@ async def test_handle_reset( ack_called_queue, ack_result_queue ) async with subscriber: - message_1 = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) - underlying.read.return_value = message_1 - read_1: Message = await subscriber.read() + message_1 = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5)._pb + underlying.read.return_value = [message_1] + read_1: List[Message] = await subscriber.read() + assert len(read_1) == 1 ack_set_tracker.track.assert_has_calls([call(1)]) - assert read_1.message_id == "1" - assert read_1.ack_id == ack_id(0, 1) + assert read_1[0].message_id == "1" + assert read_1[0].ack_id == ack_id(0, 1) await subscriber.handle_reset() ack_set_tracker.clear_and_commit.assert_called_once() # Message ACKed after reset. Its flow control tokens are refilled # but offset not committed (verified below after message 2). - read_1.ack() + read_1[0].ack() - message_2 = SequencedMessage(cursor=Cursor(offset=2), size_bytes=10) - underlying.read.return_value = message_2 - read_2: Message = await subscriber.read() + message_2 = SequencedMessage(cursor=Cursor(offset=2), size_bytes=10)._pb + underlying.read.return_value = [message_2] + read_2: List[Message] = await subscriber.read() + assert len(read_2) == 1 ack_set_tracker.track.assert_has_calls([call(1), call(2)]) - assert read_2.message_id == "2" - assert read_2.ack_id == ack_id(1, 2) - read_2.ack() + assert read_2[0].message_id == "2" + assert read_2[0].ack_id == ack_id(1, 2) + read_2[0].ack() await ack_called_queue.get() await ack_result_queue.put(None) underlying.allow_flow.assert_has_calls( diff --git a/tests/unit/pubsublite/cloudpubsub/internal/subscriber_impl_test.py b/tests/unit/pubsublite/cloudpubsub/internal/subscriber_impl_test.py index f3899719..67b25b2a 100644 --- a/tests/unit/pubsublite/cloudpubsub/internal/subscriber_impl_test.py +++ b/tests/unit/pubsublite/cloudpubsub/internal/subscriber_impl_test.py @@ -16,6 +16,7 @@ import concurrent from concurrent.futures.thread import ThreadPoolExecutor from queue import Queue +from typing import List from asynctest.mock import MagicMock import pytest @@ -94,16 +95,17 @@ def test_messages_received( ): message1 = Message(PubsubMessage(message_id="1")._pb, "", 0, None) message2 = Message(PubsubMessage(message_id="2")._pb, "", 0, None) + message3 = Message(PubsubMessage(message_id="3")._pb, "", 0, None) counter = Box[int]() counter.val = 0 - async def on_read() -> Message: + async def on_read() -> List[Message]: counter.val += 1 if counter.val == 1: - return message1 + return [message1, message2] if counter.val == 2: - return message2 + return [message3] await sleep_forever() async_subscriber.read.side_effect = on_read @@ -115,4 +117,5 @@ async def on_read() -> Message: subscriber.__enter__() assert results.get() == "1" assert results.get() == "2" + assert results.get() == "3" subscriber.close() diff --git a/tests/unit/pubsublite/cloudpubsub/message_transforms_test.py b/tests/unit/pubsublite/cloudpubsub/message_transforms_test.py index 206a228d..564db385 100644 --- a/tests/unit/pubsublite/cloudpubsub/message_transforms_test.py +++ b/tests/unit/pubsublite/cloudpubsub/message_transforms_test.py @@ -24,6 +24,7 @@ PUBSUB_LITE_EVENT_TIME, to_cps_subscribe_message, encode_attribute_event_time, + decode_attribute_event_time, from_cps_publish_message, add_id_to_cps_subscribe_transformer, ) @@ -104,7 +105,7 @@ def test_subscribe_transform_correct(): "x": "abc", "y": "abc", PUBSUB_LITE_EVENT_TIME: encode_attribute_event_time( - Timestamp(seconds=55).ToDatetime() + Timestamp(seconds=55).ToDatetime().replace(tzinfo=datetime.timezone.utc) ), }, publish_time=Timestamp(seconds=10), @@ -163,7 +164,7 @@ def test_wrapped_successful(): "x": "abc", "y": "abc", PUBSUB_LITE_EVENT_TIME: encode_attribute_event_time( - Timestamp(seconds=55).ToDatetime() + Timestamp(seconds=55).ToDatetime().replace(tzinfo=datetime.timezone.utc) ), }, message_id=MessageMetadata(Partition(1), Cursor(offset=10)).encode(), @@ -199,6 +200,10 @@ def test_publish_invalid_event_time(): def test_publish_valid_transform(): now = datetime.datetime.now() + encoded_event_time = encode_attribute_event_time(now) + assert decode_attribute_event_time(encoded_event_time) == now.astimezone( + datetime.timezone.utc + ) expected = PubSubMessage( data=b"xyz", key=b"def", @@ -208,15 +213,14 @@ def test_publish_valid_transform(): "y": AttributeValues(values=[b"abc"]), }, ) - result = from_cps_publish_message( - PubsubMessage( - data=b"xyz", - ordering_key="def", - attributes={ - "x": "abc", - "y": "abc", - PUBSUB_LITE_EVENT_TIME: encode_attribute_event_time(now), - }, - ) + initial = PubsubMessage( + data=b"xyz", + ordering_key="def", + attributes={ + "x": "abc", + "y": "abc", + PUBSUB_LITE_EVENT_TIME: encode_attribute_event_time(now), + }, ) + result = from_cps_publish_message(initial) assert result == expected diff --git a/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py b/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py index 8656aa07..82df4b2c 100644 --- a/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py +++ b/tests/unit/pubsublite/internal/wire/subscriber_impl_test.py @@ -287,8 +287,8 @@ async def test_message_receipt( write_called_queue = asyncio.Queue() write_result_queue = asyncio.Queue() flow = FlowControlRequest(allowed_messages=100, allowed_bytes=100) - message_1 = SequencedMessage(cursor=Cursor(offset=3), size_bytes=5) - message_2 = SequencedMessage(cursor=Cursor(offset=5), size_bytes=10) + message_1 = SequencedMessage(cursor=Cursor(offset=3), size_bytes=5)._pb + message_2 = SequencedMessage(cursor=Cursor(offset=5), size_bytes=10)._pb default_connection.write.side_effect = make_queue_waiter( write_called_queue, write_result_queue ) @@ -317,15 +317,16 @@ async def test_message_receipt( [call(initial_request), call(as_request(flow))] ) - message1_fut = asyncio.ensure_future(subscriber.read()) + batch1_fut = asyncio.ensure_future(subscriber.read()) # Send messages to the subscriber. await read_result_queue.put(as_response([message_1, message_2])) # Wait for the next read call await read_called_queue.get() - assert (await message1_fut) == message_1 - assert (await subscriber.read()) == message_2 + batch1 = await batch1_fut + assert batch1[0].SerializeToString() == message_1.SerializeToString() + assert batch1[1].SerializeToString() == message_2.SerializeToString() # Fail the connection with a retryable error await read_called_queue.get() @@ -375,8 +376,8 @@ async def test_out_of_order_receipt_failure( write_called_queue = asyncio.Queue() write_result_queue = asyncio.Queue() flow = FlowControlRequest(allowed_messages=100, allowed_bytes=100) - message_1 = SequencedMessage(cursor=Cursor(offset=3), size_bytes=5) - message_2 = SequencedMessage(cursor=Cursor(offset=5), size_bytes=10) + message_1 = SequencedMessage(cursor=Cursor(offset=3), size_bytes=5)._pb + message_2 = SequencedMessage(cursor=Cursor(offset=5), size_bytes=10)._pb default_connection.write.side_effect = make_queue_waiter( write_called_queue, write_result_queue ) @@ -431,10 +432,10 @@ async def test_handle_reset_signal( write_called_queue = asyncio.Queue() write_result_queue = asyncio.Queue() flow = FlowControlRequest(allowed_messages=100, allowed_bytes=100) - message_1 = SequencedMessage(cursor=Cursor(offset=2), size_bytes=5) - message_2 = SequencedMessage(cursor=Cursor(offset=4), size_bytes=10) + message_1 = SequencedMessage(cursor=Cursor(offset=2), size_bytes=5)._pb + message_2 = SequencedMessage(cursor=Cursor(offset=4), size_bytes=10)._pb # Ensure messages with earlier offsets can be handled post-reset. - message_3 = SequencedMessage(cursor=Cursor(offset=1), size_bytes=20) + message_3 = SequencedMessage(cursor=Cursor(offset=1), size_bytes=20)._pb default_connection.write.side_effect = make_queue_waiter( write_called_queue, write_result_queue ) @@ -464,11 +465,13 @@ async def test_handle_reset_signal( ) # Send messages to the subscriber. - await read_result_queue.put(as_response([message_1, message_2])) + await read_result_queue.put(as_response([message_1])) + await read_result_queue.put(as_response([message_2])) # Read one message. await read_called_queue.get() - assert (await subscriber.read()) == message_1 + batch1 = await subscriber.read() + assert batch1[0].SerializeToString() == message_1.SerializeToString() # Fail the connection with an error containing the RESET signal. await read_called_queue.get() @@ -501,4 +504,5 @@ async def test_handle_reset_signal( # Ensure the subscriber accepts an earlier message. await read_result_queue.put(as_response([message_3])) await read_called_queue.get() - assert (await subscriber.read()) == message_3 + batch2 = await subscriber.read() + assert batch2[0].SerializeToString() == message_3.SerializeToString()