Skip to content

Commit

Permalink
Merge pull request #315 from nats-io/jsm-updates
Browse files Browse the repository at this point in the history
Implement more JSM methods
  • Loading branch information
wallyqs committed May 19, 2022
2 parents 7cfba21 + f8a5c99 commit dde4c82
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 8 deletions.
11 changes: 9 additions & 2 deletions nats/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,8 +937,15 @@ async def _request_new_style(
msg = await asyncio.wait_for(future, timeout)
return msg
except asyncio.TimeoutError:
# Double check that the token is there already.
self._resp_map.pop(token.decode())
try:
# Double check that the token is there already.
self._resp_map.pop(token.decode())
except KeyError:
await self._error_cb(
errors.
Error(f"nats: missing response token '{token.decode()}'")
)

future.cancel()
raise errors.TimeoutError

Expand Down
13 changes: 13 additions & 0 deletions nats/js/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,12 @@ async def _fetch_n(

return msgs

#############################
# #
# JetStream Manager Context #
# #
#############################

async def get_last_msg(
self,
stream_name: str,
Expand All @@ -924,8 +930,15 @@ async def get_last_msg(
for k, v in parsed_headers.items():
headers[k] = v
raw_msg.headers = headers

return raw_msg

######################
# #
# KeyValue Context #
# #
######################

async def key_value(self, bucket: str) -> KeyValue:
stream = KV_STREAM_TEMPLATE.format(bucket=bucket)
try:
Expand Down
99 changes: 98 additions & 1 deletion nats/js/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
#

import json
import base64
from email.parser import BytesParser
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional, List

from nats.errors import NoRespondersError
from nats.js import api
Expand All @@ -23,6 +24,8 @@
if TYPE_CHECKING:
from nats import NATS

NATS_HDR_LINE = bytearray(b'NATS/1.0\r\n')


class JetStreamManager:
"""
Expand Down Expand Up @@ -87,6 +90,26 @@ async def add_stream(
)
return api.StreamInfo.from_response(resp)

async def update_stream(
self, config: api.StreamConfig = None, **params
) -> api.StreamInfo:
"""
update_stream updates a stream.
"""
if config is None:
config = api.StreamConfig()
config = config.evolve(**params)
if config.name is None:
raise ValueError("nats: stream name is required")

data = json.dumps(config.as_dict())
resp = await self._api_request(
f"{self._prefix}.STREAM.UPDATE.{config.name}",
data.encode(),
timeout=self._timeout,
)
return api.StreamInfo.from_response(resp)

async def delete_stream(self, name: str) -> bool:
"""
Delete a stream by name.
Expand All @@ -96,6 +119,15 @@ async def delete_stream(self, name: str) -> bool:
)
return resp['success']

async def purge_stream(self, name: str) -> bool:
"""
Purge a stream by name.
"""
resp = await self._api_request(
f"{self._prefix}.STREAM.PURGE.{name}", timeout=self._timeout
)
return resp['success']

async def consumer_info(
self, stream: str, consumer: str, timeout: Optional[float] = None
):
Expand All @@ -109,6 +141,21 @@ async def consumer_info(
)
return api.ConsumerInfo.from_response(resp)

async def streams_info(self) -> List[api.StreamInfo]:
"""
streams_info retrieves a list of streams.
"""
resp = await self._api_request(
f"{self._prefix}.STREAM.LIST",
b'',
timeout=self._timeout,
)
streams = []
for stream in resp['streams']:
stream_info = api.StreamInfo.from_response(stream)
streams.append(stream_info)
return streams

async def add_consumer(
self,
stream: str,
Expand Down Expand Up @@ -148,6 +195,56 @@ async def delete_consumer(self, stream: str, consumer: str) -> bool:
)
return resp['success']

async def consumers_info(self, stream: str) -> List[api.ConsumerInfo]:
"""
consumers_info retrieves a list of consumers.
"""
resp = await self._api_request(
f"{self._prefix}.CONSUMER.LIST.{stream}",
b'',
timeout=self._timeout,
)
consumers = []
for consumer in resp['consumers']:
consumer_info = api.ConsumerInfo.from_response(consumer)
consumers.append(consumer_info)
return consumers

async def get_msg(self, stream_name: str, seq: int) -> api.RawStreamMsg:
"""
get_msg retrieves a message from a stream based on the sequence ID.
"""
req_subject = f"{self._prefix}.STREAM.MSG.GET.{stream_name}"
req = {'seq': seq}
data = json.dumps(req)
resp = await self._api_request(req_subject, data.encode())
raw_msg = api.RawStreamMsg.from_response(resp['message'])
if raw_msg.hdrs:
hdrs = base64.b64decode(raw_msg.hdrs)
raw_headers = hdrs[len(NATS_HDR_LINE):]
parsed_headers = self._hdr_parser.parsebytes(raw_headers)
headers = {}
for k, v in parsed_headers.items():
headers[k] = v
raw_msg.headers = headers

data = None
if raw_msg.data:
data = base64.b64decode(raw_msg.data)
raw_msg.data = data

return raw_msg

async def delete_msg(self, stream_name: str, seq: int) -> bool:
"""
get_msg retrieves a message from a stream based on the sequence ID.
"""
req_subject = f"{self._prefix}.STREAM.MSG.DELETE.{stream_name}"
req = {'seq': seq}
data = json.dumps(req)
resp = await self._api_request(req_subject, data.encode())
return resp['success']

async def _api_request(
self,
req_subject: str,
Expand Down
13 changes: 10 additions & 3 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,11 +599,12 @@ async def test_subscribe_next_msg(self):

# Wait for another message, the future should not linger
# after the cancellation.
future = sub.next_msg(timeout=None)
# FIXME: Flapping...
# future = sub.next_msg(timeout=None)

await nc.close()

await future
# await future

@async_test
async def test_subscribe_without_coroutine_unsupported(self):
Expand Down Expand Up @@ -727,7 +728,12 @@ async def slow_worker_handler(msg):
await asyncio.sleep(0.5)
await nc.publish(msg.reply, b'timeout by now...')

await nc.connect()
errs = []

async def err_cb(err):
errs.append(err)

await nc.connect(error_cb=err_cb)
await nc.subscribe("help", cb=worker_handler)
await nc.subscribe("slow.help", cb=slow_worker_handler)

Expand All @@ -739,6 +745,7 @@ async def slow_worker_handler(msg):
with self.assertRaises(nats.errors.TimeoutError):
await nc.request("slow.help", b'please', timeout=0.1)
await asyncio.sleep(1)
assert len(errs) == 0
await nc.close()

@async_test
Expand Down
124 changes: 122 additions & 2 deletions tests/test_js.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ async def test_auto_create_consumer(self):

@async_test
async def test_fetch_one(self):
pytest.skip("update for nats-server 2.8")
nc = NATS()
await nc.connect()

Expand Down Expand Up @@ -201,6 +200,7 @@ async def test_fetch_one(self):
for msg in msgs:
await msg.term()

await asyncio.sleep(1)
info = await js.consumer_info("TEST1", "dur", timeout=1)
assert info.num_ack_pending == 1
assert info.num_redelivered == 1
Expand Down Expand Up @@ -277,7 +277,6 @@ async def test_add_pull_consumer_via_jsm(self):

@async_long_test
async def test_fetch_n(self):
pytest.skip("update for nats-server 2.8")
nc = NATS()
await nc.connect()
js = nc.jetstream()
Expand Down Expand Up @@ -637,6 +636,20 @@ async def test_stream_management(self):
assert current.state.messages == 1
assert current.state.bytes == 47

stream_config = current.config
stream_config.subjects.append("extra")
updated_stream = await jsm.update_stream(stream_config)
assert updated_stream.config.subjects == [
'hello', 'world', 'hello.>', 'extra'
]

# Purge Stream
is_purged = await jsm.purge_stream("hello")
assert is_purged
current = await jsm.stream_info("hello")
assert current.state.messages == 0
assert current.state.bytes == 0

# Delete stream
is_deleted = await jsm.delete_stream("hello")
assert is_deleted
Expand Down Expand Up @@ -701,6 +714,112 @@ async def test_consumer_management(self):

await nc.close()

@async_test
async def test_jsm_get_delete_msg(self):
nc = NATS()
await nc.connect()
js = nc.jetstream()
jsm = nc.jsm()

# Create stream
stream = await jsm.add_stream(name="foo", subjects=["foo.>"])

await js.publish("foo.a.1", b'Hello', headers={'foo': 'bar'})
await js.publish("foo.b.1", b'World')
await js.publish("foo.c.1", b'!!!')

# GetMsg
msg = await jsm.get_msg("foo", 2)
assert msg.subject == 'foo.b.1'
assert msg.data == b'World'

msg = await jsm.get_msg("foo", 3)
assert msg.subject == 'foo.c.1'
assert msg.data == b'!!!'

msg = await jsm.get_msg("foo", 1)
assert msg.subject == 'foo.a.1'
assert msg.data == b'Hello'
assert msg.headers["foo"] == "bar"
assert msg.hdrs == 'TkFUUy8xLjANCmZvbzogYmFyDQoNCg=='

with pytest.raises(BadRequestError):
await jsm.get_msg("foo", 0)

# DeleteMsg
stream_info = await jsm.stream_info("foo")
assert stream_info.state.messages == 3

ok = await jsm.delete_msg("foo", 2)
assert ok

stream_info = await jsm.stream_info("foo")
assert stream_info.state.messages == 2

msg = await jsm.get_msg("foo", 1)
assert msg.data == b"Hello"

# Deleted message should be gone now.
with pytest.raises(NotFoundError):
await jsm.get_msg("foo", 2)

msg = await jsm.get_msg("foo", 3)
assert msg.data == b"!!!"

await nc.close()

@async_test
async def test_jsm_stream_management(self):
nc = NATS()
await nc.connect()
js = nc.jetstream()
jsm = nc.jsm()

await jsm.add_stream(name="foo")
await jsm.add_stream(name="bar")
await jsm.add_stream(name="quux")

streams = await jsm.streams_info()

expected = ["foo", "bar", "quux"]
responses = []
for stream in streams:
responses.append(stream.config.name)

for name in expected:
assert name in responses

await nc.close()

@async_test
async def test_jsm_consumer_management(self):
nc = NATS()
await nc.connect()
js = nc.jetstream()
jsm = nc.jsm()

await jsm.add_stream(name="hello", subjects=["hello"])

durables = ["a", "b", "c"]

subs = []
for durable in durables:
sub = await js.pull_subscribe("hello", durable)
subs.append(sub)

consumers = await jsm.consumers_info("hello")
assert len(consumers) == 3

expected = ["a", "b", "c"]
responses = []
for consumer in consumers:
responses.append(consumer.config.durable_name)

for name in expected:
assert name in responses

await nc.close()


class SubscribeTest(SingleJetStreamServerTestCase):

Expand Down Expand Up @@ -953,6 +1072,7 @@ async def test_double_acking_pull_subscribe(self):
await asyncio.sleep(0.5)
await msg.in_progress()
await msg.ack()
await asyncio.sleep(1)

info = await psub.consumer_info()
assert info.num_pending == 8
Expand Down

0 comments on commit dde4c82

Please sign in to comment.