Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ensure ack() doesn't wait on stream messages #234

Merged
merged 4 commits into from
Sep 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def track(self, offset: int):
"""

@abstractmethod
async def ack(self, offset: int):
def ack(self, offset: int):
"""
Acknowledge the message with the provided offset. The offset must have previously been tracked.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Optional

from google.api_core.exceptions import FailedPrecondition

from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker
from google.cloud.pubsublite.internal.wire.committer import Committer
from google.cloud.pubsublite_v1 import Cursor
Expand All @@ -43,9 +44,7 @@ def track(self, offset: int):
)
self._receipts.append(offset)

async def ack(self, offset: int):
# Note: put_nowait is used here and below to ensure that the below logic is executed without yielding
# to another coroutine in the event loop. The queue is unbounded so it will never throw.
def ack(self, offset: int):
self._acks.put_nowait(offset)
prefix_acked_offset: Optional[int] = None
while len(self._receipts) != 0 and not self._acks.empty():
Expand All @@ -60,7 +59,7 @@ async def ack(self, offset: int):
if prefix_acked_offset is None:
return
# Convert from last acked to first unacked.
await self._committer.commit(Cursor(offset=prefix_acked_offset + 1))
self._committer.commit(Cursor(offset=prefix_acked_offset + 1))

async def clear_and_commit(self):
self._receipts.clear()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from google.cloud.pubsublite.cloudpubsub.internal.single_subscriber import (
AsyncSingleSubscriber,
)
from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
from google.cloud.pubsublite.internal.wait_ignore_cancelled import (
wait_ignore_cancelled,
wait_ignore_errors,
)
from google.cloud.pubsublite.internal.wire.assigner import Assigner
from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable
from google.cloud.pubsublite.types import Partition
Expand Down Expand Up @@ -100,8 +103,10 @@ async def __aenter__(self):

async def __aexit__(self, exc_type, exc_value, traceback):
self._assign_poller.cancel()
await wait_ignore_cancelled(self._assign_poller)
await self._assigner.__aexit__(exc_type, exc_value, traceback)
await wait_ignore_errors(self._assign_poller)
await wait_ignore_errors(
self._assigner.__aexit__(exc_type, exc_value, traceback)
)
for running in self._subscribers.values():
await self._stop_subscriber(running)
await wait_ignore_errors(self._stop_subscriber(running))
pass
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ async def read(self) -> List[Message]:
self.fail(e)
raise e

