Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to next_msg and tasks cancellation #446

Merged
merged 11 commits into from
May 8, 2023
11 changes: 5 additions & 6 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ cache:
- $HOME/nats-server

python:
- "3.8"
- "3.9"
- "3.10"
- "3.11"

before_install:
- bash ./scripts/install_nats.sh
Expand All @@ -29,11 +29,10 @@ dist: focal

jobs:
include:
- name: "Python: 3.9/uvloop"
python: "3.9"
- name: "Python: 3.11/uvloop"
python: "3.11"
install:
- pip install uvloop
allow_failures:
- python: "3.8"
- python: "3.9"
- name: "Python: 3.9/uvloop"
- python: "3.11"
- name: "Python: 3.11/uvloop"
10 changes: 8 additions & 2 deletions nats/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
)
from .transport import TcpTransport, Transport, WebSocketTransport

__version__ = '2.2.0'
__version__ = '2.3.0'
__lang__ = 'python3'
_logger = logging.getLogger(__name__)
PROTOCOL = 1
Expand Down Expand Up @@ -666,11 +666,17 @@ async def _close(self, status: int, do_cbs: bool = True) -> None:
# Cleanup subscriptions since not reconnecting so no need
# to replay the subscriptions anymore.
for sub in self._subs.values():
# FIXME: Should we clear the pending queue here?
# Async subs use join when draining already so just cancel here.
if sub._wait_for_msgs_task and not sub._wait_for_msgs_task.done():
sub._wait_for_msgs_task.cancel()
if sub._message_iterator:
sub._message_iterator._cancel()
# Sync subs may have some inflight next_msg calls that could be blocking
# so cancel them here to unblock them.
if sub._pending_next_msgs_calls:
for fut in sub._pending_next_msgs_calls.values():
fut.cancel("nats: connection is closed")
sub._pending_next_msgs_calls.clear()
self._subs.clear()