async def _handle_ack(self, message: requests.AckRequest):
await self._underlying.allow_flow(
def _handle_ack(self, message: requests.AckRequest):
self._underlying.allow_flow(
FlowControlRequest(
allowed_messages=1,
allowed_bytes=self._messages_by_ack_id[message.ack_id].size_bytes,
Expand All @@ -138,7 +138,7 @@ async def _handle_ack(self, message: requests.AckRequest):
ack_id = _AckId.parse(message.ack_id)
if ack_id.generation == self._ack_generation_id:
try:
await self._ack_set_tracker.ack(ack_id.offset)
self._ack_set_tracker.ack(ack_id.offset)
except GoogleAPICallError as e:
self.fail(e)

Expand Down Expand Up @@ -179,7 +179,7 @@ async def _handle_queue_message(
)
)
elif isinstance(message, requests.AckRequest):
await self._handle_ack(message)
self._handle_ack(message)
else:
self._handle_nack(message)

Expand All @@ -198,7 +198,7 @@ async def __aenter__(self):
await self._ack_set_tracker.__aenter__()
await self._underlying.__aenter__()
self._looper_future = asyncio.ensure_future(self._looper())
await self._underlying.allow_flow(
self._underlying.allow_flow(
FlowControlRequest(
allowed_messages=self._flow_control_settings.messages_outstanding,
allowed_bytes=self._flow_control_settings.bytes_outstanding,
Expand Down
17 changes: 11 additions & 6 deletions google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from concurrent.futures.thread import ThreadPoolExecutor
from typing import ContextManager, Optional
from google.api_core.exceptions import GoogleAPICallError
from functools import partial

from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors
from google.cloud.pubsublite.cloudpubsub.internal.managed_event_loop import (
ManagedEventLoop,
)
Expand Down Expand Up @@ -86,8 +89,8 @@ async def _poller(self):
while True:
batch = await self._underlying.read()
self._unowned_executor.map(self._callback, batch)
except GoogleAPICallError as e: # noqa: F841 Flake8 thinks e is unused
self._unowned_executor.submit(lambda: self._fail(e)) # noqa: F821
except GoogleAPICallError as e:
self._unowned_executor.submit(partial(self._fail, e))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can't capture exceptions in lambda bindings


def __enter__(self):
assert self._close_callback is not None
Expand All @@ -97,13 +100,15 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self._poller_future.cancel()
try:
self._poller_future.cancel()
self._poller_future.result()
except concurrent.futures.CancelledError:
self._poller_future.result() # Ignore error.
except: # noqa: E722
pass
self._event_loop.submit(
self._underlying.__aexit__(exc_type, exc_value, traceback)
wait_ignore_errors(
self._underlying.__aexit__(exc_type, exc_value, traceback)
)
).result()
self._event_loop.__exit__(exc_type, exc_value, traceback)
assert self._close_callback is not None
Expand Down
8 changes: 7 additions & 1 deletion google/cloud/pubsublite/internal/wire/committer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@ class Committer(AsyncContextManager, metaclass=ABCMeta):
"""

@abstractmethod
async def commit(self, cursor: Cursor) -> None:
def commit(self, cursor: Cursor) -> None:
"""
Start the commit for a cursor.

Raises:
GoogleAPICallError: When the committer terminates in failure.
"""
pass

@abstractmethod
Expand Down
55 changes: 18 additions & 37 deletions google/cloud/pubsublite/internal/wire/committer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,12 @@
ConnectionReinitializer,
)
from google.cloud.pubsublite.internal.wire.connection import Connection
from google.cloud.pubsublite.internal.wire.serial_batcher import SerialBatcher
from google.cloud.pubsublite_v1 import Cursor
from google.cloud.pubsublite_v1.types import (
StreamingCommitCursorRequest,
StreamingCommitCursorResponse,
InitialCommitCursorRequest,
)
from google.cloud.pubsublite.internal.wire.work_item import WorkItem


_LOGGER = logging.getLogger(__name__)
Expand All @@ -53,9 +51,8 @@ class CommitterImpl(
StreamingCommitCursorRequest, StreamingCommitCursorResponse
]

_batcher: SerialBatcher[Cursor, None]

_outstanding_commits: List[List[WorkItem[Cursor, None]]]
_next_to_commit: Optional[Cursor]
_outstanding_commits: List[Cursor]

_receiver: Optional[asyncio.Future]
_flusher: Optional[asyncio.Future]
Expand All @@ -72,7 +69,7 @@ def __init__(
self._initial = initial
self._flush_seconds = flush_seconds
self._connection = RetryingConnection(factory, self)
self._batcher = SerialBatcher()
self._next_to_commit = None
self._outstanding_commits = []
self._receiver = None
self._flusher = None
Expand Down Expand Up @@ -113,9 +110,7 @@ def _handle_response(self, response: StreamingCommitCursorResponse):
)
)
for _ in range(response.commit.acknowledged_commits):
batch = self._outstanding_commits.pop(0)
for item in batch:
item.response_future.set_result(None)
self._outstanding_commits.pop(0)
if len(self._outstanding_commits) == 0:
self._empty.set()

Expand All @@ -131,39 +126,31 @@ async def _flush_loop(self):

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._stop_loopers()
if self._connection.error():
self._fail_if_retrying_failed()
else:
if not self._connection.error():
await self._flush()
await self._connection.__aexit__(exc_type, exc_val, exc_tb)

def _fail_if_retrying_failed(self):
if self._connection.error():
for batch in self._outstanding_commits:
for item in batch:
item.response_future.set_exception(self._connection.error())

async def _flush(self):
batch = self._batcher.flush()
if not batch:
if self._next_to_commit is None:
return
self._outstanding_commits.append(batch)
self._empty.clear()
req = StreamingCommitCursorRequest()
req.commit.cursor = batch[-1].request
req.commit.cursor = self._next_to_commit
self._outstanding_commits.append(self._next_to_commit)
self._next_to_commit = None
self._empty.clear()
try:
await self._connection.write(req)
except GoogleAPICallError as e:
_LOGGER.debug(f"Failed commit on stream: {e}")
self._fail_if_retrying_failed()

async def wait_until_empty(self):
await self._flush()
await self._connection.await_unless_failed(self._empty.wait())

async def commit(self, cursor: Cursor) -> None:
future = self._batcher.add(cursor)
await future
def commit(self, cursor: Cursor) -> None:
if self._connection.error():
raise self._connection.error()
self._next_to_commit = cursor

async def reinitialize(
self,
Expand All @@ -181,14 +168,8 @@ async def reinitialize(
"Received an invalid initial response on the publish stream."
)
)
if self._outstanding_commits:
# Roll up outstanding commits
rollup: List[WorkItem[Cursor, None]] = []
for batch in self._outstanding_commits:
for item in batch:
rollup.append(item)
self._outstanding_commits = [rollup]
req = StreamingCommitCursorRequest()
req.commit.cursor = rollup[-1].request
await connection.write(req)
if self._next_to_commit is None:
if self._outstanding_commits:
self._next_to_commit = self._outstanding_commits[-1]
self._outstanding_commits = []
self._start_loopers()
22 changes: 6 additions & 16 deletions google/cloud/pubsublite/internal/wire/flow_control_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@ class _AggregateRequest:
def __init__(self):
self._request = FlowControlRequest.meta.pb()

def __add__(self, other: FlowControlRequest.meta.pb):
self._request.allowed_bytes = self._request.allowed_bytes + other.allowed_bytes
def __add__(self, other: FlowControlRequest):
other_pb = other._pb
self._request.allowed_bytes = (
self._request.allowed_bytes + other_pb.allowed_bytes
)
self._request.allowed_bytes = min(self._request.allowed_bytes, _MAX_INT64)
self._request.allowed_messages = (
self._request.allowed_messages + other.allowed_messages
self._request.allowed_messages + other_pb.allowed_messages
)
self._request.allowed_messages = min(self._request.allowed_messages, _MAX_INT64)
return self
Expand Down Expand Up @@ -77,16 +80,3 @@ def release_pending_request(self) -> Optional[FlowControlRequest]:
request = self._pending_tokens
self._pending_tokens = _AggregateRequest()
return request.to_optional()

def should_expedite(self):
pending_request = self._pending_tokens._request
client_request = self._client_tokens._request
if _exceeds_expedite_ratio(
pending_request.allowed_bytes, client_request.allowed_bytes
):
return True
if _exceeds_expedite_ratio(
pending_request.allowed_messages, client_request.allowed_messages
):
return True
return False
2 changes: 1 addition & 1 deletion google/cloud/pubsublite/internal/wire/subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def read(self) -> List[SequencedMessage.meta.pb]:
raise NotImplementedError()

@abstractmethod
async def allow_flow(self, request: FlowControlRequest):
def allow_flow(self, request: FlowControlRequest):
"""
Allow an additional amount of messages and bytes to be sent to this client.
"""
Expand Down
7 changes: 1 addition & 6 deletions google/cloud/pubsublite/internal/wire/subscriber_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,5 @@ async def reinitialize(
async def read(self) -> List[SequencedMessage.meta.pb]:
return await self._connection.await_unless_failed(self._message_queue.get())

async def allow_flow(self, request: FlowControlRequest):
def allow_flow(self, request: FlowControlRequest):
self._outstanding_flow_control.add(request)
if (
not self._reinitializing
and self._outstanding_flow_control.should_expedite()
):
await self._try_send_tokens()
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ async def test_track_and_aggregate_acks(committer, tracker: AckSetTracker):
tracker.track(offset=7)

committer.commit.assert_has_calls([])
await tracker.ack(offset=3)
tracker.ack(offset=3)
committer.commit.assert_has_calls([])
await tracker.ack(offset=1)
tracker.ack(offset=1)
committer.commit.assert_has_calls([call(Cursor(offset=4))])
await tracker.ack(offset=5)
tracker.ack(offset=5)
committer.commit.assert_has_calls(
[call(Cursor(offset=4)), call(Cursor(offset=6))]
)

tracker.track(offset=8)
await tracker.ack(offset=7)
tracker.ack(offset=7)
committer.commit.assert_has_calls(
[call(Cursor(offset=4)), call(Cursor(offset=6)), call(Cursor(offset=8))]
)
Expand All @@ -74,14 +74,14 @@ async def test_clear_and_commit(committer, tracker: AckSetTracker):

with pytest.raises(FailedPrecondition):
tracker.track(offset=1)
await tracker.ack(offset=5)
tracker.ack(offset=5)
committer.commit.assert_has_calls([])

await tracker.clear_and_commit()
committer.wait_until_empty.assert_called_once()

# After clearing, it should be possible to track earlier offsets.
tracker.track(offset=1)
await tracker.ack(offset=1)
tracker.ack(offset=1)
committer.commit.assert_has_calls([call(Cursor(offset=2))])
committer.__aexit__.assert_called_once()