if self._transport is not None:
Expand Down
60 changes: 39 additions & 21 deletions nats/aio/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def __init__(
self._pending_queue: asyncio.Queue[Msg] = asyncio.Queue(
maxsize=pending_msgs_limit
)
# If no callback, then this is a sync subscription which will
# require tracking the next_msg calls inflight for cancelling.
if cb is None:
self._pending_next_msgs_calls = {}
else:
self._pending_next_msgs_calls = None
self._pending_size = 0
self._wait_for_msgs_task = None
self._message_iterator = None
Expand Down Expand Up @@ -148,34 +154,49 @@ async def next_msg(self, timeout: float | None = 1.0) -> Msg:
:params timeout: Time in seconds to wait for next message before timing out.
:raises nats.errors.TimeoutError:

next_msg can be used to retrieve the next message
from a stream of messages using await syntax, this
only works when not passing a callback on `subscribe`::
next_msg can be used to retrieve the next message from a stream of messages using
await syntax, this only works when not passing a callback on `subscribe`::

sub = await nc.subscribe('hello')
msg = await sub.next_msg(timeout=1)

"""
future: asyncio.Future[Msg] = asyncio.Future()
if self._conn.is_closed:
raise errors.ConnectionClosedError

async def _next_msg() -> None:
msg = await self._pending_queue.get()
self._pending_size -= len(msg.data)
future.set_result(msg)
if self._cb:
raise errors.Error(
'nats: next_msg cannot be used in async subscriptions'
)

task = asyncio.get_running_loop().create_task(_next_msg())
msg = None
future = None
task_name = None
try:
msg = await asyncio.wait_for(future, timeout)
future = asyncio.create_task(
asyncio.wait_for(self._pending_queue.get(), timeout)
)
task_name = future.get_name()
self._pending_next_msgs_calls[task_name] = future
msg = await future
self._pending_size -= len(msg.data)
return msg
except asyncio.TimeoutError:
future.cancel()
task.cancel()
if self._conn.is_closed:
raise errors.ConnectionClosedError
raise errors.TimeoutError
except asyncio.CancelledError:
future.cancel()
task.cancel()
# Call timeout otherwise would get an empty message.
raise errors.TimeoutError
if self._conn.is_closed:
raise errors.ConnectionClosedError
raise
finally:
if self._pending_next_msgs_calls and task_name in self._pending_next_msgs_calls:
del self._pending_next_msgs_calls[task_name]
if msg:
# For sync subscriptions we will consider a message
# to be done once it has been consumed by the client
# regardless of whether it has been processed.
self._pending_queue.task_done()

def _start(self, error_cb):
"""
Expand Down Expand Up @@ -231,9 +252,7 @@ async def _drain(self) -> None:
# messages so can throw it away now.
self._conn._remove_sub(self._id)
except asyncio.CancelledError:
# In case draining of a connection times out then
# the sub per task will be canceled as well.
pass
raise
finally:
self._closed = True

Expand Down Expand Up @@ -298,13 +317,12 @@ async def _wait_for_msgs(self, error_cb) -> None:
if error_cb:
await error_cb(e)
finally:
# indicate the message finished processing so drain can continue
# indicate the message finished processing so drain can continue.
self._pending_queue.task_done()

# Apply auto unsubscribe checks after having processed last msg.
if self._max_msgs > 0 and self._received >= self._max_msgs and self._pending_queue.empty:
self._stop_processing()

except asyncio.CancelledError:
break

Expand Down
10 changes: 8 additions & 2 deletions nats/js/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ def __init__(
self._pending_size = sub._pending_size
self._wait_for_msgs_task = sub._wait_for_msgs_task
self._message_iterator = sub._message_iterator
self._pending_next_msgs_calls = sub._pending_next_msgs_calls

async def consumer_info(self) -> api.ConsumerInfo:
"""
Expand Down Expand Up @@ -881,8 +882,13 @@ async def _fetch_n(
)
await asyncio.sleep(0)

# Wait for first message or timeout.
msg = await self._sub.next_msg(timeout)
try:
msg = await self._sub.next_msg(timeout)
except asyncio.TimeoutError:
if msgs:
return msgs
raise

status = JetStreamContext.is_status_msg(msg)
if JetStreamContext._is_processable_msg(status, msg):
# First processable message received, do not raise error from now.
Expand Down
2 changes: 1 addition & 1 deletion scripts/install_nats.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

set -e

export DEFAULT_NATS_SERVER_VERSION=v2.9.7
export DEFAULT_NATS_SERVER_VERSION=v2.9.16

export NATS_SERVER_VERSION="${NATS_SERVER_VERSION:=$DEFAULT_NATS_SERVER_VERSION}"

Expand Down
52 changes: 47 additions & 5 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,12 +626,16 @@ async def test_subscribe_next_msg(self):

# Wait for another message, the future should not linger
# after the cancellation.
# FIXME: Flapping...
# future = sub.next_msg(timeout=None)

future = sub.next_msg(timeout=2)
task = asyncio.create_task(asyncio.wait_for(future, timeout=2))
await nc.close()

# await future
# Unblocked pending calls get a connection closed errors now.
start = time.time()
with self.assertRaises(nats.errors.ConnectionClosedError):
await task
end = time.time()
assert (end - start) < 0.5

@async_test
async def test_subscribe_next_msg_custom_limits(self):
Expand Down Expand Up @@ -672,6 +676,25 @@ async def error_cb(err):
assert sub.pending_bytes == 0
await nc.close()

@async_test
async def test_subscribe_next_msg_with_cb_not_supported(self):
nc = await nats.connect()

async def handler(msg):
await msg.respond(b'OK')

sub = await nc.subscribe('foo', cb=handler)
await nc.flush()

for i in range(0, 2):
await nc.publish(f"tests.{i}", b'bar')
await nc.flush()

with self.assertRaises(nats.errors.Error):
await sub.next_msg()

await nc.close()

@async_test
async def test_subscribe_without_coroutine_unsupported(self):
nc = NATS()
Expand Down Expand Up @@ -770,7 +793,6 @@ async def slow_worker_handler(msg):
msg = await nc.request(
"slow.help", b'please', timeout=0.1, old_style=True
)
print(msg)

with self.assertRaises(nats.errors.NoRespondersError):
await nc.request("nowhere", b'please', timeout=0.1, old_style=True)
Expand Down Expand Up @@ -2639,6 +2661,26 @@ async def test_protocol_mixing(self):
servers=["tls://127.0.0.1:4222", "wss://127.0.0.1:8080"]
)

@async_test
async def test_drain_cancelled_errors_raised(self):
nc = NATS()
await nc.connect()

async def cb(msg):
await asyncio.sleep(20)

sub = await nc.subscribe(f"test.sub", cb=cb)
await nc.publish("test.sub")
await nc.publish("test.sub")
await asyncio.sleep(0.1)
with self.assertRaises(asyncio.CancelledError):
with unittest.mock.patch(
"asyncio.wait_for",
unittest.mock.AsyncMock(side_effect=asyncio.CancelledError
)):
await sub.drain()
await nc.close()


if __name__ == '__main__':
import sys
Expand Down
36 changes: 36 additions & 0 deletions tests/test_js.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random
import time
import unittest
from unittest import mock
import uuid
import json

Expand Down Expand Up @@ -724,6 +725,40 @@ async def error_cb(err):

await nc.close()

@async_test
async def test_fetch_cancelled_errors_raised(self):
import tracemalloc
tracemalloc.start()

nc = NATS()
await nc.connect()

js = nc.jetstream()

await js.add_stream(name="test", subjects=["test.a"])
await js.add_consumer(
"test",
durable_name="a",
deliver_policy=nats.js.api.DeliverPolicy.ALL,
max_deliver=20,
max_waiting=512,
max_ack_pending=1024,
filter_subject="test.a"
)

sub = await js.pull_subscribe("test.a", "test", stream="test")

# FIXME: RuntimeWarning: coroutine 'Queue.get' was never awaited
# is raised here due to the mock usage.
with self.assertRaises(asyncio.CancelledError):
with unittest.mock.patch(
"asyncio.wait_for",
unittest.mock.AsyncMock(side_effect=asyncio.CancelledError
)):
await sub.fetch(batch=1, timeout=0.1)

await nc.close()


class JSMTest(SingleJetStreamServerTestCase):

Expand Down Expand Up @@ -760,6 +795,7 @@ async def test_stream_management(self):
# Get info
current = await jsm.stream_info("hello")
stream.did_create = None
current.cluster = None
assert stream == current

assert isinstance(current, nats.js.api.StreamInfo)
Expand Down