diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 5cad797fe..629238bdc 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -1,4 +1,4 @@ -name: check +name: Check on: push: branches: @@ -26,8 +26,4 @@ jobs: run: uv sync --dev - name: Run format check - run: | - uv run yapf --diff --recursive nats/ - - - name: Run isort check - run: uv run isort --check-only --diff nats/src + run: uv run ruff format --check diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 165fc0d4c..11071f00f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,7 +42,7 @@ jobs: - name: Run tests run: | - uv run flake8 --ignore="W391, W503, W504, E501" ./nats/src/nats/js/ + uv run flake8 --ignore="W391, W503, W504, E501, E203" ./nats/src/nats/js/ uv run pytest -x -vv -s --continue-on-collection-errors ./nats/tests env: PATH: $HOME/nats-server:$PATH diff --git a/nats-server/src/nats/server/__init__.py b/nats-server/src/nats/server/__init__.py index b80fa79e1..7a50cab32 100644 --- a/nats-server/src/nats/server/__init__.py +++ b/nats-server/src/nats/server/__init__.py @@ -34,9 +34,7 @@ FATAL_PATTERN = re.compile(r"\[FTL\]\s+(.*)") INFO_PATTERN = re.compile(r"\[INF\]\s+(.*)") READY_PATTERN = re.compile(r"Server is ready") -LISTENING_PATTERN = re.compile( - r"Listening for client connections on (.+):(\d+)" -) +LISTENING_PATTERN = re.compile(r"Listening for client connections on (.+):(\d+)") class ServerError(Exception): @@ -279,7 +277,7 @@ async def wait_ready() -> tuple[str, int]: if match := LISTENING_PATTERN.search(stderr_line): host_part = match.group(1) - if host_part.startswith('[') and host_part.endswith(']'): + if host_part.startswith("[") and host_part.endswith("]"): host = host_part[1:-1] else: host = host_part @@ -298,13 +296,11 @@ async def wait_ready() -> tuple[str, int]: if returncode != 0: msg = f"Server exited with code {returncode}" if error_lines: - errors = '\n'.join(error_lines) + errors = "\n".join(error_lines) msg += f"\nErrors:\n{errors}" raise ServerError(msg) - raise ServerError( - "Server ended without becoming ready" - ) # pragma: no cover + raise ServerError("Server ended without becoming ready") # pragma: no cover return await asyncio.wait_for(wait_ready(), timeout=timeout) @@ -413,14 +409,14 @@ async def run_cluster( node_store_dir = None if jetstream and store_dir: # Use as base directory and create subdirectory for each node - node_store_dir = os.path.join(store_dir, f"node{i+1}") + node_store_dir = os.path.join(store_dir, f"node{i + 1}") os.makedirs(node_store_dir, exist_ok=True) server = await _run_cluster_node( config_path=config_path, port=available_ports[i], routes=routes, - name=f"node{i+1}", + name=f"node{i + 1}", cluster_name="cluster", cluster_port=cluster_ports[i], jetstream=jetstream, @@ -465,9 +461,7 @@ async def _run_cluster_node( """ # Build cluster URL and routes string for CLI cluster_url = f"nats://127.0.0.1:{cluster_port}" - routes_str = ",".join( - f"nats://127.0.0.1:{r}" for r in routes - ) if routes else None + routes_str = ",".join(f"nats://127.0.0.1:{r}" for r in routes) if routes else None process = await _create_server_process( port=port, @@ -480,8 +474,6 @@ async def _run_cluster_node( config_path=config_path if config_path else None, ) - assigned_host, assigned_port = await _wait_for_server_ready( - process, timeout=10.0 - ) + assigned_host, assigned_port = await _wait_for_server_ready(process, timeout=10.0) return Server(process, assigned_host, assigned_port) diff --git a/nats-server/tests/conftest.py b/nats-server/tests/conftest.py index 0d245df64..9ad5da3c0 100644 --- a/nats-server/tests/conftest.py +++ b/nats-server/tests/conftest.py @@ -14,10 +14,9 @@ def get_nats_server_version(): """Get the nats-server version or fail if not installed.""" try: - result = subprocess.run(["nats-server", "--version"], - capture_output=True, - check=True, - text=True) + result = subprocess.run( + ["nats-server", "--version"], capture_output=True, check=True, text=True + ) return result.stdout.strip() or result.stderr.strip() except (subprocess.SubprocessError, FileNotFoundError) as e: pytest.fail(f"nats-server is not installed or not in PATH: {e}") diff --git a/nats-server/tests/test_server.py b/nats-server/tests/test_server.py index 1f011eeda..ad7984a4e 100644 --- a/nats-server/tests/test_server.py +++ b/nats-server/tests/test_server.py @@ -18,6 +18,7 @@ class ServerInfo(TypedDict): See: https://docs.nats.io/reference/reference-protocols/nats-protocol#info """ + # Required fields server_id: str server_name: str @@ -264,9 +265,7 @@ async def test_run_with_store_dir_as_file(tmp_path): # Try to start server with JetStream using a file as store_dir with pytest.raises(ServerError) as exc_info: - await run( - port=0, jetstream=True, store_dir=str(store_file), timeout=2.0 - ) + await run(port=0, jetstream=True, store_dir=str(store_file), timeout=2.0) # Verify the error message indicates the storage directory issue error_msg = str(exc_info.value).lower() @@ -522,9 +521,7 @@ async def test_cluster_with_conflicting_config(tmp_path): """Test run_cluster with config that includes cluster settings.""" # The function should still work, merging config with generated cluster setup cluster = await run_cluster( - "tests/configs/jetstream.conf", - jetstream=True, - store_dir=str(tmp_path) + "tests/configs/jetstream.conf", jetstream=True, store_dir=str(tmp_path) ) try: @@ -549,8 +546,10 @@ async def test_run_with_invalid_host(): with pytest.raises(ServerError) as exc_info: await run(host="999.999.999.999", port=0, timeout=2.0) - assert "exited" in str(exc_info.value - ).lower() or "error" in str(exc_info.value).lower() + assert ( + "exited" in str(exc_info.value).lower() + or "error" in str(exc_info.value).lower() + ) async def test_cluster_client_url(): diff --git a/nats/benchmark/latency_perf.py b/nats/benchmark/latency_perf.py index dcfd54018..ae0fb9bd0 100644 --- a/nats/benchmark/latency_perf.py +++ b/nats/benchmark/latency_perf.py @@ -34,9 +34,7 @@ def show_usage_and_die(): async def main(): parser = argparse.ArgumentParser() - parser.add_argument( - "-n", "--iterations", default=DEFAULT_ITERATIONS, type=int - ) + parser.add_argument("-n", "--iterations", default=DEFAULT_ITERATIONS, type=int) parser.add_argument("-S", "--subject", default="test") parser.add_argument("--servers", default=[], action="append") args = parser.parse_args() @@ -60,11 +58,7 @@ async def handler(msg): start = time.monotonic() to_send = args.iterations - print( - "Sending {} request/responses on [{}]".format( - args.iterations, args.subject - ) - ) + print("Sending {} request/responses on [{}]".format(args.iterations, args.subject)) while to_send > 0: to_send -= 1 if to_send == 0: diff --git a/nats/benchmark/parser_perf.py b/nats/benchmark/parser_perf.py index 11e869631..ec6136016 100644 --- a/nats/benchmark/parser_perf.py +++ b/nats/benchmark/parser_perf.py @@ -5,7 +5,6 @@ class DummyNatsClient: - def __init__(self): self._subs = {} self._pongs = [] @@ -40,9 +39,7 @@ async def _process_err(self, err=None): def generate_msg(subject, nbytes, reply=""): msg = [] - protocol_line = "MSG {subject} 1 {reply} {nbytes}\r\n".format( - subject=subject, reply=reply, nbytes=nbytes - ).encode() + protocol_line = "MSG {subject} 1 {reply} {nbytes}\r\n".format(subject=subject, reply=reply, nbytes=nbytes).encode() msg.append(protocol_line) msg.append(b"A" * nbytes) msg.append(b"r\n") diff --git a/nats/benchmark/pub_perf.py b/nats/benchmark/pub_perf.py index 6e2c4f119..ea8e51a80 100644 --- a/nats/benchmark/pub_perf.py +++ b/nats/benchmark/pub_perf.py @@ -68,11 +68,7 @@ async def main(): start = time.time() to_send = args.count - print( - "Sending {} messages of size {} bytes on [{}]".format( - args.count, args.size, args.subject - ) - ) + print("Sending {} messages of size {} bytes on [{}]".format(args.count, args.size, args.subject)) while to_send > 0: for i in range(0, args.batch): to_send -= 1 @@ -94,11 +90,7 @@ async def main(): elapsed = time.time() - start mbytes = "%.1f" % (((args.size * args.count) / elapsed) / (1024 * 1024)) - print( - "\nTest completed : {} msgs/sec ({}) MB/sec".format( - args.count / elapsed, mbytes - ) - ) + print("\nTest completed : {} msgs/sec ({}) MB/sec".format(args.count / elapsed, mbytes)) await nc.close() diff --git a/nats/benchmark/pub_sub_perf.py b/nats/benchmark/pub_sub_perf.py index 87198b14a..f76bb69a3 100644 --- a/nats/benchmark/pub_sub_perf.py +++ b/nats/benchmark/pub_sub_perf.py @@ -79,11 +79,7 @@ async def handler(msg): start = time.time() to_send = args.count - print( - "Sending {} messages of size {} bytes on [{}]".format( - args.count, args.size, args.subject - ) - ) + print("Sending {} messages of size {} bytes on [{}]".format(args.count, args.size, args.subject)) while to_send > 0: for i in range(0, args.batch): to_send -= 1 @@ -107,17 +103,9 @@ async def handler(msg): elapsed = time.time() - start mbytes = "%.1f" % (((args.size * args.count) / elapsed) / (1024 * 1024)) - print( - "\nTest completed : {} msgs/sec sent ({}) MB/sec".format( - args.count / elapsed, mbytes - ) - ) - - print( - "Received {} messages ({} msgs/sec)".format( - received, received / elapsed - ) - ) + print("\nTest completed : {} msgs/sec sent ({}) MB/sec".format(args.count / elapsed, mbytes)) + + print("Received {} messages ({} msgs/sec)".format(received, received / elapsed)) await nc.close() diff --git a/nats/benchmark/sub_perf.py b/nats/benchmark/sub_perf.py index 492665337..7f1b05353 100644 --- a/nats/benchmark/sub_perf.py +++ b/nats/benchmark/sub_perf.py @@ -77,11 +77,7 @@ async def handler(msg): elapsed = time.monotonic() - start print("\nTest completed : {} msgs/sec sent".format(args.count / elapsed)) - print( - "Received {} messages ({} msgs/sec)".format( - received, received / elapsed - ) - ) + print("Received {} messages ({} msgs/sec)".format(received, received / elapsed)) await nc.close() diff --git a/nats/examples/advanced.py b/nats/examples/advanced.py index 8cec761e9..35e55401c 100644 --- a/nats/examples/advanced.py +++ b/nats/examples/advanced.py @@ -5,7 +5,6 @@ async def main(): - async def disconnected_cb(): print("Got disconnected!") @@ -41,11 +40,7 @@ async def request_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print( - "Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data - ) - ) + print("Received a message on '{subject} {reply}': {data}".format(subject=subject, reply=reply, data=data)) # Signal the server to stop sending messages after we got 10 already. resp = await nc.request("help.please", b"help") diff --git a/nats/examples/basic.py b/nats/examples/basic.py index 959e38498..b2da95b7d 100644 --- a/nats/examples/basic.py +++ b/nats/examples/basic.py @@ -31,17 +31,13 @@ async def message_handler(msg): try: async for msg in sub.messages: - print( - f"Received a message on '{msg.subject} {msg.reply}': {msg.data.decode()}" - ) + print(f"Received a message on '{msg.subject} {msg.reply}': {msg.data.decode()}") await sub.unsubscribe() except Exception as e: pass async def help_request(msg): - print( - f"Received a message on '{msg.subject} {msg.reply}': {msg.data.decode()}" - ) + print(f"Received a message on '{msg.subject} {msg.reply}': {msg.data.decode()}") await nc.publish(msg.reply, b"I can help") # Use queue named 'workers' for distributing requests diff --git a/nats/examples/client.py b/nats/examples/client.py index 9efffad98..245c6ecaa 100644 --- a/nats/examples/client.py +++ b/nats/examples/client.py @@ -7,7 +7,6 @@ class Client: - def __init__(self, nc): self.nc = nc @@ -15,11 +14,7 @@ async def message_handler(self, msg): print(f"[Received on '{msg.subject}']: {msg.data.decode()}") async def request_handler(self, msg): - print( - "[Request on '{} {}']: {}".format( - msg.subject, msg.reply, msg.data.decode() - ) - ) + print("[Request on '{} {}']: {}".format(msg.subject, msg.reply, msg.data.decode())) await self.nc.publish(msg.reply, b"I can help!") async def start(self): diff --git a/nats/examples/clustered.py b/nats/examples/clustered.py index 60150957c..f68041808 100644 --- a/nats/examples/clustered.py +++ b/nats/examples/clustered.py @@ -70,11 +70,7 @@ async def subscribe_handler(msg): print("Connection closed prematurely.") break except ErrTimeout as e: - print( - "Timeout occurred when publishing msg i={}: {}".format( - i, e - ) - ) + print("Timeout occurred when publishing msg i={}: {}".format(i, e)) end_time = datetime.now() await nc.drain() diff --git a/nats/examples/component.py b/nats/examples/component.py index 2e9ebd417..74fdf08d8 100644 --- a/nats/examples/component.py +++ b/nats/examples/component.py @@ -6,7 +6,6 @@ class Component: - def __init__(self): self._nc = None self._done = asyncio.Future() @@ -89,9 +88,7 @@ def signal_handler(): asyncio.create_task(c.close()) for sig in ("SIGINT", "SIGTERM"): - asyncio.get_running_loop().add_signal_handler( - getattr(signal, sig), signal_handler - ) + asyncio.get_running_loop().add_signal_handler(getattr(signal, sig), signal_handler) await c.run_forever() diff --git a/nats/examples/connect.py b/nats/examples/connect.py index 06998a908..9e2da3c44 100644 --- a/nats/examples/connect.py +++ b/nats/examples/connect.py @@ -5,7 +5,6 @@ async def main(): - async def disconnected_cb(): print("Got disconnected!") diff --git a/nats/examples/context-manager.py b/nats/examples/context-manager.py index f5ea0d941..d7138faee 100644 --- a/nats/examples/context-manager.py +++ b/nats/examples/context-manager.py @@ -12,19 +12,14 @@ async def closed_cb(): is_done.set_result(True) arguments, _ = args.get_args("Run a context manager example.") - async with await nats.connect(arguments.servers, - closed_cb=closed_cb) as nc: + async with await nats.connect(arguments.servers, closed_cb=closed_cb) as nc: print(f"Connected to NATS at {nc.connected_url.netloc}...") async def subscribe_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print( - "Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data - ) - ) + print("Received a message on '{subject} {reply}': {data}".format(subject=subject, reply=reply, data=data)) await nc.subscribe("discover", cb=subscribe_handler) await nc.flush() diff --git a/nats/examples/drain-sub.py b/nats/examples/drain-sub.py index 23bd2a447..073cdff22 100644 --- a/nats/examples/drain-sub.py +++ b/nats/examples/drain-sub.py @@ -14,11 +14,7 @@ async def message_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print( - "Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data - ) - ) + print("Received a message on '{subject} {reply}': {data}".format(subject=subject, reply=reply, data=data)) # Simple publisher and async subscriber via coroutine. sub = await nc.subscribe("foo", cb=message_handler) @@ -33,11 +29,7 @@ async def help_request(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print( - "Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data - ) - ) + print("Received a message on '{subject} {reply}': {data}".format(subject=subject, reply=reply, data=data)) await nc.publish(reply, b"I can help") # Use queue named 'workers' for distributing requests @@ -69,11 +61,7 @@ async def drain_sub(): print("Received {count} responses!".format(count=len(responses))) for response in responses[:5]: - print( - "Received response: {message}".format( - message=response.data.decode() - ) - ) + print("Received response: {message}".format(message=response.data.decode())) except: pass diff --git a/nats/examples/example.py b/nats/examples/example.py index e3339f808..cd11d0fa7 100644 --- a/nats/examples/example.py +++ b/nats/examples/example.py @@ -32,11 +32,7 @@ async def message_handler(msg): print("Connection closed prematurely") async def request_handler(msg): - print( - "[Request on '{} {}']: {}".format( - msg.subject, msg.reply, msg.data.decode() - ) - ) + print("[Request on '{} {}']: {}".format(msg.subject, msg.reply, msg.data.decode())) await nc.publish(msg.reply, b"OK") if nc.is_connected: diff --git a/nats/examples/micro/service.py b/nats/examples/micro/service.py index 0e426c517..da4b88804 100644 --- a/nats/examples/micro/service.py +++ b/nats/examples/micro/service.py @@ -27,8 +27,7 @@ async def main(): # Add the service service = await stack.enter_async_context( - await - nats.micro.add_service(nc, name="demo_service", version="0.0.1") + await nats.micro.add_service(nc, name="demo_service", version="0.0.1") ) group = service.add_group(name="demo") diff --git a/nats/examples/nats-sub/__main__.py b/nats/examples/nats-sub/__main__.py index 595f0b287..343cc2d16 100644 --- a/nats/examples/nats-sub/__main__.py +++ b/nats/examples/nats-sub/__main__.py @@ -61,11 +61,7 @@ async def subscribe_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print( - "Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data - ) - ) + print("Received a message on '{subject} {reply}': {data}".format(subject=subject, reply=reply, data=data)) options = { "error_cb": error_cb, @@ -94,9 +90,7 @@ def signal_handler(): asyncio.create_task(nc.drain()) for sig in ("SIGINT", "SIGTERM"): - asyncio.get_running_loop().add_signal_handler( - getattr(signal, sig), signal_handler - ) + asyncio.get_running_loop().add_signal_handler(getattr(signal, sig), signal_handler) await nc.subscribe(args.subject, args.queue, subscribe_handler) diff --git a/nats/examples/publish.py b/nats/examples/publish.py index 1c44ee724..ad66da73b 100644 --- a/nats/examples/publish.py +++ b/nats/examples/publish.py @@ -5,9 +5,7 @@ async def main(): - arguments, _ = args.get_args( - "Run a publish example.", "Usage: python examples/publish.py" - ) + arguments, _ = args.get_args("Run a publish example.", "Usage: python examples/publish.py") nc = await nats.connect(arguments.servers) # Publish as message with an inbox. diff --git a/nats/examples/service.py b/nats/examples/service.py index 946efa0cc..3fadd8836 100644 --- a/nats/examples/service.py +++ b/nats/examples/service.py @@ -19,9 +19,7 @@ def signal_handler(): asyncio.create_task(stop()) for sig in ("SIGINT", "SIGTERM"): - asyncio.get_running_loop().add_signal_handler( - getattr(signal, sig), signal_handler - ) + asyncio.get_running_loop().add_signal_handler(getattr(signal, sig), signal_handler) async def disconnected_cb(): print("Got disconnected...") @@ -40,11 +38,7 @@ async def help_request(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print( - "Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data - ) - ) + print("Received a message on '{subject} {reply}': {data}".format(subject=subject, reply=reply, data=data)) await nc.publish(reply, b"I can help") # Use queue named 'workers' for distributing requests diff --git a/nats/examples/subscribe.py b/nats/examples/subscribe.py index 71779bee0..8283b4c85 100644 --- a/nats/examples/subscribe.py +++ b/nats/examples/subscribe.py @@ -23,11 +23,7 @@ async def subscribe_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print( - "Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data - ) - ) + print("Received a message on '{subject} {reply}': {data}".format(subject=subject, reply=reply, data=data)) await msg.respond(b"I can help!") # Basic subscription to receive all published messages @@ -45,9 +41,7 @@ def signal_handler(): asyncio.create_task(nc.close()) for sig in ("SIGINT", "SIGTERM"): - asyncio.get_running_loop().add_signal_handler( - getattr(signal, sig), signal_handler - ) + asyncio.get_running_loop().add_signal_handler(getattr(signal, sig), signal_handler) await nc.request("help", b"help") diff --git a/nats/examples/tls.py b/nats/examples/tls.py index 7185658ea..7019141d8 100644 --- a/nats/examples/tls.py +++ b/nats/examples/tls.py @@ -22,11 +22,7 @@ async def message_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print( - "Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data - ) - ) + print("Received a message on '{subject} {reply}': {data}".format(subject=subject, reply=reply, data=data)) # Simple publisher and async subscriber via coroutine. sid = await nc.subscribe("foo", cb=message_handler) @@ -41,11 +37,7 @@ async def help_request(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print( - "Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data - ) - ) + print("Received a message on '{subject} {reply}': {data}".format(subject=subject, reply=reply, data=data)) await nc.publish(reply, b"I can help") # Use queue named 'workers' for distributing requests @@ -56,11 +48,7 @@ async def help_request(msg): # and trigger timeout if not faster than 50 ms. try: response = await nc.timed_request("help", b"help me", 0.050) - print( - "Received response: {message}".format( - message=response.data.decode() - ) - ) + print("Received response: {message}".format(message=response.data.decode())) except TimeoutError: print("Request timed out") diff --git a/nats/examples/wildcard.py b/nats/examples/wildcard.py index 911884950..86eec3250 100644 --- a/nats/examples/wildcard.py +++ b/nats/examples/wildcard.py @@ -7,20 +7,14 @@ async def run(loop): nc = NATS() - arguments, _ = args.get_args( - "Run the wildcard example.", "Usage: python examples/wildcard.py" - ) + arguments, _ = args.get_args("Run the wildcard example.", "Usage: python examples/wildcard.py") await nc.connect(arguments.servers) async def message_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print( - "Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data - ) - ) + print("Received a message on '{subject} {reply}': {data}".format(subject=subject, reply=reply, data=data)) # "*" matches any token, at any level of the subject. await nc.subscribe("foo.*.baz", cb=message_handler) diff --git a/nats/src/nats/__init__.py b/nats/src/nats/__init__.py index f085b3d68..021f2a442 100644 --- a/nats/src/nats/__init__.py +++ b/nats/src/nats/__init__.py @@ -16,15 +16,12 @@ from typing import List, Union # Extend namespace to allow nats.server and other nats.* packages -__path__ = __import__('pkgutil').extend_path(__path__, __name__) +__path__ = __import__("pkgutil").extend_path(__path__, __name__) from .aio.client import Client as NATS -async def connect( - servers: Union[str, List[str]] = ["nats://localhost:4222"], - **options -) -> NATS: +async def connect(servers: Union[str, List[str]] = ["nats://localhost:4222"], **options) -> NATS: """ :param servers: List of servers to connect. :param options: NATS connect options. diff --git a/nats/src/nats/aio/client.py b/nats/src/nats/aio/client.py index f0677c561..8b4ba906c 100644 --- a/nats/src/nats/aio/client.py +++ b/nats/src/nats/aio/client.py @@ -61,6 +61,7 @@ try: from importlib.metadata import version + __version__ = version("nats-py") except Exception: __version__ = "0.0.0" @@ -133,7 +134,6 @@ class Srv: class ServerVersion: - def __init__(self, server_version: str) -> None: self._server_version = server_version self._major_version: Optional[int] = None @@ -165,9 +165,7 @@ def parse_version(self) -> None: _REGEX = re.compile(_SEMVER_REGEX, re.VERBOSE) match = _REGEX.match(self._server_version) if match is None: - raise ValueError( - f"{self._server_version} is not a valid Semantic Version" - ) + raise ValueError(f"{self._server_version} is not a valid Semantic Version") matches = match.groupdict() self._major_version = int(matches["major"]) self._minor_version = int(matches["minor"]) @@ -175,9 +173,7 @@ def parse_version(self) -> None: self._prerelease_version = matches["prerelease"] or "" self._build_version = matches["buildmetadata"] or "" if self._build_version: - self._dev_version = '+'.join([ - self._prerelease_version, self._build_version - ]) + self._dev_version = "+".join([self._prerelease_version, self._build_version]) else: self._dev_version = self._prerelease_version @@ -454,11 +450,11 @@ async def subscribe_handler(msg): """ for cb in [ - error_cb, - disconnected_cb, - closed_cb, - reconnected_cb, - discovered_server_cb, + error_cb, + disconnected_cb, + closed_cb, + reconnected_cb, + discovered_server_cb, ]: if cb and not asyncio.iscoroutinefunction(cb): raise errors.InvalidCallbackTypeError @@ -516,8 +512,7 @@ async def subscribe_handler(msg): if user or password or token or server_auth_configured: self._auth_configured = True - if (self._user_credentials is not None or self._nkeys_seed is not None - or self._nkeys_seed_str is not None): + if self._user_credentials is not None or self._nkeys_seed is not None or self._nkeys_seed_str is not None: self._auth_configured = True self._setup_nkeys_connect() @@ -537,9 +532,7 @@ async def subscribe_handler(msg): try: await self._select_next_server() await self._process_connect_init() - assert ( - self._current_server - ), "the current server must be set by _select_next_server" + assert self._current_server, "the current server must be set by _select_next_server" self._current_server.reconnects = 0 break except errors.NoServersError as e: @@ -603,8 +596,7 @@ def sig_cb(nonce: str) -> bytes: return sig self._signature_cb = sig_cb - elif (isinstance(creds, str) or isinstance(creds, UserString) - or isinstance(creds, Path)): + elif isinstance(creds, str) or isinstance(creds, UserString) or isinstance(creds, Path): # Define the functions to be able to sign things using nkeys. def user_cb() -> bytearray: return self._read_creds_user_jwt(creds) @@ -625,10 +617,7 @@ def sig_cb(nonce: str) -> bytes: self._signature_cb = sig_cb - def _read_creds_user_nkey( - self, creds: str | UserString | Path - ) -> bytearray: - + def _read_creds_user_nkey(self, creds: str | UserString | Path) -> bytearray: def get_user_seed(f): for line in f: # Detect line where the NKEY would start and end, @@ -657,7 +646,6 @@ def get_user_seed(f): return get_user_seed(f) def _read_creds_user_jwt(self, creds: str | RawCredentials | Path): - def get_user_jwt(f): user_jwt = None while True: @@ -666,7 +654,7 @@ def get_user_jwt(f): user_jwt = bytearray(f.readline()) break # Remove trailing line break but reusing same memory view. - return user_jwt[:len(user_jwt) - 1] + return user_jwt[: len(user_jwt) - 1] if isinstance(creds, UserString): return get_user_jwt(BytesIO(creds.data.encode())) @@ -675,9 +663,7 @@ def get_user_jwt(f): return get_user_jwt(f) def _setup_nkeys_seed_connect(self) -> None: - assert ( - self._nkeys_seed or self._nkeys_seed_str - ), "Client.connect must be called first" + assert self._nkeys_seed or self._nkeys_seed_str, "Client.connect must be called first" import nkeys @@ -729,26 +715,21 @@ async def _close(self, status: int, do_cbs: bool = True) -> None: # Kick the flusher once again so that Task breaks and avoid pending futures. await self._flush_pending() - if self._reading_task is not None and not self._reading_task.cancelled( - ): + if self._reading_task is not None and not self._reading_task.cancelled(): self._reading_task.cancel() - if (self._ping_interval_task is not None - and not self._ping_interval_task.cancelled()): + if self._ping_interval_task is not None and not self._ping_interval_task.cancelled(): self._ping_interval_task.cancel() - if self._flusher_task is not None and not self._flusher_task.cancelled( - ): + if self._flusher_task is not None and not self._flusher_task.cancelled(): self._flusher_task.cancel() - if self._reconnection_task is not None and not self._reconnection_task.done( - ): + if self._reconnection_task is not None and not self._reconnection_task.done(): self._reconnection_task.cancel() # Wait for the reconnection task to be done which should be soon. try: - if (self._reconnection_task_future is not None - and not self._reconnection_task_future.cancelled()): + if self._reconnection_task_future is not None and not self._reconnection_task_future.cancelled(): await asyncio.wait_for( self._reconnection_task_future, self.options["reconnect_time_wait"], @@ -831,9 +812,7 @@ async def drain(self) -> None: self._status = Client.DRAINING_SUBS try: - await asyncio.wait_for( - drain_is_done, self.options["drain_timeout"] - ) + await asyncio.wait_for(drain_is_done, self.options["drain_timeout"]) except asyncio.TimeoutError: drain_is_done.exception() drain_is_done.cancel() @@ -904,18 +883,14 @@ async def main(): payload_size = len(payload) if not self.is_connected: - if (self._max_pending_size <= 0 - or payload_size + self._pending_data_size - > self._max_pending_size): + if self._max_pending_size <= 0 or payload_size + self._pending_data_size > self._max_pending_size: # Cannot publish during a reconnection when the buffering is disabled, # or if pending buffer is already full. raise errors.OutboundBufferLimitError if payload_size > self._max_payload: raise errors.MaxPayloadError - await self._send_publish( - subject, reply, payload, payload_size, headers - ) + await self._send_publish(subject, reply, payload, payload_size, headers) async def _send_publish( self, @@ -1029,12 +1004,10 @@ async def _init_request_sub(self) -> None: self._resp_sub_prefix.extend(b".") resp_mux_subject = self._resp_sub_prefix[:] resp_mux_subject.extend(b"*") - await self.subscribe( - resp_mux_subject.decode(), cb=self._request_sub_callback - ) + await self.subscribe(resp_mux_subject.decode(), cb=self._request_sub_callback) async def _request_sub_callback(self, msg: Msg) -> None: - token = msg.subject[len(self._inbox_prefix) + 22 + 2:] + token = msg.subject[len(self._inbox_prefix) + 22 + 2 :] future = self._resp_map.get(token) if not future: @@ -1059,15 +1032,10 @@ async def request( """ if old_style: # FIXME: Support headers in old style requests. - return await self._request_old_style( - subject, payload, timeout=timeout - ) + return await self._request_old_style(subject, payload, timeout=timeout) else: - msg = await self._request_new_style( - subject, payload, timeout=timeout, headers=headers - ) - if (msg.headers and msg.headers.get(nats.js.api.Header.STATUS) - == NO_RESPONDERS_STATUS): + msg = await self._request_new_style(subject, payload, timeout=timeout, headers=headers) + if msg.headers and msg.headers.get(nats.js.api.Header.STATUS) == NO_RESPONDERS_STATUS: raise errors.NoRespondersError return msg @@ -1093,15 +1061,11 @@ async def _request_new_style( # Then use the future to get the response. future: asyncio.Future = asyncio.Future() - future.add_done_callback( - lambda f: self._resp_map.pop(token.decode(), None) - ) + future.add_done_callback(lambda f: self._resp_map.pop(token.decode(), None)) self._resp_map[token.decode()] = future # Publish the request - await self.publish( - subject, payload, reply=inbox.decode(), headers=headers - ) + await self.publish(subject, payload, reply=inbox.decode(), headers=headers) # Wait for the response or give up on timeout. try: @@ -1125,9 +1089,7 @@ def new_inbox(self) -> str: next_inbox.extend(self._nuid.next()) return next_inbox.decode() - async def _request_old_style( - self, subject: str, payload: bytes, timeout: float = 1 - ) -> Msg: + async def _request_old_style(self, subject: str, payload: bytes, timeout: float = 1) -> Msg: """ Implements the request/response pattern via pub/sub using an ephemeral subscription which will be published @@ -1144,8 +1106,7 @@ async def _request_old_style( try: msg = await asyncio.wait_for(future, timeout) if msg.headers: - if msg.headers.get(nats.js.api.Header.STATUS - ) == NO_RESPONDERS_STATUS: + if msg.headers.get(nats.js.api.Header.STATUS) == NO_RESPONDERS_STATUS: raise errors.NoRespondersError return msg except asyncio.TimeoutError: @@ -1244,10 +1205,7 @@ def is_connecting(self) -> bool: @property def is_draining(self) -> bool: - return ( - self._status == Client.DRAINING_SUBS - or self._status == Client.DRAINING_PUBS - ) + return self._status == Client.DRAINING_SUBS or self._status == Client.DRAINING_PUBS @property def is_draining_pubs(self) -> bool: @@ -1280,8 +1238,7 @@ async def _send_command(self, cmd: bytes, priority: bool = False) -> None: else: self._pending.append(cmd) self._pending_data_size += len(cmd) - if (self._max_pending_size > 0 - and self._pending_data_size > self._max_pending_size): + if self._max_pending_size > 0 and self._pending_data_size > self._max_pending_size: # Only flush force timeout on publish await self._flush_pending(force_flush=True) @@ -1346,13 +1303,11 @@ def _setup_server_pool(self, connect_url: Union[List[str]]) -> None: except ValueError: raise errors.Error("nats: invalid connect url option") # make sure protocols aren't mixed - if not (all(server.uri.scheme in ("nats", "tls") - for server in self._server_pool) - or all(server.uri.scheme in ("ws", "wss") - for server in self._server_pool)): - raise errors.Error( - "nats: mixing of websocket and non websocket URLs is not allowed" - ) + if not ( + all(server.uri.scheme in ("nats", "tls") for server in self._server_pool) + or all(server.uri.scheme in ("ws", "wss") for server in self._server_pool) + ): + raise errors.Error("nats: mixing of websocket and non websocket URLs is not allowed") else: raise errors.Error("nats: invalid connect url option") @@ -1377,8 +1332,7 @@ async def _select_next_server(self) -> None: # Not yet exceeded max_reconnect_attempts so can still use # this server in the future. self._server_pool.append(s) - if (s.last_attempt is not None and now - < s.last_attempt + self.options["reconnect_time_wait"]): + if s.last_attempt is not None and now < s.last_attempt + self.options["reconnect_time_wait"]: # Backoff connecting to server if we attempted recently. await asyncio.sleep(self.options["reconnect_time_wait"]) try: @@ -1457,14 +1411,11 @@ async def _process_op_err(self, e: Exception) -> None: self._status = Client.RECONNECTING self._ps.reset() - if (self._reconnection_task is not None - and not self._reconnection_task.cancelled()): + if self._reconnection_task is not None and not self._reconnection_task.cancelled(): # Cancel the previous task in case it may still be running. self._reconnection_task.cancel() - self._reconnection_task = asyncio.get_running_loop().create_task( - self._attempt_reconnect() - ) + self._reconnection_task = asyncio.get_running_loop().create_task(self._attempt_reconnect()) else: self._process_disconnect() self._err = e @@ -1472,16 +1423,13 @@ async def _process_op_err(self, e: Exception) -> None: async def _attempt_reconnect(self) -> None: assert self._current_server, "Client.connect must be called first" - if self._reading_task is not None and not self._reading_task.cancelled( - ): + if self._reading_task is not None and not self._reading_task.cancelled(): self._reading_task.cancel() - if (self._ping_interval_task is not None - and not self._ping_interval_task.cancelled()): + if self._ping_interval_task is not None and not self._ping_interval_task.cancelled(): self._ping_interval_task.cancel() - if self._flusher_task is not None and not self._flusher_task.cancelled( - ): + if self._flusher_task is not None and not self._flusher_task.cancelled(): self._flusher_task.cancel() if self._transport is not None: @@ -1498,8 +1446,7 @@ async def _attempt_reconnect(self) -> None: if self.is_closed: return - if "dont_randomize" not in self.options or not self.options[ - "dont_randomize"]: + if "dont_randomize" not in self.options or not self.options["dont_randomize"]: shuffle(self._server_pool) # Create a future that the client can use to control waiting @@ -1535,9 +1482,7 @@ async def _attempt_reconnect(self) -> None: # auto unsubscribe the number of messages we have left max_msgs = sub._max_msgs - sub._received - sub_cmd = prot_command.sub_cmd( - sub._subject, sub._queue, sid - ) + sub_cmd = prot_command.sub_cmd(sub._subject, sub._queue, sid) self._transport.write(sub_cmd) if max_msgs > 0: @@ -1573,8 +1518,7 @@ async def _attempt_reconnect(self) -> None: except asyncio.CancelledError: break - if (self._reconnection_task_future is not None - and not self._reconnection_task_future.cancelled()): + if self._reconnection_task_future is not None and not self._reconnection_task_future.cancelled(): self._reconnection_task_future.set_result(True) def _connect_command(self) -> bytes: @@ -1608,8 +1552,7 @@ def _connect_command(self) -> bytes: options["nkey"] = self._public_nkey # In case there is no password, then consider handle # sending a token instead. - elif (self.options["user"] is not None - and self.options["password"] is not None): + elif self.options["user"] is not None and self.options["password"] is not None: options["user"] = self.options["user"] options["pass"] = self.options["password"] elif self.options["token"] is not None: @@ -1646,8 +1589,7 @@ async def _process_pong(self) -> None: self._pongs_received += 1 self._pings_outstanding = 0 - def _is_control_message(self, data, header: Dict[str, - str]) -> Optional[str]: + def _is_control_message(self, data, header: Dict[str, str]) -> Optional[str]: if len(data) > 0: return None status = header.get(nats.js.api.Header.STATUS) @@ -1676,9 +1618,9 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]: # if raw_headers[0] == _SPC_BYTE_: # Special handling for status messages. - line = headers[len(NATS_HDR_LINE) + 1:] + line = headers[len(NATS_HDR_LINE) + 1 :] status = line[:STATUS_MSG_LEN] - desc = line[STATUS_MSG_LEN + 1:len(line) - _CRLF_LEN_ - _CRLF_LEN_] + desc = line[STATUS_MSG_LEN + 1 : len(line) - _CRLF_LEN_ - _CRLF_LEN_] stripped_status = status.strip().decode() # Process as status only when it is a valid integer. @@ -1688,7 +1630,7 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]: # Move the raw_headers to end of line i = raw_headers.find(_CRLF_) - raw_headers = raw_headers[i + _CRLF_LEN_:] + raw_headers = raw_headers[i + _CRLF_LEN_ :] if len(desc) > 0: # Heartbeat messages can have both headers and inline status, @@ -1696,9 +1638,7 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]: i = desc.find(_CRLF_) if i > 0: hdr[nats.js.api.Header.DESCRIPTION] = desc[:i].decode() - parsed_hdr = self._hdr_parser.parsebytes( - desc[i + _CRLF_LEN_:] - ) + parsed_hdr = self._hdr_parser.parsebytes(desc[i + _CRLF_LEN_ :]) for k, v in parsed_hdr.items(): hdr[k] = v else: @@ -1713,16 +1653,12 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]: # # NATS/1.0\r\nfoo: bar\r\nhello: world # - raw_headers = headers[NATS_HDR_LINE_SIZE + _CRLF_LEN_:] + raw_headers = headers[NATS_HDR_LINE_SIZE + _CRLF_LEN_ :] try: if parse_email: parsed_hdr = parse_email(raw_headers).headers else: - parsed_hdr = { - k.strip(): v.strip() - for k, v in self._hdr_parser.parsebytes(raw_headers - ).items() - } + parsed_hdr = {k.strip(): v.strip() for k, v in self._hdr_parser.parsebytes(raw_headers).items()} if hdr: hdr.update(parsed_hdr) else: @@ -1828,29 +1764,19 @@ async def _process_msg( try: sub._pending_size += payload_size # allow setting pending_bytes_limit to 0 to disable - if (sub._pending_bytes_limit > 0 - and sub._pending_size >= sub._pending_bytes_limit): + if sub._pending_bytes_limit > 0 and sub._pending_size >= sub._pending_bytes_limit: # Subtract the bytes since the message will be thrown away # so it would not be pending data. sub._pending_size -= payload_size await self._error_cb( - errors.SlowConsumerError( - subject=msg.subject, - reply=msg.reply, - sid=sid, - sub=sub - ) + errors.SlowConsumerError(subject=msg.subject, reply=msg.reply, sid=sid, sub=sub) ) return sub._pending_queue.put_nowait(msg) except asyncio.QueueFull: sub._pending_size -= len(msg.data) - await self._error_cb( - errors.SlowConsumerError( - subject=msg.subject, reply=msg.reply, sid=sid, sub=sub - ) - ) + await self._error_cb(errors.SlowConsumerError(subject=msg.subject, reply=msg.reply, sid=sid, sub=sub)) # Store the ACK metadata from the message to # compare later on with the received heartbeat. @@ -1900,9 +1826,7 @@ def _process_disconnect(self) -> None: """ self._status = Client.DISCONNECTED - async def _process_info( - self, info: Dict[str, Any], initial_connection: bool = False - ) -> None: + async def _process_info(self, info: Dict[str, Any], initial_connection: bool = False) -> None: """ Process INFO lines sent by the server to reconfigure client with latest updates from cluster to enable server discovery. @@ -1923,9 +1847,11 @@ async def _process_info( srv.discovered = True # Check whether we should reuse the original hostname. - if ("tls_required" in self._server_info - and self._server_info["tls_required"] - and self._host_is_ip(uri.hostname)): + if ( + "tls_required" in self._server_info + and self._server_info["tls_required"] + and self._host_is_ip(uri.hostname) + ): srv.tls_name = self._current_server.uri.hostname # Filter for any similar server in the server pool already. @@ -1941,8 +1867,7 @@ async def _process_info( for srv in connect_urls: self._server_pool.append(srv) - if (not initial_connection and connect_urls - and self._discovered_server_cb): + if not initial_connection and connect_urls and self._discovered_server_cb: await self._discovered_server_cb() def _host_is_ip(self, connect_url: Optional[str]) -> bool: @@ -1983,14 +1908,10 @@ async def _process_connect_init(self) -> None: ) connection_completed = self._transport.readline() - info_line = await asyncio.wait_for( - connection_completed, self.options["connect_timeout"] - ) + info_line = await asyncio.wait_for(connection_completed, self.options["connect_timeout"]) if INFO_OP not in info_line: # FIXME: Handle PING/PONG arriving first as well. - raise errors.Error( - "nats: empty response from server when expecting INFO message" - ) + raise errors.Error("nats: empty response from server when expecting INFO message") _, info = info_line.split(INFO_OP + _SPC_, 1) @@ -2015,9 +1936,11 @@ async def _process_connect_init(self) -> None: if "client_id" in self._server_info: self._client_id = self._server_info["client_id"] - if ("tls_required" in self._server_info - and self._server_info["tls_required"] - and self._current_server.uri.scheme != "ws"): + if ( + "tls_required" in self._server_info + and self._server_info["tls_required"] + and self._current_server.uri.scheme != "ws" + ): if not handshake_first: await self._transport.drain() # just in case something is left @@ -2039,9 +1962,7 @@ async def _process_connect_init(self) -> None: await self._transport.drain() if self.options["verbose"]: future = self._transport.readline() - next_op = await asyncio.wait_for( - future, self.options["connect_timeout"] - ) + next_op = await asyncio.wait_for(future, self.options["connect_timeout"]) if OK_OP in next_op: # Do nothing pass @@ -2058,9 +1979,7 @@ async def _process_connect_init(self) -> None: await self._transport.drain() future = self._transport.readline() - next_op = await asyncio.wait_for( - future, self.options["connect_timeout"] - ) + next_op = await asyncio.wait_for(future, self.options["connect_timeout"]) if PONG_PROTO in next_op: self._status = Client.CONNECTED @@ -2076,23 +1995,15 @@ async def _process_connect_init(self) -> None: if PONG_PROTO in next_op: self._status = Client.CONNECTED - self._reading_task = asyncio.get_running_loop().create_task( - self._read_loop() - ) + self._reading_task = asyncio.get_running_loop().create_task(self._read_loop()) self._pongs = [] self._pings_outstanding = 0 - self._ping_interval_task = asyncio.get_running_loop().create_task( - self._ping_interval() - ) + self._ping_interval_task = asyncio.get_running_loop().create_task(self._ping_interval()) # Task for kicking the flusher queue - self._flusher_task = asyncio.get_running_loop().create_task( - self._flusher() - ) + self._flusher_task = asyncio.get_running_loop().create_task(self._flusher()) - async def _send_ping( - self, future: Optional[asyncio.Future] = None - ) -> None: + async def _send_ping(self, future: Optional[asyncio.Future] = None) -> None: assert self._transport, "Client.connect must be called first" if future is None: future = asyncio.Future() @@ -2139,8 +2050,7 @@ async def _ping_interval(self) -> None: continue try: self._pings_outstanding += 1 - if self._pings_outstanding > self.options[ - "max_outstanding_pings"]: + if self._pings_outstanding > self.options["max_outstanding_pings"]: await self._process_op_err(ErrStaleConnection()) return await self._send_ping() diff --git a/nats/src/nats/aio/msg.py b/nats/src/nats/aio/msg.py index 5724706aa..767c7cde7 100644 --- a/nats/src/nats/aio/msg.py +++ b/nats/src/nats/aio/msg.py @@ -215,10 +215,11 @@ def _get_metadata_fields(cls, reply: Optional[str]) -> List[str]: if not reply: raise NotJSMessageError tokens = reply.split(".") - if ((len(tokens) == _V1_TOKEN_COUNT - or len(tokens) >= _V2_TOKEN_COUNT - 1) - and tokens[0] == Msg.Ack.Prefix0 - and tokens[1] == Msg.Ack.Prefix1): + if ( + (len(tokens) == _V1_TOKEN_COUNT or len(tokens) >= _V2_TOKEN_COUNT - 1) + and tokens[0] == Msg.Ack.Prefix0 + and tokens[1] == Msg.Ack.Prefix1 + ): return tokens raise NotJSMessageError @@ -227,9 +228,7 @@ def _from_reply(cls, reply: str) -> Msg.Metadata: """Construct the metadata from the reply string""" tokens = cls._get_metadata_fields(reply) if len(tokens) == _V1_TOKEN_COUNT: - t = datetime.datetime.fromtimestamp( - int(tokens[7]) / 1_000_000_000.0, datetime.timezone.utc - ) + t = datetime.datetime.fromtimestamp(int(tokens[7]) / 1_000_000_000.0, datetime.timezone.utc) return cls( sequence=Msg.Metadata.SequencePair( stream=int(tokens[5]), @@ -243,8 +242,7 @@ def _from_reply(cls, reply: str) -> Msg.Metadata: ) else: t = datetime.datetime.fromtimestamp( - int(tokens[Msg.Ack.Timestamp]) / 1_000_000_000.0, - datetime.timezone.utc + int(tokens[Msg.Ack.Timestamp]) / 1_000_000_000.0, datetime.timezone.utc ) # Underscore indicate no domain is set. Expose as empty string diff --git a/nats/src/nats/aio/subscription.py b/nats/src/nats/aio/subscription.py index 19c315b17..76727cc65 100644 --- a/nats/src/nats/aio/subscription.py +++ b/nats/src/nats/aio/subscription.py @@ -26,6 +26,7 @@ from uuid import uuid4 from nats import errors + # Default Pending Limits of Subscriptions from nats.aio.msg import Msg @@ -84,9 +85,7 @@ def __init__( # Per subscription message processor. self._pending_msgs_limit = pending_msgs_limit self._pending_bytes_limit = pending_bytes_limit - self._pending_queue: asyncio.Queue[Msg] = asyncio.Queue( - maxsize=pending_msgs_limit - ) + self._pending_queue: asyncio.Queue[Msg] = asyncio.Queue(maxsize=pending_msgs_limit) # If no callback, then this is a sync subscription which will # require tracking the next_msg calls inflight for cancelling. if cb is None: @@ -131,9 +130,7 @@ def messages(self) -> AsyncIterator[Msg]: print('Received', msg) """ if not self._message_iterator: - raise errors.Error( - "cannot iterate over messages with a non iteration subscription type" - ) + raise errors.Error("cannot iterate over messages with a non iteration subscription type") return self._message_iterator @@ -180,9 +177,7 @@ async def timed_get() -> Msg: raise errors.ConnectionClosedError if self._cb: - raise errors.Error( - "nats: next_msg cannot be used in async subscriptions" - ) + raise errors.Error("nats: next_msg cannot be used in async subscriptions") task_name = str(uuid4()) try: @@ -213,15 +208,11 @@ def _start(self, error_cb): """ if self._cb: if not asyncio.iscoroutinefunction(self._cb) and not ( - hasattr(self._cb, "func") - and asyncio.iscoroutinefunction(self._cb.func)): - raise errors.Error( - "nats: must use coroutine for subscriptions" - ) + hasattr(self._cb, "func") and asyncio.iscoroutinefunction(self._cb.func) + ): + raise errors.Error("nats: must use coroutine for subscriptions") - self._wait_for_msgs_task = asyncio.get_running_loop().create_task( - self._wait_for_msgs(error_cb) - ) + self._wait_for_msgs_task = asyncio.get_running_loop().create_task(self._wait_for_msgs(error_cb)) elif self._future: # Used to handle the single response from a request. @@ -284,8 +275,7 @@ async def unsubscribe(self, limit: int = 0): raise errors.BadSubscriptionError self._max_msgs = limit - if limit == 0 or (self._received >= limit - and self._pending_queue.empty()): + if limit == 0 or (self._received >= limit and self._pending_queue.empty()): self._closed = True self._stop_processing() self._conn._remove_sub(self._id) @@ -331,15 +321,13 @@ async def _wait_for_msgs(self, error_cb) -> None: self._pending_queue.task_done() # Apply auto unsubscribe checks after having processed last msg. - if (self._max_msgs > 0 and self._received >= self._max_msgs - and self._pending_queue.empty): + if self._max_msgs > 0 and self._received >= self._max_msgs and self._pending_queue.empty: self._stop_processing() except asyncio.CancelledError: break class _SubscriptionMessageIterator: - def __init__(self, sub: Subscription) -> None: self._sub: Subscription = sub self._queue: asyncio.Queue[Msg] = sub._pending_queue @@ -355,9 +343,7 @@ def __aiter__(self) -> _SubscriptionMessageIterator: async def __anext__(self) -> Msg: get_task = asyncio.get_running_loop().create_task(self._queue.get()) tasks: List[asyncio.Future] = [get_task, self._unsubscribed_future] - finished, _ = await asyncio.wait( - tasks, return_when=asyncio.FIRST_COMPLETED - ) + finished, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) sub = self._sub if get_task in finished: diff --git a/nats/src/nats/aio/transport.py b/nats/src/nats/aio/transport.py index 45de64438..aaa85d4b3 100644 --- a/nats/src/nats/aio/transport.py +++ b/nats/src/nats/aio/transport.py @@ -15,11 +15,8 @@ class Transport(abc.ABC): - @abc.abstractmethod - async def connect( - self, uri: ParseResult, buffer_size: int, connect_timeout: int - ): + async def connect(self, uri: ParseResult, buffer_size: int, connect_timeout: int): """ Connects to a server using the implemented transport. The uri passed is of type ParseResult that can be obtained calling urllib.parse.urlparse. @@ -108,16 +105,13 @@ def __bool__(self): class TcpTransport(Transport): - def __init__(self): self._bare_io_reader: Optional[asyncio.StreamReader] = None self._io_reader: Optional[asyncio.StreamReader] = None self._bare_io_writer: Optional[asyncio.StreamWriter] = None self._io_writer: Optional[asyncio.StreamWriter] = None - async def connect( - self, uri: ParseResult, buffer_size: int, connect_timeout: int - ): + async def connect(self, uri: ParseResult, buffer_size: int, connect_timeout: int): r, w = await asyncio.wait_for( asyncio.open_connection( host=uri.hostname, @@ -156,9 +150,7 @@ async def connect_tls( server_hostname=uri if isinstance(uri, str) else uri.hostname, ) transport = await asyncio.wait_for(transport_future, connect_timeout) - writer = asyncio.StreamWriter( - transport, protocol, reader, asyncio.get_running_loop() - ) + writer = asyncio.StreamWriter(transport, protocol, reader, asyncio.get_running_loop()) self._io_reader, self._io_writer = reader, writer def write(self, payload): @@ -191,25 +183,18 @@ def __bool__(self): class WebSocketTransport(Transport): - def __init__(self): if not aiohttp: - raise ImportError( - "Could not import aiohttp transport, please install it with `pip install aiohttp`" - ) + raise ImportError("Could not import aiohttp transport, please install it with `pip install aiohttp`") self._ws: Optional[aiohttp.ClientWebSocketResponse] = None self._client: aiohttp.ClientSession = aiohttp.ClientSession() self._pending = asyncio.Queue() self._close_task = asyncio.Future() self._using_tls: Optional[bool] = None - async def connect( - self, uri: ParseResult, buffer_size: int, connect_timeout: int - ): + async def connect(self, uri: ParseResult, buffer_size: int, connect_timeout: int): # for websocket library, the uri must contain the scheme already - self._ws = await self._client.ws_connect( - uri.geturl(), timeout=connect_timeout - ) + self._ws = await self._client.ws_connect(uri.geturl(), timeout=connect_timeout) self._using_tls = False async def connect_tls( diff --git a/nats/src/nats/errors.py b/nats/src/nats/errors.py index a5d292871..043e401a5 100644 --- a/nats/src/nats/errors.py +++ b/nats/src/nats/errors.py @@ -25,156 +25,127 @@ class Error(Exception): class TimeoutError(Error, asyncio.TimeoutError): - def __str__(self) -> str: return "nats: timeout" class NoRespondersError(Error): - def __str__(self) -> str: return "nats: no responders available for request" class StaleConnectionError(Error): - def __str__(self) -> str: return "nats: stale connection" class OutboundBufferLimitError(Error): - def __str__(self) -> str: return "nats: outbound buffer limit exceeded" class UnexpectedEOF(StaleConnectionError): - def __str__(self) -> str: return "nats: unexpected EOF" class FlushTimeoutError(TimeoutError): - def __str__(self) -> str: return "nats: flush timeout" class ConnectionClosedError(Error): - def __str__(self) -> str: return "nats: connection closed" class SecureConnRequiredError(Error): - def __str__(self) -> str: return "nats: secure connection required" class SecureConnWantedError(Error): - def __str__(self) -> str: return "nats: secure connection not available" class SecureConnFailedError(Error): - def __str__(self) -> str: return "nats: secure connection failed" class BadSubscriptionError(Error): - def __str__(self) -> str: return "nats: invalid subscription" class BadSubjectError(Error): - def __str__(self) -> str: return "nats: invalid subject" class SlowConsumerError(Error): - - def __init__( - self, subject: str, reply: str, sid: int, sub: Subscription - ) -> None: + def __init__(self, subject: str, reply: str, sid: int, sub: Subscription) -> None: self.subject = subject self.reply = reply self.sid = sid self.sub = sub def __str__(self) -> str: - return ( - "nats: slow consumer, messages dropped subject: " - f"{self.subject}, sid: {self.sid}, sub: {self.sub}" - ) + return f"nats: slow consumer, messages dropped subject: {self.subject}, sid: {self.sid}, sub: {self.sub}" class BadTimeoutError(Error): - def __str__(self) -> str: return "nats: timeout invalid" class AuthorizationError(Error): - def __str__(self) -> str: return "nats: authorization failed" class NoServersError(Error): - def __str__(self) -> str: return "nats: no servers available for connection" class JsonParseError(Error): - def __str__(self) -> str: return "nats: connect message, json parse err" class MaxPayloadError(Error): - def __str__(self) -> str: return "nats: maximum payload exceeded" class DrainTimeoutError(TimeoutError): - def __str__(self) -> str: return "nats: draining connection timed out" class ConnectionDrainingError(Error): - def __str__(self) -> str: return "nats: connection draining" class ConnectionReconnectingError(Error): - def __str__(self) -> str: return "nats: connection reconnecting" class InvalidUserCredentialsError(Error): - def __str__(self) -> str: return "nats: invalid user credentials" class InvalidCallbackTypeError(Error): - def __str__(self) -> str: return "nats: callbacks must be coroutine functions" class ProtocolError(Error): - def __str__(self) -> str: return "nats: protocol error" @@ -190,7 +161,6 @@ def __str__(self) -> str: class MsgAlreadyAckdError(Error): - def __init__(self, msg=None) -> None: self._msg = msg diff --git a/nats/src/nats/js/api.py b/nats/src/nats/js/api.py index c456db6c0..170e4a01a 100644 --- a/nats/src/nats/js/api.py +++ b/nats/src/nats/js/api.py @@ -164,9 +164,7 @@ def from_response(cls, resp: Dict[str, Any]): def as_dict(self) -> Dict[str, object]: result = super().as_dict() if self.subject_transforms: - result["subject_transforms"] = [ - tr.as_dict() for tr in self.subject_transforms - ] + result["subject_transforms"] = [tr.as_dict() for tr in self.subject_transforms] return result @@ -321,17 +319,12 @@ def from_response(cls, resp: Dict[str, Any]): def as_dict(self) -> Dict[str, object]: result = super().as_dict() - result["duplicate_window"] = self._to_nanoseconds( - self.duplicate_window - ) + result["duplicate_window"] = self._to_nanoseconds(self.duplicate_window) result["max_age"] = self._to_nanoseconds(self.max_age) if self.sources: result["sources"] = [src.as_dict() for src in self.sources] - if self.compression and (self.compression != StoreCompression.NONE - and self.compression != StoreCompression.S2): - raise ValueError( - "nats: invalid store compression type: %s" % self.compression - ) + if self.compression and (self.compression != StoreCompression.NONE and self.compression != StoreCompression.S2): + raise ValueError("nats: invalid store compression type: %s" % self.compression) if self.metadata and not isinstance(self.metadata, dict): raise ValueError("nats: invalid metadata format") return result @@ -387,9 +380,7 @@ class StreamsListIterator(Iterable): StreamsListIterator is an iterator for streams list responses from JetStream. """ - def __init__( - self, offset: int, total: int, streams: List[Dict[str, any]] - ) -> None: + def __init__(self, offset: int, total: int, streams: List[Dict[str, any]]) -> None: self.offset = offset self.total = total self.streams = streams @@ -513,9 +504,7 @@ def as_dict(self) -> Dict[str, object]: result = super().as_dict() result["ack_wait"] = self._to_nanoseconds(self.ack_wait) result["idle_heartbeat"] = self._to_nanoseconds(self.idle_heartbeat) - result["inactive_threshold"] = self._to_nanoseconds( - self.inactive_threshold - ) + result["inactive_threshold"] = self._to_nanoseconds(self.inactive_threshold) if self.backoff: result["backoff"] = [self._to_nanoseconds(i) for i in self.backoff] return result diff --git a/nats/src/nats/js/client.py b/nats/src/nats/js/client.py index d26413c0b..dc333dea1 100644 --- a/nats/src/nats/js/client.py +++ b/nats/src/nats/js/client.py @@ -123,9 +123,7 @@ def __init__( self._publish_async_completed_event = asyncio.Event() self._publish_async_completed_event.set() - self._publish_async_pending_semaphore = asyncio.Semaphore( - publish_async_max_pending - ) + self._publish_async_pending_semaphore = asyncio.Semaphore(publish_async_max_pending) @property def _jsm(self) -> JetStreamManager: @@ -146,12 +144,10 @@ async def _init_async_reply(self) -> None: async_reply_subject = self._async_reply_prefix[:] async_reply_subject.extend(b"*") - await self._nc.subscribe( - async_reply_subject.decode(), cb=self._handle_async_reply - ) + await self._nc.subscribe(async_reply_subject.decode(), cb=self._handle_async_reply) async def _handle_async_reply(self, msg: Msg) -> None: - token = msg.subject[len(self._nc._inbox_prefix) + 22 + 2:] + token = msg.subject[len(self._nc._inbox_prefix) + 22 + 2 :] future = self._publish_async_futures.get(token) if not future: @@ -161,8 +157,7 @@ async def _handle_async_reply(self, msg: Msg) -> None: return # Handle no responders - if msg.headers and msg.headers.get(api.Header.STATUS - ) == NO_RESPONDERS_STATUS: + if msg.headers and msg.headers.get(api.Header.STATUS) == NO_RESPONDERS_STATUS: future.set_exception(nats.js.errors.NoStreamResponseError) return @@ -234,10 +229,7 @@ async def publish_async( hdr[api.Header.EXPECTED_STREAM] = stream try: - await asyncio.wait_for( - self._publish_async_pending_semaphore.acquire(), - timeout=wait_stall - ) + await asyncio.wait_for(self._publish_async_pending_semaphore.acquire(), timeout=wait_stall) except (asyncio.TimeoutError, asyncio.CancelledError): raise nats.js.errors.TooManyStalledMsgsError @@ -264,9 +256,7 @@ def handle_done(future): if self._publish_async_completed_event.is_set(): self._publish_async_completed_event.clear() - await self._nc.publish( - subject, payload, reply=inbox.decode(), headers=hdr - ) + await self._nc.publish(subject, payload, reply=inbox.decode(), headers=hdr) return future @@ -359,9 +349,7 @@ async def cb(msg): # If using a queue, that will be the consumer/durable name. if queue: if durable and durable != queue: - raise nats.js.errors.Error( - f"cannot create queue subscription '{queue}' to consumer '{durable}'" - ) + raise nats.js.errors.Error(f"cannot create queue subscription '{queue}' to consumer '{durable}'") else: durable = queue @@ -393,9 +381,7 @@ async def cb(msg): elif consumer_info.push_bound: # Need to reject a non queue subscription to a non queue consumer # if the consumer is already bound. - raise nats.js.errors.Error( - "consumer is already bound to a subscription" - ) + raise nats.js.errors.Error("consumer is already bound to a subscription") else: if not queue: raise nats.js.errors.Error( @@ -483,8 +469,7 @@ async def subscribe_bind( # # In case ack policy is none then we also do not require to ack. # - if cb and (not manual_ack) and (config.ack_policy - is not api.AckPolicy.NONE): + if cb and (not manual_ack) and (config.ack_policy is not api.AckPolicy.NONE): cb = self._auto_ack_callback(cb) if config.deliver_subject is None: raise TypeError("config.deliver_subject is required") @@ -512,15 +497,12 @@ async def subscribe_bind( sub._jsi._hbtask = asyncio.create_task(sub._jsi.activity_check()) if ordered_consumer: - sub._jsi._fctask = asyncio.create_task( - sub._jsi.check_flow_control_response() - ) + sub._jsi._fctask = asyncio.create_task(sub._jsi.check_flow_control_response()) return psub @staticmethod def _auto_ack_callback(callback: Callback) -> Callback: - async def new_callback(msg: Msg) -> None: await callback(msg) try: @@ -687,9 +669,11 @@ def _is_processable_msg(cls, status: Optional[str], msg: Msg) -> bool: @classmethod def _is_temporary_error(cls, status: Optional[str]) -> bool: - if (status == api.StatusCode.NO_MESSAGES - or status == api.StatusCode.CONFLICT - or status == api.StatusCode.REQUEST_TIMEOUT): + if ( + status == api.StatusCode.NO_MESSAGES + or status == api.StatusCode.CONFLICT + or status == api.StatusCode.REQUEST_TIMEOUT + ): return True else: return False @@ -702,14 +686,12 @@ def _is_heartbeat(cls, status: Optional[str]) -> bool: return False @classmethod - def _time_until(cls, timeout: Optional[float], - start_time: float) -> Optional[float]: + def _time_until(cls, timeout: Optional[float], start_time: float) -> Optional[float]: if timeout is None: return None return timeout - (time.monotonic() - start_time) class _JSI: - def __init__( self, js: JetStreamContext, @@ -784,8 +766,7 @@ async def check_flow_control_response(self): if self._conn.is_closed: break - if (self._fciseq - - self._psub._pending_queue.qsize()) >= self._fcd: + if (self._fciseq - self._psub._pending_queue.qsize()) >= self._fcd: fc_reply = self._fcr try: if fc_reply: @@ -798,8 +779,7 @@ async def check_flow_control_response(self): except asyncio.CancelledError: break - async def check_for_sequence_mismatch(self, - msg: Msg) -> Optional[bool]: + async def check_for_sequence_mismatch(self, msg: Msg) -> Optional[bool]: self._active = True if not self._cmeta: return None @@ -817,9 +797,7 @@ async def check_for_sequence_mismatch(self, sseq = int(tokens[5]) # stream sequence if self._ordered: - did_reset = await self.reset_ordered_consumer( - self._sseq + 1 - ) + did_reset = await self.reset_ordered_consumer(self._sseq + 1) else: ecs = nats.js.errors.ConsumerSequenceMismatchError( stream_resume_sequence=sseq, @@ -874,11 +852,7 @@ async def reset_ordered_consumer(self, sseq: Optional[int]) -> bool: async def recreate_consumer(self) -> None: try: - cinfo = await self._js._jsm.add_consumer( - self._stream, - config=self._ccreq, - timeout=self._js._timeout - ) + cinfo = await self._js._jsm.add_consumer(self._stream, config=self._ccreq, timeout=self._js._timeout) self._psub._consumer = cinfo.name except Exception as err: await self._conn._error_cb(err) @@ -1043,9 +1017,7 @@ async def consumer_info(self) -> api.ConsumerInfo: """ consumer_info gets the current info of the consumer from this subscription. """ - info = await self._js._jsm.consumer_info( - self._stream, self._consumer - ) + info = await self._js._jsm.consumer_info(self._stream, self._consumer) return info async def fetch( @@ -1091,9 +1063,7 @@ async def main(): if timeout is not None and timeout <= 0: raise ValueError("nats: invalid fetch timeout") - expires = int( - timeout * 1_000_000_000 - ) - 100_000 if timeout else None + expires = int(timeout * 1_000_000_000) - 100_000 if timeout else None if batch == 1: msg = await self._fetch_one(expires, timeout, heartbeat) return [msg] @@ -1129,9 +1099,7 @@ async def _fetch_one( if expires: next_req["expires"] = int(expires) if heartbeat: - next_req["idle_heartbeat"] = int( - heartbeat * 1_000_000_000 - ) # to nanoseconds + next_req["idle_heartbeat"] = int(heartbeat * 1_000_000_000) # to nanoseconds await self._nc.publish( self._nms, @@ -1143,9 +1111,7 @@ async def _fetch_one( got_any_response = False while True: try: - deadline = JetStreamContext._time_until( - timeout, start_time - ) + deadline = JetStreamContext._time_until(timeout, start_time) # Wait for the response or raise timeout. msg = await self._sub.next_msg(timeout=deadline) @@ -1165,9 +1131,7 @@ async def _fetch_one( else: return msg except asyncio.TimeoutError: - deadline = JetStreamContext._time_until( - timeout, start_time - ) + deadline = JetStreamContext._time_until(timeout, start_time) if deadline is not None and deadline < 0: # No response from the consumer could have been # due to a reconnect while the fetch request, @@ -1214,9 +1178,7 @@ async def _fetch_n( if expires: next_req["expires"] = expires if heartbeat: - next_req["idle_heartbeat"] = int( - heartbeat * 1_000_000_000 - ) # to nanoseconds + next_req["idle_heartbeat"] = int(heartbeat * 1_000_000_000) # to nanoseconds next_req["no_wait"] = True await self._nc.publish( self._nms, @@ -1248,9 +1210,7 @@ async def _fetch_n( try: for i in range(0, needed): - deadline = JetStreamContext._time_until( - timeout, start_time - ) + deadline = JetStreamContext._time_until(timeout, start_time) msg = await self._sub.next_msg(timeout=deadline) status = JetStreamContext.is_status_msg(msg) if status == api.StatusCode.NO_MESSAGES or status == api.StatusCode.REQUEST_TIMEOUT: @@ -1280,9 +1240,7 @@ async def _fetch_n( if expires: next_req["expires"] = expires if heartbeat: - next_req["idle_heartbeat"] = int( - heartbeat * 1_000_000_000 - ) # to nanoseconds + next_req["idle_heartbeat"] = int(heartbeat * 1_000_000_000) # to nanoseconds await self._nc.publish( self._nms, @@ -1338,9 +1296,7 @@ async def _fetch_n( # Wait for the rest of the messages to be delivered to the internal pending queue. try: for _ in range(needed): - deadline = JetStreamContext._time_until( - timeout, start_time - ) + deadline = JetStreamContext._time_until(timeout, start_time) if deadline is not None and deadline < 0: return msgs diff --git a/nats/src/nats/js/errors.py b/nats/src/nats/js/errors.py index 10de0eb04..d594587db 100644 --- a/nats/src/nats/js/errors.py +++ b/nats/src/nats/js/errors.py @@ -92,8 +92,7 @@ def from_error(cls, err: Dict[str, Any]): def __str__(self) -> str: return ( - f"nats: {type(self).__name__}: code={self.code} err_code={self.err_code} " - f"description='{self.description}'" + f"nats: {type(self).__name__}: code={self.code} err_code={self.err_code} description='{self.description}'" ) @@ -243,13 +242,11 @@ def __str__(self) -> str: class NoKeysError(KeyValueError): - def __str__(self) -> str: return "nats: no keys found" class KeyHistoryTooLargeError(KeyValueError): - def __str__(self) -> str: return "nats: history limited to a max of 64" @@ -258,6 +255,7 @@ class InvalidKeyError(Error): """ Raised when trying to put an object in Key Value with an invalid key. """ + pass diff --git a/nats/src/nats/js/kv.py b/nats/src/nats/js/kv.py index bed70d940..a765f8c17 100644 --- a/nats/src/nats/js/kv.py +++ b/nats/src/nats/js/kv.py @@ -35,11 +35,11 @@ logger = logging.getLogger(__name__) -VALID_KEY_RE = re.compile(r'^[-/_=\.a-zA-Z0-9]+$') +VALID_KEY_RE = re.compile(r"^[-/_=\.a-zA-Z0-9]+$") def _is_key_valid(key: str) -> bool: - if len(key) == 0 or key[0] == '.' or key[-1] == '.': + if len(key) == 0 or key[0] == "." or key[-1] == ".": return False return bool(VALID_KEY_RE.match(key)) @@ -133,12 +133,7 @@ def __init__( self._js = js self._direct = direct - async def get( - self, - key: str, - revision: Optional[int] = None, - validate_keys: bool = True - ) -> Entry: + async def get(self, key: str, revision: Optional[int] = None, validate_keys: bool = True) -> Entry: """ get returns the latest value for the key. """ @@ -174,9 +169,7 @@ async def _get(self, key: str, revision: Optional[int] = None) -> Entry: # Check whether the revision from the stream does not match the key. if subject != msg.subject: - raise nats.js.errors.KeyNotFoundError( - message=f"expected '{subject}', but got '{msg.subject}'" - ) + raise nats.js.errors.KeyNotFoundError(message=f"expected '{subject}', but got '{msg.subject}'") entry = KeyValue.Entry( bucket=self._name, @@ -196,9 +189,7 @@ async def _get(self, key: str, revision: Optional[int] = None) -> Entry: return entry - async def put( - self, key: str, value: bytes, validate_keys: bool = True - ) -> int: + async def put(self, key: str, value: bytes, validate_keys: bool = True) -> int: """ put will place the new value for the key into the store and return the revision number. @@ -209,9 +200,7 @@ async def put( pa = await self._js.publish(f"{self._pre}{key}", value) return pa.seq - async def create( - self, key: str, value: bytes, validate_keys: bool = True - ) -> int: + async def create(self, key: str, value: bytes, validate_keys: bool = True) -> int: """ create will add the key/value pair iff it does not exist. """ @@ -220,9 +209,7 @@ async def create( pa = None try: - pa = await self.update( - key, value, last=0, validate_keys=validate_keys - ) + pa = await self.update(key, value, last=0, validate_keys=validate_keys) except nats.js.errors.KeyWrongLastSequenceError as err: # In case of attempting to recreate an already deleted key, # the client would get a KeyWrongLastSequenceError. When this happens, @@ -242,22 +229,11 @@ async def create( # to recreate using the last revision. raise err except nats.js.errors.KeyDeletedError as err: - pa = await self.update( - key, - value, - last=err.entry.revision, - validate_keys=validate_keys - ) + pa = await self.update(key, value, last=err.entry.revision, validate_keys=validate_keys) return pa - async def update( - self, - key: str, - value: bytes, - last: Optional[int] = None, - validate_keys: bool = True - ) -> int: + async def update(self, key: str, value: bytes, last: Optional[int] = None, validate_keys: bool = True) -> int: """ update will update the value if the latest revision matches. """ @@ -271,25 +247,16 @@ async def update( pa = None try: - pa = await self._js.publish( - f"{self._pre}{key}", value, headers=hdrs - ) + pa = await self._js.publish(f"{self._pre}{key}", value, headers=hdrs) except nats.js.errors.APIError as err: # Check for a BadRequest::KeyWrongLastSequenceError error code. if err.err_code == 10071: - raise nats.js.errors.KeyWrongLastSequenceError( - description=err.description - ) + raise nats.js.errors.KeyWrongLastSequenceError(description=err.description) else: raise err return pa.seq - async def delete( - self, - key: str, - last: Optional[int] = None, - validate_keys: bool = True - ) -> bool: + async def delete(self, key: str, last: Optional[int] = None, validate_keys: bool = True) -> bool: """ delete will place a delete marker and remove all previous revisions. """ @@ -330,14 +297,10 @@ async def purge_deletes(self, olderthan: int = 30 * 60) -> bool: for entry in delete_markers: keep = 0 subject = f"{self._pre}{entry.key}" - duration = datetime.datetime.now( - datetime.timezone.utc - ) - entry.created + duration = datetime.datetime.now(datetime.timezone.utc) - entry.created if olderthan > 0 and olderthan > duration.total_seconds(): keep = 1 - await self._js.purge_stream( - self._stream, subject=subject, keep=keep - ) + await self._js.purge_stream(self._stream, subject=subject, keep=keep) return True async def status(self) -> BucketStatus: @@ -348,11 +311,9 @@ async def status(self) -> BucketStatus: return KeyValue.BucketStatus(stream_info=info, bucket=self._name) class KeyWatcher: - def __init__(self, js): self._js = js - self._updates: asyncio.Queue[KeyValue.Entry - | None] = asyncio.Queue(maxsize=256) + self._updates: asyncio.Queue[KeyValue.Entry | None] = asyncio.Queue(maxsize=256) self._sub = None self._pending: Optional[int] = None @@ -408,9 +369,7 @@ async def keys(self, filters: List[str] = None, **kwargs) -> List[str]: if consumer_info and filters: # If NATS server < 2.10, filters might be ignored. if consumer_info.config.filter_subject != ">": - logger.warning( - "Server may ignore filters if version is < 2.10." - ) + logger.warning("Server may ignore filters if version is < 2.10.") except Exception as e: raise e @@ -493,7 +452,7 @@ async def watch_updates(msg): entry = KeyValue.Entry( bucket=self._name, - key=msg.subject[len(self._pre):], + key=msg.subject[len(self._pre) :], value=msg.data, revision=meta.sequence.stream, delta=meta.num_pending, diff --git a/nats/src/nats/js/manager.py b/nats/src/nats/js/manager.py index bfd5937f0..33f29170c 100644 --- a/nats/src/nats/js/manager.py +++ b/nats/src/nats/js/manager.py @@ -49,9 +49,7 @@ def __init__( self._hdr_parser = BytesParser() async def account_info(self) -> api.AccountInfo: - resp = await self._api_request( - f"{self._prefix}.INFO", b"", timeout=self._timeout - ) + resp = await self._api_request(f"{self._prefix}.INFO", b"", timeout=self._timeout) return api.AccountInfo.from_response(resp) async def find_stream_name_by_subject(self, subject: str) -> str: @@ -61,18 +59,12 @@ async def find_stream_name_by_subject(self, subject: str) -> str: req_sub = f"{self._prefix}.STREAM.NAMES" req_data = json.dumps({"subject": subject}) - info = await self._api_request( - req_sub, req_data.encode(), timeout=self._timeout - ) + info = await self._api_request(req_sub, req_data.encode(), timeout=self._timeout) if not info["streams"]: raise NotFoundError return info["streams"][0] - async def stream_info( - self, - name: str, - subjects_filter: Optional[str] = None - ) -> api.StreamInfo: + async def stream_info(self, name: str, subjects_filter: Optional[str] = None) -> api.StreamInfo: """ Get the latest StreamInfo by stream name. """ @@ -86,11 +78,7 @@ async def stream_info( ) return api.StreamInfo.from_response(resp) - async def add_stream( - self, - config: Optional[api.StreamConfig] = None, - **params - ) -> api.StreamInfo: + async def add_stream(self, config: Optional[api.StreamConfig] = None, **params) -> api.StreamInfo: """ add_stream creates a stream. """ @@ -122,11 +110,7 @@ async def add_stream( ) return api.StreamInfo.from_response(resp) - async def update_stream( - self, - config: Optional[api.StreamConfig] = None, - **params - ) -> api.StreamInfo: + async def update_stream(self, config: Optional[api.StreamConfig] = None, **params) -> api.StreamInfo: """ update_stream updates a stream. """ @@ -148,9 +132,7 @@ async def delete_stream(self, name: str) -> bool: """ Delete a stream by name. """ - resp = await self._api_request( - f"{self._prefix}.STREAM.DELETE.{name}", timeout=self._timeout - ) + resp = await self._api_request(f"{self._prefix}.STREAM.DELETE.{name}", timeout=self._timeout) return resp["success"] async def purge_stream( @@ -172,24 +154,14 @@ async def purge_stream( stream_req["keep"] = keep req = json.dumps(stream_req) - resp = await self._api_request( - f"{self._prefix}.STREAM.PURGE.{name}", - req.encode(), - timeout=self._timeout - ) + resp = await self._api_request(f"{self._prefix}.STREAM.PURGE.{name}", req.encode(), timeout=self._timeout) return resp["success"] - async def consumer_info( - self, stream: str, consumer: str, timeout: Optional[float] = None - ): + async def consumer_info(self, stream: str, consumer: str, timeout: Optional[float] = None): # TODO: Validate the stream and consumer names. if timeout is None: timeout = self._timeout - resp = await self._api_request( - f"{self._prefix}.CONSUMER.INFO.{stream}.{consumer}", - b"", - timeout=timeout - ) + resp = await self._api_request(f"{self._prefix}.CONSUMER.INFO.{stream}.{consumer}", b"", timeout=timeout) return api.ConsumerInfo.from_response(resp) async def streams_info(self, offset=0) -> List[api.StreamInfo]: @@ -198,9 +170,7 @@ async def streams_info(self, offset=0) -> List[api.StreamInfo]: """ resp = await self._api_request( f"{self._prefix}.STREAM.LIST", - json.dumps({ - "offset": offset - }).encode(), + json.dumps({"offset": offset}).encode(), timeout=self._timeout, ) streams = [] @@ -209,22 +179,17 @@ async def streams_info(self, offset=0) -> List[api.StreamInfo]: streams.append(stream_info) return streams - async def streams_info_iterator(self, - offset=0) -> Iterable[api.StreamInfo]: + async def streams_info_iterator(self, offset=0) -> Iterable[api.StreamInfo]: """ streams_info retrieves a list of streams Iterator. """ resp = await self._api_request( f"{self._prefix}.STREAM.LIST", - json.dumps({ - "offset": offset - }).encode(), + json.dumps({"offset": offset}).encode(), timeout=self._timeout, ) - return api.StreamsListIterator( - resp["offset"], resp["total"], resp["streams"] - ) + return api.StreamsListIterator(resp["offset"], resp["total"], resp["streams"]) async def add_consumer( self, @@ -270,11 +235,7 @@ async def delete_consumer(self, stream: str, consumer: str) -> bool: ) return resp["success"] - async def consumers_info( - self, - stream: str, - offset: Optional[int] = None - ) -> List[api.ConsumerInfo]: + async def consumers_info(self, stream: str, offset: Optional[int] = None) -> List[api.ConsumerInfo]: """ consumers_info retrieves a list of consumers. Consumers list limit is 256 for more consider to use offset @@ -283,9 +244,7 @@ async def consumers_info( """ resp = await self._api_request( f"{self._prefix}.CONSUMER.LIST.{stream}", - b"" if offset is None else json.dumps({ - "offset": offset - }).encode(), + b"" if offset is None else json.dumps({"offset": offset}).encode(), timeout=self._timeout, ) consumers = [] @@ -329,22 +288,18 @@ async def get_msg( else: req_subject = f"{self._prefix}.DIRECT.GET.{stream_name}" - resp = await self._nc.request( - req_subject, data.encode(), timeout=self._timeout - ) + resp = await self._nc.request(req_subject, data.encode(), timeout=self._timeout) raw_msg = JetStreamManager._lift_msg_to_raw_msg(resp) return raw_msg # Non Direct form req_subject = f"{self._prefix}.STREAM.MSG.GET.{stream_name}" - resp_data = await self._api_request( - req_subject, data.encode(), timeout=self._timeout - ) + resp_data = await self._api_request(req_subject, data.encode(), timeout=self._timeout) raw_msg = api.RawStreamMsg.from_response(resp_data["message"]) if raw_msg.hdrs: hdrs = base64.b64decode(raw_msg.hdrs) - raw_headers = hdrs[NATS_HDR_LINE_SIZE + _CRLF_LEN_:] + raw_headers = hdrs[NATS_HDR_LINE_SIZE + _CRLF_LEN_ :] parsed_headers = self._hdr_parser.parsebytes(raw_headers) headers = None if len(parsed_headers.items()) > 0: diff --git a/nats/src/nats/js/object_store.py b/nats/src/nats/js/object_store.py index 70ce3d3bf..b49b0a9ba 100644 --- a/nats/src/nats/js/object_store.py +++ b/nats/src/nats/js/object_store.py @@ -199,12 +199,8 @@ async def get( if info.size == 0: return result - chunk_subj = OBJ_CHUNKS_PRE_TEMPLATE.format( - bucket=self._name, obj=info.nuid - ) - sub = await self._js.subscribe( - subject=chunk_subj, ordered_consumer=True - ) + chunk_subj = OBJ_CHUNKS_PRE_TEMPLATE.format(bucket=self._name, obj=info.nuid) + sub = await self._js.subscribe(subject=chunk_subj, ordered_consumer=True) h = sha256() @@ -232,9 +228,7 @@ async def get( # Make sure the digest matches. sha = h.digest() - digest_str = info.digest.replace(OBJ_DIGEST_TYPE, "").replace( - OBJ_DIGEST_TYPE.upper(), "" - ) + digest_str = info.digest.replace(OBJ_DIGEST_TYPE, "").replace(OBJ_DIGEST_TYPE.upper(), "") rsha = base64.urlsafe_b64decode(digest_str) if not sha == rsha: raise DigestMismatchError @@ -267,9 +261,7 @@ async def put( newnuid = self._js._nc._nuid.next() # Create a random subject prefixed with the object stream name. - chunk_subj = OBJ_CHUNKS_PRE_TEMPLATE.format( - bucket=self._name, obj=newnuid.decode() - ) + chunk_subj = OBJ_CHUNKS_PRE_TEMPLATE.format(bucket=self._name, obj=newnuid.decode()) # Grab existing meta info (einfo). Ok to be found or not found, any other error is a problem. # Chunks on the old nuid can be cleaned up at the end. @@ -330,9 +322,7 @@ async def put( sha = h.digest() info.size = total info.chunks = sent - info.digest = OBJ_DIGEST_TEMPLATE.format( - digest=base64.urlsafe_b64encode(sha).decode() - ) + info.digest = OBJ_DIGEST_TEMPLATE.format(digest=base64.urlsafe_b64encode(sha).decode()) # Prepare the meta message. meta_subj = OBJ_META_PRE_TEMPLATE.format( @@ -355,9 +345,7 @@ async def put( # Delete any original chunks. if einfo is not None and not einfo.deleted: - chunk_subj = OBJ_CHUNKS_PRE_TEMPLATE.format( - bucket=self._name, obj=einfo.nuid - ) + chunk_subj = OBJ_CHUNKS_PRE_TEMPLATE.format(bucket=self._name, obj=einfo.nuid) await self._js.purge_stream(self._stream, subject=chunk_subj) return info @@ -424,7 +412,6 @@ async def update_meta( await self._js.purge_stream(self._stream, subject=meta_subj) class ObjectWatcher: - def __init__(self, js): self._js = js self._updates = asyncio.Queue(maxsize=256) @@ -469,7 +456,9 @@ async def watch( """ watch for changes in the underlying store and receive meta information updates. """ - all_meta = OBJ_ALL_META_PRE_TEMPLATE.format(bucket=self._name, ) + all_meta = OBJ_ALL_META_PRE_TEMPLATE.format( + bucket=self._name, + ) watcher = ObjectStore.ObjectWatcher(self) async def watch_updates(msg): @@ -518,9 +507,7 @@ async def delete(self, name: str) -> ObjectResult: raise BadObjectMetaError # Purge chunks for the object. - chunk_subj = OBJ_CHUNKS_PRE_TEMPLATE.format( - bucket=self._name, obj=info.nuid - ) + chunk_subj = OBJ_CHUNKS_PRE_TEMPLATE.format(bucket=self._name, obj=info.nuid) # Reset meta values. info.deleted = True diff --git a/nats/src/nats/micro/__init__.py b/nats/src/nats/micro/__init__.py index ab1701719..7ba415cc6 100644 --- a/nats/src/nats/micro/__init__.py +++ b/nats/src/nats/micro/__init__.py @@ -21,9 +21,7 @@ from .service import Service, ServiceConfig -async def add_service( - nc: Client, config: Optional[ServiceConfig] = None, **kwargs -) -> Service: +async def add_service(nc: Client, config: Optional[ServiceConfig] = None, **kwargs) -> Service: """Add a service.""" if config: config = replace(config, **kwargs) diff --git a/nats/src/nats/micro/request.py b/nats/src/nats/micro/request.py index 7a192f9b7..ff85ee35e 100644 --- a/nats/src/nats/micro/request.py +++ b/nats/src/nats/micro/request.py @@ -45,11 +45,7 @@ def data(self) -> bytes: """The data of the request.""" return self._msg.data - async def respond( - self, - data: bytes = b"", - headers: Optional[Dict[str, str]] = None - ) -> None: + async def respond(self, data: bytes = b"", headers: Optional[Dict[str, str]] = None) -> None: """Send a response to the request. :param data: The response data. @@ -83,10 +79,12 @@ async def respond_error( else: headers = {} - headers.update({ - ERROR_HEADER: description, - ERROR_CODE_HEADER: code, - }) + headers.update( + { + ERROR_HEADER: description, + ERROR_CODE_HEADER: code, + } + ) await self.respond(data, headers=headers) @@ -108,6 +106,4 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> ServiceError: - return cls( - code=data.get("code", ""), description=data.get("description", "") - ) + return cls(code=data.get("code", ""), description=data.get("description", "")) diff --git a/nats/src/nats/micro/service.py b/nats/src/nats/micro/service.py index 159530a39..f62ae823d 100644 --- a/nats/src/nats/micro/service.py +++ b/nats/src/nats/micro/service.py @@ -72,21 +72,15 @@ def __post_init__(self) -> None: raise ValueError("Name cannot be empty.") if not NAME_REGEX.match(self.name): - raise ValueError( - "Invalid name. Name must contain only alphanumeric characters, underscores, and hyphens." - ) + raise ValueError("Invalid name. Name must contain only alphanumeric characters, underscores, and hyphens.") if self.subject: if not SUBJECT_REGEX.match(self.subject): - raise ValueError( - "Invalid subject. Subject must not contain spaces, and can only have '>' at the end." - ) + raise ValueError("Invalid subject. Subject must not contain spaces, and can only have '>' at the end.") if self.queue_group: if not SUBJECT_REGEX.match(self.queue_group): - raise ValueError( - "Invalid queue group. Queue group must not contain spaces." - ) + raise ValueError("Invalid queue group. Queue group must not contain spaces.") @dataclass @@ -272,9 +266,7 @@ async def _handle_request(self, msg: Msg) -> None: elapsed_time = current_time - start_time self._processing_time += elapsed_time - self._average_processing_time = int( - self._processing_time / self._num_requests - ) + self._average_processing_time = int(self._processing_time / self._num_requests) @dataclass @@ -294,8 +286,7 @@ class EndpointManager(Protocol): """ @overload - async def add_endpoint(self, config: EndpointConfig) -> None: - ... + async def add_endpoint(self, config: EndpointConfig) -> None: ... @overload async def add_endpoint( @@ -306,13 +297,9 @@ async def add_endpoint( queue_group: Optional[str] = None, subject: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, - ) -> None: - ... + ) -> None: ... - async def add_endpoint( - self, config: Optional[EndpointConfig] = None, **kwargs - ) -> None: - ... + async def add_endpoint(self, config: Optional[EndpointConfig] = None, **kwargs) -> None: ... class GroupManager(Protocol): @@ -321,31 +308,22 @@ class GroupManager(Protocol): """ @overload - def add_group( - self, *, name: str, queue_group: Optional[str] = None - ) -> Group: - ... + def add_group(self, *, name: str, queue_group: Optional[str] = None) -> Group: ... @overload - def add_group(self, config: GroupConfig) -> Group: - ... + def add_group(self, config: GroupConfig) -> Group: ... - def add_group( - self, config: Optional[GroupConfig] = None, **kwargs - ) -> Group: - ... + def add_group(self, config: Optional[GroupConfig] = None, **kwargs) -> Group: ... class Group(GroupManager, EndpointManager): - def __init__(self, service: "Service", config: GroupConfig) -> None: self._service = service self._prefix = config.name self._queue_group = config.queue_group @overload - async def add_endpoint(self, config: EndpointConfig) -> None: - ... + async def add_endpoint(self, config: EndpointConfig) -> None: ... @overload async def add_endpoint( @@ -356,12 +334,9 @@ async def add_endpoint( queue_group: Optional[str] = None, subject: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, - ) -> None: - ... + ) -> None: ... - async def add_endpoint( - self, config: Optional[EndpointConfig] = None, **kwargs - ) -> None: + async def add_endpoint(self, config: Optional[EndpointConfig] = None, **kwargs) -> None: if config is None: config = EndpointConfig(**kwargs) else: @@ -369,26 +344,19 @@ async def add_endpoint( config = replace( config, - subject=f"{self._prefix.strip('.')}.{config.subject or config.name}" - .strip("."), + subject=f"{self._prefix.strip('.')}.{config.subject or config.name}".strip("."), queue_group=config.queue_group or self._queue_group, ) await self._service.add_endpoint(config) @overload - def add_group( - self, *, name: str, queue_group: Optional[str] = None - ) -> Group: - ... + def add_group(self, *, name: str, queue_group: Optional[str] = None) -> Group: ... @overload - def add_group(self, config: GroupConfig) -> Group: - ... + def add_group(self, config: GroupConfig) -> Group: ... - def add_group( - self, config: Optional[GroupConfig] = None, **kwargs - ) -> Group: + def add_group(self, config: Optional[GroupConfig] = None, **kwargs) -> Group: if config: config = replace(config, **kwargs) else: @@ -440,17 +408,13 @@ def __post_init__(self) -> None: raise ValueError("Name cannot be empty.") if not NAME_REGEX.match(self.name): - raise ValueError( - "Invalid name. It must contain only alphanumeric characters, dashes, and underscores." - ) + raise ValueError("Invalid name. It must contain only alphanumeric characters, dashes, and underscores.") if not self.version: raise ValueError("Version cannot be empty.") if not SEMVER_REGEX.match(self.version): - raise ValueError( - "Invalid version. It must follow semantic versioning (e.g., 1.0.0, 2.1.3-alpha.1)." - ) + raise ValueError("Invalid version. It must follow semantic versioning (e.g., 1.0.0, 2.1.3-alpha.1).") if self.queue_group: if not SUBJECT_REGEX.match(self.queue_group): @@ -542,13 +506,8 @@ def from_dict(cls, data: Dict[str, Any]) -> ServiceStats: id=data["id"], name=data["name"], version=data["version"], - started=datetime.strptime( - data["started"], "%Y-%m-%dT%H:%M:%S.%fZ" - ), - endpoints=[ - EndpointStats.from_dict(endpoint) - for endpoint in data["endpoints"] - ], + started=datetime.strptime(data["started"], "%Y-%m-%dT%H:%M:%S.%fZ"), + endpoints=[EndpointStats.from_dict(endpoint) for endpoint in data["endpoints"]], metadata=data["metadata"], ) @@ -560,8 +519,7 @@ def to_dict(self) -> Dict[str, Any]: "name": self.name, "id": self.id, "version": self.version, - "started": self.started.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + - "Z", + "started": self.started.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", "endpoints": [endpoint.to_dict() for endpoint in self.endpoints], "metadata": self.metadata, } @@ -620,10 +578,7 @@ def from_dict(cls, data: Dict[str, Any]) -> ServiceInfo: name=data["name"], version=data["version"], description=data.get("description"), - endpoints=[ - EndpointInfo.from_dict(endpoint) - for endpoint in data["endpoints"] - ], + endpoints=[EndpointInfo.from_dict(endpoint) for endpoint in data["endpoints"]], metadata=data["metadata"], type=data.get("type", "io.nats.micro.v1.info_response"), ) @@ -645,7 +600,6 @@ def to_dict(self) -> Dict[str, Any]: class Service(AsyncContextManager): - def __init__(self, client: Client, config: ServiceConfig) -> None: self._id = client._nuid.next().decode() self._name = config.name @@ -683,38 +637,26 @@ async def start(self) -> None: verb_subjects = [ ( f"{verb}-all", - control_subject( - verb, name=None, id=None, prefix=self._prefix - ), + control_subject(verb, name=None, id=None, prefix=self._prefix), ), ( f"{verb}-kind", - control_subject( - verb, name=self._name, id=None, prefix=self._prefix - ), + control_subject(verb, name=self._name, id=None, prefix=self._prefix), ), ( verb, - control_subject( - verb, - name=self._name, - id=self._id, - prefix=self._prefix - ), + control_subject(verb, name=self._name, id=self._id, prefix=self._prefix), ), ] for key, subject in verb_subjects: - self._subscriptions[key] = await self._client.subscribe( - subject, cb=verb_handler - ) + self._subscriptions[key] = await self._client.subscribe(subject, cb=verb_handler) self._started = datetime.utcnow() await self._client.flush() @overload - async def add_endpoint(self, config: EndpointConfig) -> None: - ... + async def add_endpoint(self, config: EndpointConfig) -> None: ... @overload async def add_endpoint( @@ -725,46 +667,33 @@ async def add_endpoint( queue_group: Optional[str] = None, subject: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, - ) -> None: - ... + ) -> None: ... - async def add_endpoint( - self, config: Optional[EndpointConfig] = None, **kwargs - ) -> None: + async def add_endpoint(self, config: Optional[EndpointConfig] = None, **kwargs) -> None: if config is None: config = EndpointConfig(**kwargs) else: config = replace(config, **kwargs) - config = replace( - config, queue_group=config.queue_group or self._queue_group - ) + config = replace(config, queue_group=config.queue_group or self._queue_group) endpoint = Endpoint(self, config) await endpoint._start() self._endpoints.append(endpoint) @overload - def add_group( - self, *, name: str, queue_group: Optional[str] = None - ) -> Group: - ... + def add_group(self, *, name: str, queue_group: Optional[str] = None) -> Group: ... @overload - def add_group(self, config: GroupConfig) -> Group: - ... + def add_group(self, config: GroupConfig) -> Group: ... - def add_group( - self, config: Optional[GroupConfig] = None, **kwargs - ) -> Group: + def add_group(self, config: Optional[GroupConfig] = None, **kwargs) -> Group: if config: config = replace(config, **kwargs) else: config = GroupConfig(**kwargs) - config = replace( - config, queue_group=config.queue_group or self._queue_group - ) + config = replace(config, queue_group=config.queue_group or self._queue_group) return Group(self, config) @@ -787,7 +716,8 @@ def stats(self) -> ServiceStats: last_error=endpoint._last_error, processing_time=endpoint._processing_time, average_processing_time=endpoint._average_processing_time, - ) for endpoint in (self._endpoints or []) + ) + for endpoint in (self._endpoints or []) ], started=self._started, ) @@ -811,7 +741,8 @@ def info(self) -> ServiceInfo: subject=endpoint._subject, queue_group=endpoint._queue_group, metadata=endpoint._metadata, - ) for endpoint in self._endpoints + ) + for endpoint in self._endpoints ], ) diff --git a/nats/src/nats/protocol/command.py b/nats/src/nats/protocol/command.py index 7218f48de..9b8337be2 100644 --- a/nats/src/nats/protocol/command.py +++ b/nats/src/nats/protocol/command.py @@ -12,19 +12,13 @@ def pub_cmd(subject, reply, payload) -> bytes: - return ( - f"{PUB_OP} {subject} {reply} {len(payload)}{_CRLF_}".encode() + - payload + _CRLF_.encode() - ) + return f"{PUB_OP} {subject} {reply} {len(payload)}{_CRLF_}".encode() + payload + _CRLF_.encode() def hpub_cmd(subject, reply, hdr, payload) -> bytes: hdr_len = len(hdr) total_size = len(payload) + hdr_len - return ( - f"{HPUB_OP} {subject} {reply} {hdr_len} {total_size}{_CRLF_}".encode() - + hdr + payload + _CRLF_.encode() - ) + return f"{HPUB_OP} {subject} {reply} {hdr_len} {total_size}{_CRLF_}".encode() + hdr + payload + _CRLF_.encode() def sub_cmd(subject, queue, sid) -> bytes: diff --git a/nats/src/nats/protocol/parser.py b/nats/src/nats/protocol/parser.py index 6b8c72555..472f75221 100644 --- a/nats/src/nats/protocol/parser.py +++ b/nats/src/nats/protocol/parser.py @@ -23,12 +23,8 @@ from nats.errors import ProtocolError -MSG_RE = re.compile( - b"\\AMSG\\s+([^\\s]+)\\s+([^\\s]+)\\s+(([^\\s]+)[^\\S\r\n]+)?(\\d+)\r\n" -) -HMSG_RE = re.compile( - b"\\AHMSG\\s+([^\\s]+)\\s+([^\\s]+)\\s+(([^\\s]+)[^\\S\r\n]+)?([\\d]+)\\s+(\\d+)\r\n" -) +MSG_RE = re.compile(b"\\AMSG\\s+([^\\s]+)\\s+([^\\s]+)\\s+(([^\\s]+)[^\\S\r\n]+)?(\\d+)\r\n") +HMSG_RE = re.compile(b"\\AHMSG\\s+([^\\s]+)\\s+([^\\s]+)\\s+(([^\\s]+)[^\\S\r\n]+)?([\\d]+)\\s+(\\d+)\r\n") OK_RE = re.compile(b"\\A\\+OK\\s*\r\n") ERR_RE = re.compile(b"\\A-ERR\\s+('.+')?\r\n") PING_RE = re.compile(b"\\APING\\s*\r\n") @@ -72,7 +68,6 @@ class Parser: - def __init__(self, nc=None) -> None: self.nc = nc self.reset() @@ -106,7 +101,7 @@ async def parse(self, data: bytes = b""): else: self.msg_arg["reply"] = b"" self.needed = int(needed_bytes) - del self.buf[:msg.end()] + del self.buf[: msg.end()] self.state = AWAITING_MSG_PAYLOAD continue except Exception: @@ -115,8 +110,7 @@ async def parse(self, data: bytes = b""): msg = HMSG_RE.match(self.buf) if msg: try: - subject, sid, _, reply, header_size, needed_bytes = msg.groups( - ) + subject, sid, _, reply, header_size, needed_bytes = msg.groups() self.msg_arg["subject"] = subject self.msg_arg["sid"] = int(sid) if reply: @@ -125,7 +119,7 @@ async def parse(self, data: bytes = b""): self.msg_arg["reply"] = b"" self.needed = int(needed_bytes) self.header_needed = int(header_size) - del self.buf[:msg.end()] + del self.buf[: msg.end()] self.state = AWAITING_MSG_PAYLOAD continue except Exception: @@ -134,7 +128,7 @@ async def parse(self, data: bytes = b""): ok = OK_RE.match(self.buf) if ok: # Do nothing and just skip. - del self.buf[:ok.end()] + del self.buf[: ok.end()] continue err = ERR_RE.match(self.buf) @@ -142,18 +136,18 @@ async def parse(self, data: bytes = b""): err_msg = err.groups() emsg = err_msg[0].decode().lower() await self.nc._process_err(emsg) - del self.buf[:err.end()] + del self.buf[: err.end()] continue ping = PING_RE.match(self.buf) if ping: - del self.buf[:ping.end()] + del self.buf[: ping.end()] await self.nc._process_ping() continue pong = PONG_RE.match(self.buf) if pong: - del self.buf[:pong.end()] + del self.buf[: pong.end()] await self.nc._process_pong() continue @@ -162,11 +156,10 @@ async def parse(self, data: bytes = b""): info_line = info.groups()[0] srv_info = json.loads(info_line.decode()) await self.nc._process_info(srv_info) - del self.buf[:info.end()] + del self.buf[: info.end()] continue - if len(self.buf - ) < MAX_CONTROL_LINE_SIZE and _CRLF_ in self.buf: + if len(self.buf) < MAX_CONTROL_LINE_SIZE and _CRLF_ in self.buf: # FIXME: By default server uses a max protocol # line of 4096 bytes but it can be tuned in latest # releases, in that case we won't reach here but @@ -187,21 +180,17 @@ async def parse(self, data: bytes = b""): # Consume msg payload from buffer and set next parser state. if self.header_needed > 0: - hbuf = bytes(self.buf[:self.header_needed]) - payload = bytes( - self.buf[self.header_needed:self.needed] - ) + hbuf = bytes(self.buf[: self.header_needed]) + payload = bytes(self.buf[self.header_needed : self.needed]) hdr = hbuf - del self.buf[:self.needed + CRLF_SIZE] + del self.buf[: self.needed + CRLF_SIZE] self.header_needed = 0 else: - payload = bytes(self.buf[:self.needed]) - del self.buf[:self.needed + CRLF_SIZE] + payload = bytes(self.buf[: self.needed]) + del self.buf[: self.needed + CRLF_SIZE] self.state = AWAITING_CONTROL_LINE - await self.nc._process_msg( - sid, subject, reply, payload, hdr - ) + await self.nc._process_msg(sid, subject, reply, payload, hdr) else: # Wait until we have enough bytes in buffer. break diff --git a/nats/tests/test_client.py b/nats/tests/test_client.py index e16ee09a7..aff0dad1b 100644 --- a/nats/tests/test_client.py +++ b/nats/tests/test_client.py @@ -27,7 +27,6 @@ class ClientUtilsTest(unittest.TestCase): - def test_default_connect_command(self): nc = NATS() nc.options["verbose"] = False @@ -106,16 +105,36 @@ def test_semver_parsing(self): # Check that some common server versions do not panic. versions = [ - "2.2.2", "2.2.2", "2.2.2", "2.2.2-prerelease+meta", "2.2.2+meta", - "2.2.2+meta-valid", "2.2.2-alpha", "2.2.2-beta", - "2.2.2-alpha.beta", "2.2.2-alpha.beta.1", "2.2.2-alpha.1", - "2.2.2-alpha0.valid", "2.2.2-alpha.0valid", + "2.2.2", + "2.2.2", + "2.2.2", + "2.2.2-prerelease+meta", + "2.2.2+meta", + "2.2.2+meta-valid", + "2.2.2-alpha", + "2.2.2-beta", + "2.2.2-alpha.beta", + "2.2.2-alpha.beta.1", + "2.2.2-alpha.1", + "2.2.2-alpha0.valid", + "2.2.2-alpha.0valid", "2.2.2-alpha-a.b-c-somethinglong+build.1-aef.1-its-okay", - "2.2.2-rc.1+build.1", "2.2.2-rc.1+build.123", "2.2.2-RC.1+build.1", - "2.2.2-RC.1+build.123", "2.2.2-rc.1", "2.2.2-RC.1", - "2.2.2-RC.1+foo", "2.2.2-beta", "2.2.2-DEV-SNAPSHOT", - "2.2.2-SNAPSHOT-123", "2.2.2", "2.2.2", "2.2.2", - "2.2.2+build.1848", "2.2.2-alpha.1227", "2.2.2-alpha+beta" + "2.2.2-rc.1+build.1", + "2.2.2-rc.1+build.123", + "2.2.2-RC.1+build.1", + "2.2.2-RC.1+build.123", + "2.2.2-rc.1", + "2.2.2-RC.1", + "2.2.2-RC.1+foo", + "2.2.2-beta", + "2.2.2-DEV-SNAPSHOT", + "2.2.2-SNAPSHOT-123", + "2.2.2", + "2.2.2", + "2.2.2", + "2.2.2+build.1848", + "2.2.2-alpha.1227", + "2.2.2-alpha+beta", ] for version in versions: v = ServerVersion(version) @@ -125,7 +144,6 @@ def test_semver_parsing(self): class ClientTest(SingleServerTestCase): - @async_test async def test_default_connect(self): nc = await nats.connect() @@ -162,10 +180,7 @@ async def test_default_module_connect(self): def test_connect_syntax_sugar(self): nc = NATS() - nc._setup_server_pool([ - "nats://127.0.0.1:4222", "nats://127.0.0.1:4223", - "nats://127.0.0.1:4224" - ]) + nc._setup_server_pool(["nats://127.0.0.1:4222", "nats://127.0.0.1:4223", "nats://127.0.0.1:4224"]) self.assertEqual(3, len(nc._server_pool)) nc = NATS() @@ -396,9 +411,7 @@ async def subscription_handler(arg1, msg): partial_arg = arg1 msgs.append(msg) - partial_sub_handler = functools.partial( - subscription_handler, "example" - ) + partial_sub_handler = functools.partial(subscription_handler, "example") payload = b"hello world" await nc.connect() @@ -869,19 +882,13 @@ async def slow_worker_handler(msg): await nc.subscribe("help", cb=worker_handler) await nc.subscribe("slow.help", cb=slow_worker_handler) - response = await nc.request( - "help", b"please", timeout=1, old_style=True - ) + response = await nc.request("help", b"please", timeout=1, old_style=True) self.assertEqual(b"Reply:1", response.data) - response = await nc.request( - "help", b"please", timeout=1, old_style=True - ) + response = await nc.request("help", b"please", timeout=1, old_style=True) self.assertEqual(b"Reply:2", response.data) with self.assertRaises(nats.errors.TimeoutError): - msg = await nc.request( - "slow.help", b"please", timeout=0.1, old_style=True - ) + msg = await nc.request("slow.help", b"please", timeout=0.1, old_style=True) with self.assertRaises(nats.errors.NoRespondersError): await nc.request("nowhere", b"please", timeout=0.1, old_style=True) @@ -1115,7 +1122,6 @@ async def receiver_cb(msg): class ClientReconnectTest(MultiServerAuthTestCase): - @async_test async def test_connect_with_auth(self): nc = NATS() @@ -1144,9 +1150,7 @@ async def test_module_connect_with_auth(self): @async_test async def test_module_connect_with_options(self): - nc = await nats.connect( - "nats://127.0.0.1:4223", user="foo", password="bar" - ) + nc = await nats.connect("nats://127.0.0.1:4223", user="foo", password="bar") self.assertTrue(nc.is_connected) await nc.drain() self.assertTrue(nc.is_closed) @@ -1163,7 +1167,9 @@ async def err_cb(e): options = { "reconnect_time_wait": 0.2, - "servers": ["nats://hello:world@127.0.0.1:4223", ], + "servers": [ + "nats://hello:world@127.0.0.1:4223", + ], "max_reconnect_attempts": 3, "error_cb": err_cb, } @@ -1190,7 +1196,9 @@ async def err_cb(e): options = { "reconnect_time_wait": 0.2, - "servers": ["nats://hello:world@127.0.0.1:4223", ], + "servers": [ + "nats://hello:world@127.0.0.1:4223", + ], "max_reconnect_attempts": 3, "error_cb": err_cb, } @@ -1231,12 +1239,8 @@ async def err_cb(e): self.assertTrue(nc.is_connected) # Stop all servers so that there aren't any available to reconnect - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[1].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[1].stop) for i in range(0, 10): await asyncio.sleep(0) await asyncio.sleep(0.2) @@ -1248,9 +1252,7 @@ async def err_cb(e): # Restart one of the servers and confirm we are reconnected # even after many tries from small reconnect_time_wait. - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[1].start - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[1].start) for i in range(0, 10): await asyncio.sleep(0) await asyncio.sleep(0.2) @@ -1311,12 +1313,8 @@ async def err_cb(e): # Stop all servers so that there aren't any available to reconnect # then start one of them again. - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[1].stop - ) - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[1].stop) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) for i in range(0, 10): await asyncio.sleep(0) await asyncio.sleep(0.1) @@ -1329,18 +1327,14 @@ async def err_cb(e): # Restart one of the servers and confirm we are reconnected # even after many tries from small reconnect_time_wait. - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[1].start - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[1].start) for i in range(0, 10): await asyncio.sleep(0) await asyncio.sleep(0.1) await asyncio.sleep(0) # Stop the server once again - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[1].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[1].stop) for i in range(0, 10): await asyncio.sleep(0) await asyncio.sleep(0.1) @@ -1452,9 +1446,7 @@ async def cb(msg): if not done_once: await nc.flush(2) post_flush_pending_data = nc.pending_data_size - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) done_once = True self.assertTrue(largest_pending_data_size > 0) @@ -1529,9 +1521,7 @@ async def cb(msg): if not done_once: await nc.flush(2) post_flush_pending_data = nc.pending_data_size - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) done_once = True self.assertTrue(largest_pending_data_size > 0) @@ -1607,9 +1597,7 @@ async def worker_handler(msg): self.assertEqual(b"Reply:1", response.data) # Stop the first server and connect to another one asap. - asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) # FIXME: Find better way to wait for the server to be stopped. await asyncio.sleep(0.5) @@ -1627,12 +1615,15 @@ async def worker_handler(msg): class ClientAuthTokenTest(MultiServerAuthTokenTestCase): - @async_test async def test_connect_with_auth_token(self): nc = NATS() - options = {"servers": ["nats://token@127.0.0.1:4223", ]} + options = { + "servers": [ + "nats://token@127.0.0.1:4223", + ] + } await nc.connect(**options) self.assertIn("auth_required", nc._server_info) self.assertTrue(nc.is_connected) @@ -1645,7 +1636,9 @@ async def test_connect_with_auth_token_option(self): nc = NATS() options = { - "servers": ["nats://127.0.0.1:4223", ], + "servers": [ + "nats://127.0.0.1:4223", + ], "token": "token", } await nc.connect(**options) @@ -1660,7 +1653,9 @@ async def test_connect_with_bad_auth_token(self): nc = NATS() options = { - "servers": ["nats://token@127.0.0.1:4225", ], + "servers": [ + "nats://token@127.0.0.1:4225", + ], "allow_reconnect": False, "reconnect_time_wait": 0.1, "max_reconnect_attempts": 1, @@ -1717,9 +1712,7 @@ async def worker_handler(msg): self.assertTrue(nc.is_connected) # Trigger a reconnect - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.sleep(1) await nc.subscribe("test", cb=worker_handler) @@ -1735,7 +1728,6 @@ async def worker_handler(msg): class ClientTLSTest(TLSServerTestCase): - @async_test async def test_connect(self): nc = NATS() @@ -1755,9 +1747,7 @@ async def test_default_connect_using_tls_scheme(self): # Will attempt to connect using TLS with default certs. with self.assertRaises(ssl.SSLError): - await nc.connect( - servers=["tls://127.0.0.1:4224"], allow_reconnect=False - ) + await nc.connect(servers=["tls://127.0.0.1:4224"], allow_reconnect=False) @async_test async def test_default_connect_using_tls_scheme_in_url(self): @@ -1810,7 +1800,6 @@ async def subscription_handler(msg): class ClientTLSReconnectTest(MultiTLSServerAuthTestCase): - @async_test async def test_tls_reconnect(self): nc = NATS() @@ -1863,9 +1852,7 @@ async def worker_handler(msg): self.assertEqual(b"Reply:1", response.data) # Trigger a reconnect and should be fine - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.sleep(1) await nc.subscribe("example", cb=worker_handler) @@ -1883,17 +1870,12 @@ async def worker_handler(msg): class ClientTLSHandshakeFirstTest(TLSServerHandshakeFirstTestCase): - @async_test async def test_connect(self): if os.environ.get("NATS_SERVER_VERSION") != "main": pytest.skip("test requires nats-server@main") - nc = await nats.connect( - "nats://127.0.0.1:4224", - tls=self.ssl_ctx, - tls_handshake_first=True - ) + nc = await nats.connect("nats://127.0.0.1:4224", tls=self.ssl_ctx, tls_handshake_first=True) self.assertEqual(nc._server_info["max_payload"], nc.max_payload) self.assertTrue(nc._server_info["tls_required"]) self.assertTrue(nc._server_info["tls_verify"]) @@ -1927,11 +1909,7 @@ async def test_default_connect_using_tls_scheme_in_url(self): # Will attempt to connect using TLS with default certs. with self.assertRaises(ssl.SSLError): - await nc.connect( - "tls://127.0.0.1:4224", - allow_reconnect=False, - tls_handshake_first=True - ) + await nc.connect("tls://127.0.0.1:4224", allow_reconnect=False, tls_handshake_first=True) @async_test async def test_connect_tls_with_custom_hostname(self): @@ -1987,23 +1965,22 @@ async def subscription_handler(msg): class ClusterDiscoveryTest(ClusteringTestCase): - @async_test async def test_discover_servers_on_first_connect(self): nc = NATS() # Start rest of cluster members so that we receive them # connect_urls on the first connect. - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[1].start - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[1].start) await asyncio.sleep(1) - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[2].start - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[2].start) await asyncio.sleep(1) - options = {"servers": ["nats://127.0.0.1:4223", ]} + options = { + "servers": [ + "nats://127.0.0.1:4223", + ] + } discovered_server_cb = mock.AsyncMock() await nc.connect(**options, discovered_server_cb=discovered_server_cb) self.assertTrue(nc.is_connected) @@ -2017,19 +1994,19 @@ async def test_discover_servers_on_first_connect(self): async def test_discover_servers_after_first_connect(self): nc = NATS() - options = {"servers": ["nats://127.0.0.1:4223", ]} + options = { + "servers": [ + "nats://127.0.0.1:4223", + ] + } discovered_server_cb = mock.AsyncMock() await nc.connect(**options, discovered_server_cb=discovered_server_cb) # Start rest of cluster members so that we receive them # connect_urls on the first connect. - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[1].start - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[1].start) await asyncio.sleep(1) - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[2].start - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[2].start) await asyncio.sleep(1) await nc.close() @@ -2040,7 +2017,6 @@ async def test_discover_servers_after_first_connect(self): class ClusterDiscoveryReconnectTest(ClusteringDiscoveryAuthTestCase): - @async_test async def test_reconnect_to_new_server_with_auth(self): nc = NATS() @@ -2056,7 +2032,9 @@ async def err_cb(e): errors.append(e) options = { - "servers": ["nats://127.0.0.1:4223", ], + "servers": [ + "nats://127.0.0.1:4223", + ], "reconnected_cb": reconnected_cb, "error_cb": err_cb, "reconnect_time_wait": 0.1, @@ -2074,9 +2052,7 @@ async def handler(msg): await nc.subscribe("foo", cb=handler) # Remove first member and try to reconnect - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.wait_for(reconnected, 2) msg = await nc.request("foo", b"hi") @@ -2137,9 +2113,7 @@ async def handler(msg): self.assertEqual(b"ok", msg.data) # Remove first member and try to reconnect - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.wait_for(disconnected, 2) # Publishing while disconnected is an error if pending size is disabled. @@ -2207,9 +2181,7 @@ async def handler(msg): self.assertEqual(b"ok", msg.data) # Remove first member and try to reconnect - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.wait_for(disconnected, 2) # While reconnecting the pending data will accumulate. @@ -2352,10 +2324,8 @@ async def handler(msg): class ConnectFailuresTest(SingleServerTestCase): - @async_test async def test_empty_info_op_uses_defaults(self): - async def bad_server(reader, writer): writer.write(b"INFO {}\r\n") await writer.drain() @@ -2374,7 +2344,9 @@ async def disconnected_cb(): nc = NATS() options = { - "servers": ["nats://127.0.0.1:4555", ], + "servers": [ + "nats://127.0.0.1:4555", + ], "disconnected_cb": disconnected_cb, } await nc.connect(**options) @@ -2385,7 +2357,6 @@ async def disconnected_cb(): @async_test async def test_empty_response_from_server(self): - async def bad_server(reader, writer): writer.write(b"") await asyncio.sleep(0.2) @@ -2401,7 +2372,9 @@ async def error_cb(e): nc = NATS() options = { - "servers": ["nats://127.0.0.1:4555", ], + "servers": [ + "nats://127.0.0.1:4555", + ], "error_cb": error_cb, "allow_reconnect": False, } @@ -2413,7 +2386,6 @@ async def error_cb(e): @async_test async def test_malformed_info_response_from_server(self): - async def bad_server(reader, writer): writer.write(b"INF") await asyncio.sleep(0.2) @@ -2429,7 +2401,9 @@ async def error_cb(e): nc = NATS() options = { - "servers": ["nats://127.0.0.1:4555", ], + "servers": [ + "nats://127.0.0.1:4555", + ], "error_cb": error_cb, "allow_reconnect": False, } @@ -2441,7 +2415,6 @@ async def error_cb(e): @async_test async def test_malformed_info_json_response_from_server(self): - async def bad_server(reader, writer): writer.write(b"INFO {\r\n") await asyncio.sleep(0.2) @@ -2457,7 +2430,9 @@ async def error_cb(e): nc = NATS() options = { - "servers": ["nats://127.0.0.1:4555", ], + "servers": [ + "nats://127.0.0.1:4555", + ], "error_cb": error_cb, "allow_reconnect": False, } @@ -2470,7 +2445,6 @@ async def error_cb(e): @async_test async def test_connect_timeout(self): - async def slow_server(reader, writer): await asyncio.sleep(1) writer.close() @@ -2490,7 +2464,9 @@ async def reconnected_cb(): nc = NATS() options = { - "servers": ["nats://127.0.0.1:4555", ], + "servers": [ + "nats://127.0.0.1:4555", + ], "disconnected_cb": disconnected_cb, "reconnected_cb": reconnected_cb, "connect_timeout": 0.5, @@ -2507,7 +2483,6 @@ async def reconnected_cb(): @async_test async def test_connect_timeout_then_connect_to_healthy_server(self): - async def slow_server(reader, writer): await asyncio.sleep(1) writer.close() @@ -2561,7 +2536,6 @@ async def error_cb(e): class ClientDrainTest(SingleServerTestCase): - @async_test async def test_drain_subscription(self): nc = NATS() @@ -2683,17 +2657,9 @@ async def replies(msg): await nc2.subscribe("my-replies.*", cb=replies) for i in range(0, 201): - await nc2.publish( - "foo", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" - ) - await nc2.publish( - "bar", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" - ) - await nc2.publish( - "quux", - b"help", - reply=f"my-replies.{nc._nuid.next().decode()}" - ) + await nc2.publish("foo", b"help", reply=f"my-replies.{nc._nuid.next().decode()}") + await nc2.publish("bar", b"help", reply=f"my-replies.{nc._nuid.next().decode()}") + await nc2.publish("quux", b"help", reply=f"my-replies.{nc._nuid.next().decode()}") # Relinquish control so that messages are processed. await asyncio.sleep(0) @@ -2752,9 +2718,7 @@ async def closed_cb(): nonlocal drain_done drain_done.set_result(True) - await nc.connect( - closed_cb=closed_cb, error_cb=error_cb, drain_timeout=0.1 - ) + await nc.connect(closed_cb=closed_cb, error_cb=error_cb, drain_timeout=0.1) nc2 = NATS() await nc2.connect() @@ -2779,17 +2743,9 @@ async def replies(msg): await nc2.subscribe("my-replies.*", cb=replies) for i in range(0, 201): - await nc2.publish( - "foo", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" - ) - await nc2.publish( - "bar", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" - ) - await nc2.publish( - "quux", - b"help", - reply=f"my-replies.{nc._nuid.next().decode()}" - ) + await nc2.publish("foo", b"help", reply=f"my-replies.{nc._nuid.next().decode()}") + await nc2.publish("bar", b"help", reply=f"my-replies.{nc._nuid.next().decode()}") + await nc2.publish("quux", b"help", reply=f"my-replies.{nc._nuid.next().decode()}") # Relinquish control so that messages are processed. await asyncio.sleep(0) @@ -2816,11 +2772,11 @@ def f(): pass for cb in [ - "error_cb", - "disconnected_cb", - "discovered_server_cb", - "closed_cb", - "reconnected_cb", + "error_cb", + "disconnected_cb", + "discovered_server_cb", + "closed_cb", + "reconnected_cb", ]: with self.assertRaises(nats.errors.InvalidCallbackTypeError): await nc.connect( @@ -2834,17 +2790,11 @@ def f(): async def test_protocol_mixing(self): nc = NATS() with self.assertRaises(nats.errors.Error): - await nc.connect( - servers=["nats://127.0.0.1:4222", "ws://127.0.0.1:8080"] - ) + await nc.connect(servers=["nats://127.0.0.1:4222", "ws://127.0.0.1:8080"]) with self.assertRaises(nats.errors.Error): - await nc.connect( - servers=["nats://127.0.0.1:4222", "wss://127.0.0.1:8080"] - ) + await nc.connect(servers=["nats://127.0.0.1:4222", "wss://127.0.0.1:8080"]) with self.assertRaises(nats.errors.Error): - await nc.connect( - servers=["tls://127.0.0.1:4222", "wss://127.0.0.1:8080"] - ) + await nc.connect(servers=["tls://127.0.0.1:4222", "wss://127.0.0.1:8080"]) @async_test async def test_drain_cancelled_errors_raised(self): @@ -2865,16 +2815,14 @@ async def cb(msg): await asyncio.sleep(0.1) with self.assertRaises(asyncio.CancelledError): with unittest.mock.patch( - "asyncio.wait_for", - unittest.mock.AsyncMock(side_effect=asyncio.CancelledError - ), + "asyncio.wait_for", + unittest.mock.AsyncMock(side_effect=asyncio.CancelledError), ): await sub.drain() await nc.close() class NoAuthUserClientTest(NoAuthUserServerTestCase): - @async_test async def test_connect_user(self): fut = asyncio.Future() @@ -2895,11 +2843,11 @@ async def err_cb(e): await nc.publish("foo", b"hello") await asyncio.wait_for(fut, 2) err = fut.result() - assert str( - err - ) == 'nats: permissions violation for subscription to "foo"' + assert str(err) == 'nats: permissions violation for subscription to "foo"' - nc2 = await nats.connect("nats://127.0.0.1:4555", ) + nc2 = await nats.connect( + "nats://127.0.0.1:4555", + ) async def cb(msg): await msg.respond(b"pong") @@ -2930,11 +2878,11 @@ async def err_cb(e): await nc.publish("foo", b"hello") await asyncio.wait_for(fut, 2) err = fut.result() - assert str( - err - ) == 'nats: permissions violation for subscription to "foo"' + assert str(err) == 'nats: permissions violation for subscription to "foo"' - nc2 = await nats.connect("nats://127.0.0.1:4555", ) + nc2 = await nats.connect( + "nats://127.0.0.1:4555", + ) async def cb(msg): await msg.respond(b"pong") @@ -2949,7 +2897,6 @@ async def cb(msg): class ClientDisconnectTest(SingleServerTestCase): - @async_test async def test_close_while_disconnected(self): reconnected = asyncio.Future() @@ -2981,9 +2928,7 @@ async def disconnected_cb(): msg = await sub.next_msg() self.assertEqual(msg.data, b"First") - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.wait_for(disconnected, 2) await nc.close() diff --git a/nats/tests/test_client_async_await.py b/nats/tests/test_client_async_await.py index d73a4d761..ec1047542 100644 --- a/nats/tests/test_client_async_await.py +++ b/nats/tests/test_client_async_await.py @@ -6,7 +6,6 @@ class ClientAsyncAwaitTest(SingleServerTestCase): - @async_test async def test_async_await_subscribe_async(self): nc = NATS() diff --git a/nats/tests/test_client_nkeys.py b/nats/tests/test_client_nkeys.py index 71481c94e..b7a6d10a5 100644 --- a/nats/tests/test_client_nkeys.py +++ b/nats/tests/test_client_nkeys.py @@ -23,7 +23,6 @@ class ClientNkeysAuthTest(NkeysServerTestCase): - @async_test async def test_nkeys_connect(self): import os @@ -34,12 +33,8 @@ async def test_nkeys_connect(self): seed = bytearray(os.fstat(f.fileno()).st_size) f.readinto(seed) args_list = [ - { - "nkeys_seed": config_file - }, - { - "nkeys_seed_str": seed.decode() - }, + {"nkeys_seed": config_file}, + {"nkeys_seed_str": seed.decode()}, ] for nkeys_args in args_list: if not nkeys_installed: @@ -81,7 +76,6 @@ async def help_handler(msg): class ClientJWTAuthTest(TrustedServerTestCase): - @async_test async def test_nkeys_jwt_creds_user_connect(self): if not nkeys_installed: diff --git a/nats/tests/test_client_v2.py b/nats/tests/test_client_v2.py index 755cefb23..54bf8cd4c 100644 --- a/nats/tests/test_client_v2.py +++ b/nats/tests/test_client_v2.py @@ -6,19 +6,13 @@ class HeadersTest(SingleServerTestCase): - @async_test async def test_simple_headers(self): nc = await nats.connect() sub = await nc.subscribe("foo") await nc.flush() - await nc.publish( - "foo", b"hello world", headers={ - "foo": "bar", - "hello": "world-1" - } - ) + await nc.publish("foo", b"hello world", headers={"foo": "bar", "hello": "world-1"}) msg = await sub.next_msg() self.assertTrue(msg.headers != None) @@ -40,12 +34,7 @@ async def service(msg): await nc.subscribe("foo", cb=service) await nc.flush() - msg = await nc.request( - "foo", b"hello world", headers={ - "foo": "bar", - "hello": "world" - } - ) + msg = await nc.request("foo", b"hello world", headers={"foo": "bar", "hello": "world"}) self.assertTrue(msg.headers != None) self.assertEqual(len(msg.headers), 3) @@ -73,9 +62,7 @@ async def test_empty_headers(self): self.assertTrue(msg.headers == None) # Empty long key - await nc.publish( - "foo", b"hello world", headers={"": " "} - ) + await nc.publish("foo", b"hello world", headers={"": " "}) msg = await sub.next_msg() self.assertTrue(msg.headers == None) diff --git a/nats/tests/test_client_websocket.py b/nats/tests/test_client_websocket.py index f3172979b..654f3b0ca 100644 --- a/nats/tests/test_client_websocket.py +++ b/nats/tests/test_client_websocket.py @@ -8,13 +8,13 @@ try: import aiohttp + aiohttp_installed = True except ModuleNotFoundError: aiohttp_installed = False class WebSocketTest(SingleWebSocketServerTestCase): - @async_test async def test_simple_headers(self): if not aiohttp_installed: @@ -24,12 +24,7 @@ async def test_simple_headers(self): sub = await nc.subscribe("foo") await nc.flush() - await nc.publish( - "foo", b"hello world", headers={ - "foo": "bar", - "hello": "world-1" - } - ) + await nc.publish("foo", b"hello world", headers={"foo": "bar", "hello": "world-1"}) msg = await sub.next_msg() self.assertTrue(msg.headers != None) @@ -54,12 +49,7 @@ async def service(msg): await nc.subscribe("foo", cb=service) await nc.flush() - msg = await nc.request( - "foo", b"hello world", headers={ - "foo": "bar", - "hello": "world" - } - ) + msg = await nc.request("foo", b"hello world", headers={"foo": "bar", "hello": "world"}) self.assertTrue(msg.headers != None) self.assertEqual(len(msg.headers), 3) @@ -90,9 +80,7 @@ async def test_empty_headers(self): self.assertTrue(msg.headers == None) # Empty long key - await nc.publish( - "foo", b"hello world", headers={"": " "} - ) + await nc.publish("foo", b"hello world", headers={"": " "}) msg = await sub.next_msg() self.assertTrue(msg.headers == None) @@ -140,13 +128,9 @@ async def bar_cb(msg): self.assertEqual(rmsg.data, b"OK!") # Restart the server and wait for reconnect. - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.sleep(1) - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].start - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].start) await asyncio.wait_for(reconnected, 2) # Get another message. @@ -190,9 +174,7 @@ async def bar_cb(msg): self.assertEqual(rmsg.data, b"OK!") # Restart the server and wait for reconnect. - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.sleep(1) # Should not fail closing while disconnected. @@ -200,7 +182,6 @@ async def bar_cb(msg): class WebSocketTLSTest(SingleWebSocketTLSServerTestCase): - @async_test async def test_pub_sub(self): if not aiohttp_installed: @@ -231,11 +212,7 @@ async def reconnected_cb(): if not reconnected.done(): reconnected.set_result(True) - nc = await nats.connect( - "wss://localhost:8081", - reconnected_cb=reconnected_cb, - tls=self.ssl_ctx - ) + nc = await nats.connect("wss://localhost:8081", reconnected_cb=reconnected_cb, tls=self.ssl_ctx) sub = await nc.subscribe("foo") @@ -252,13 +229,9 @@ async def bar_cb(msg): self.assertEqual(rmsg.data, b"OK!") # Restart the server and wait for reconnect. - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.sleep(1) - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].start - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].start) await asyncio.wait_for(reconnected, 2) # Get another message. @@ -303,9 +276,7 @@ async def bar_cb(msg): self.assertEqual(rmsg.data, b"OK!") # Restart the server and wait for reconnect. - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.sleep(1) # Should not fail closing while disconnected. diff --git a/nats/tests/test_compatibility.py b/nats/tests/test_compatibility.py index 838c3a8a4..efdcc83a0 100644 --- a/nats/tests/test_compatibility.py +++ b/nats/tests/test_compatibility.py @@ -30,7 +30,6 @@ @skipIf("NATS_URL" not in os.environ, "NATS_URL not set in environment") class CompatibilityTest(TestCase): - def setUp(self): self.loop = asyncio.new_event_loop() @@ -40,15 +39,12 @@ def tearDown(self): async def validate_test_result(self, sub: Subscription): try: msg = await sub.next_msg(timeout=5) - self.assertNotIn( - "fail", msg.subject, f"Test step failed: {msg.subject}" - ) + self.assertNotIn("fail", msg.subject, f"Test step failed: {msg.subject}") except asyncio.TimeoutError: self.fail("Timeout waiting for test result") @async_long_test async def test_service_compatibility(self): - @dataclass class TestGroupConfig: name: str @@ -56,9 +52,7 @@ class TestGroupConfig: @classmethod def from_dict(cls, data: Dict[str, Any]) -> TestGroupConfig: - return cls( - name=data["name"], queue_group=data.get("queue_group") - ) + return cls(name=data["name"], queue_group=data.get("queue_group")) @dataclass class TestEndpointConfig: @@ -96,14 +90,8 @@ def from_dict(cls, data: Dict[str, Any]) -> TestServiceConfig: description=data["description"], queue_group=data.get("queue_group"), metadata=data["metadata"], - groups=[ - TestGroupConfig.from_dict(group) - for group in data.get("groups", []) - ], - endpoints=[ - TestEndpointConfig.from_dict(endpoint) - for endpoint in data.get("endpoints", []) - ], + groups=[TestGroupConfig.from_dict(group) for group in data.get("groups", [])], + endpoints=[TestEndpointConfig.from_dict(endpoint) for endpoint in data.get("endpoints", [])], ) @dataclass @@ -153,9 +141,7 @@ def stats_handler(endpoint: EndpointStats) -> Dict[str, str]: groups = {} for group_config in test_step.config.groups: - group = svc.add_group( - name=group_config.name, queue_group=group_config.queue_group - ) + group = svc.add_group(name=group_config.name, queue_group=group_config.queue_group) groups[group_config.name] = group for step_endpoint_config in test_step.config.endpoints: diff --git a/nats/tests/test_js.py b/nats/tests/test_js.py index 42778f89b..983950228 100644 --- a/nats/tests/test_js.py +++ b/nats/tests/test_js.py @@ -29,7 +29,6 @@ class PublishTest(SingleJetStreamServerTestCase): - @async_test async def test_publish(self): nc = NATS() @@ -100,9 +99,7 @@ async def test_publish_async(self): await js.add_stream(name="QUUX", subjects=["quux"]) - futures = [ - await js.publish_async("quux", b"bar:1") for i in range(0, 100) - ] + futures = [await js.publish_async("quux", b"bar:1") for i in range(0, 100)] await js.publish_async_completed() results = await asyncio.gather(*futures) @@ -114,18 +111,12 @@ async def test_publish_async(self): self.assertEqual(result.seq, seq) with pytest.raises(TooManyStalledMsgsError): - publishes = [ - js.publish_async("quux", b"bar:1", wait_stall=0.0) - for i in range(0, 100) - ] + publishes = [js.publish_async("quux", b"bar:1", wait_stall=0.0) for i in range(0, 100)] futures = await asyncio.gather(*publishes) results = await asyncio.gather(*futures) self.assertEqual(len(results), 100) - publishes = [ - js.publish_async("quux", b"bar:1", wait_stall=1.0) - for i in range(0, 1000) - ] + publishes = [js.publish_async("quux", b"bar:1", wait_stall=1.0) for i in range(0, 1000)] futures = await asyncio.gather(*publishes) results = await asyncio.gather(*futures) self.assertEqual(len(results), 1000) @@ -134,7 +125,6 @@ async def test_publish_async(self): class PullSubscribeTest(SingleJetStreamServerTestCase): - @async_test async def test_auto_create_consumer(self): nc = NATS() @@ -163,9 +153,7 @@ async def test_auto_create_consumer(self): await sub.fetch(1, timeout=1) # Customize consumer config. - sub = await js.pull_subscribe( - "a2", "auto2", config=nats.js.api.ConsumerConfig(max_waiting=10) - ) + sub = await js.pull_subscribe("a2", "auto2", config=nats.js.api.ConsumerConfig(max_waiting=10)) msgs = await sub.fetch(1) msg = msgs[0] await msg.ack() @@ -229,9 +217,7 @@ async def test_fetch_one(self): msg = msgs[0] assert msg.metadata.sequence.stream == 1 assert msg.metadata.sequence.consumer == 1 - assert datetime.datetime.now( - datetime.timezone.utc - ) > msg.metadata.timestamp + assert datetime.datetime.now(datetime.timezone.utc) > msg.metadata.timestamp assert msg.metadata.num_pending == 0 assert msg.metadata.num_delivered == 1 @@ -239,9 +225,7 @@ async def test_fetch_one(self): await sub.fetch(timeout=1) for i in range(0, 10): - await js.publish( - "foo.1", f"i:{i}".encode(), headers={"hello": "world"} - ) + await js.publish("foo.1", f"i:{i}".encode(), headers={"hello": "world"}) # nak msgs = await sub.fetch() @@ -296,9 +280,7 @@ async def test_fetch_one_wait_forever(self): msg = msgs[0] assert msg.metadata.sequence.stream == 1 assert msg.metadata.sequence.consumer == 1 - assert datetime.datetime.now( - datetime.timezone.utc - ) > msg.metadata.timestamp + assert datetime.datetime.now(datetime.timezone.utc) > msg.metadata.timestamp assert msg.metadata.num_pending == 0 assert msg.metadata.num_delivered == 1 @@ -1075,7 +1057,6 @@ async def test_fetch_heartbeats(self): class JSMTest(SingleJetStreamServerTestCase): - @async_test async def test_stream_management(self): nc = NATS() @@ -1086,9 +1067,7 @@ async def test_stream_management(self): assert isinstance(acc, nats.js.api.AccountInfo) # Create stream - stream = await jsm.add_stream( - name="hello", subjects=["hello", "world", "hello.>"] - ) + stream = await jsm.add_stream(name="hello", subjects=["hello", "world", "hello.>"]) assert isinstance(stream, nats.js.api.StreamInfo) assert isinstance(stream.config, nats.js.api.StreamConfig) assert stream.config.name == "hello" @@ -1101,9 +1080,7 @@ async def test_stream_management(self): with pytest.raises(ValueError): await jsm.add_stream(nats.js.api.StreamConfig()) # Create with config, name is provided as kwargs - stream_with_name = await jsm.add_stream( - nats.js.api.StreamConfig(), name="hi" - ) + stream_with_name = await jsm.add_stream(nats.js.api.StreamConfig(), name="hi") assert stream_with_name.config.name == "hi" # Get info @@ -1130,9 +1107,7 @@ async def test_stream_management(self): stream_config = current.config stream_config.subjects.append("extra") updated_stream = await jsm.update_stream(stream_config) - assert updated_stream.config.subjects == [ - "hello", "world", "hello.>", "extra" - ] + assert updated_stream.config.subjects == ["hello", "world", "hello.>", "extra"] # Purge Stream is_purged = await jsm.purge_stream("hello") @@ -1339,9 +1314,7 @@ async def test_number_of_consumer_replicas(self): await js.publish("test.replicas", f"{i}".encode()) # Create consumer - config = nats.js.api.ConsumerConfig( - num_replicas=1, durable_name="mycons" - ) + config = nats.js.api.ConsumerConfig(num_replicas=1, durable_name="mycons") cons = await js.add_consumer(stream="TESTREPLICAS", config=config) if cons.config.num_replicas: assert cons.config.num_replicas == 1 @@ -1457,10 +1430,7 @@ async def test_consumer_with_name(self): ack_policy="explicit", ) assert err.value.err_code == 10017 - assert ( - err.value.description == - "consumer name in subject does not match durable name in request" - ) + assert err.value.description == "consumer name in subject does not match durable name in request" # Create ephemeral pull consumer with a name and inactive threshold. stream_name = "ctests" @@ -1519,7 +1489,6 @@ async def test_jsm_stream_info_options(self): class SubscribeTest(SingleJetStreamServerTestCase): - @async_test async def test_queue_subscribe_deliver_group(self): nc = await nats.connect() @@ -1618,9 +1587,7 @@ async def cb2(msg): assert err.value.description == "consumer is already bound to a subscription" with pytest.raises(nats.js.errors.Error) as err: - await js.subscribe( - "pbound", queue="foo", cb=cb2, durable="singleton" - ) + await js.subscribe("pbound", queue="foo", cb=cb2, durable="singleton") exp = "cannot create queue subscription 'foo' to consumer 'singleton'" assert err.value.description == exp @@ -1842,38 +1809,33 @@ async def cb_s(msg): async def cb_d(msg): d.append(msg.data) - #Create config for our subscriber - cc = nats.js.api.ConsumerConfig( - name="pconfig-ps", deliver_subject="pconfig-deliver" - ) + # Create config for our subscriber + cc = nats.js.api.ConsumerConfig(name="pconfig-ps", deliver_subject="pconfig-deliver") - #Make stream consumer with set deliver_subjct - sub_s = await js.subscribe( - "pconfig", stream="pconfig", cb=cb_s, config=cc - ) - #Make direct sub on deliver_subject + # Make stream consumer with set deliver_subjct + sub_s = await js.subscribe("pconfig", stream="pconfig", cb=cb_s, config=cc) + # Make direct sub on deliver_subject sub_d = await nc.subscribe("pconfig-deliver", "check-queue", cb=cb_d) - #Stream consumer sub should have configured subject + # Stream consumer sub should have configured subject assert sub_s.subject == "pconfig-deliver" - #Publish some messages + # Publish some messages for i in range(10): - await js.publish("pconfig", f'Hello World {i}'.encode()) + await js.publish("pconfig", f"Hello World {i}".encode()) await asyncio.sleep(0.5) - #Both subs should recieve same messages, but we are not sure about order + # Both subs should recieve same messages, but we are not sure about order assert len(s) == len(d) assert set(s) == set(d) - #Cleanup + # Cleanup await js.delete_consumer("pconfig", "pconfig-ps") await js.delete_stream("pconfig") await nc.close() class AckPolicyTest(SingleJetStreamServerTestCase): - @async_test async def test_ack_v2_tokens(self): nc = await nats.connect() @@ -1898,9 +1860,7 @@ async def test_ack_v2_tokens(self): assert meta.sequence.consumer == consumer_sequence assert meta.num_delivered == num_delivered assert meta.num_pending == num_pending - exp = datetime.datetime( - 2022, 9, 11, 0, 28, 27, 340506, tzinfo=datetime.timezone.utc - ) + exp = datetime.datetime(2022, 9, 11, 0, 28, 27, 340506, tzinfo=datetime.timezone.utc) assert meta.timestamp.astimezone(datetime.timezone.utc) == exp # Complete v2 tokens (last one discarded) @@ -1914,17 +1874,9 @@ async def test_ack_v2_tokens(self): assert meta.sequence.consumer == consumer_sequence assert meta.num_delivered == num_delivered assert meta.num_pending == num_pending - assert meta.timestamp.astimezone(datetime.timezone.utc - ) == datetime.datetime( - 2022, - 9, - 11, - 0, - 28, - 27, - 340506, - tzinfo=datetime.timezone.utc - ) + assert meta.timestamp.astimezone(datetime.timezone.utc) == datetime.datetime( + 2022, 9, 11, 0, 28, 27, 340506, tzinfo=datetime.timezone.utc + ) @async_test async def test_double_acking_pull_subscribe(self): @@ -2055,7 +2007,6 @@ async def f(): class DiscardPolicyTest(SingleJetStreamServerTestCase): - @async_test async def test_with_discard_new_and_discard_new_per_subject_set(self): # Connect to NATS and create JetStream context @@ -2139,9 +2090,7 @@ async def test_with_discard_old_and_discard_new_per_subject_not_set(self): await nc.close() @async_test - async def test_with_discard_new_and_discard_new_per_subject_set_no_max_msgs( - self - ): + async def test_with_discard_new_and_discard_new_per_subject_set_no_max_msgs(self): # Connect to NATS and create JetStream context nc = await nats.connect() js = nc.jetstream() @@ -2160,7 +2109,6 @@ async def test_with_discard_new_and_discard_new_per_subject_set_no_max_msgs( class OrderedConsumerTest(SingleJetStreamServerTestCase): - @async_test async def test_flow_control(self): errors = [] @@ -2181,14 +2129,9 @@ async def cb(msg): with pytest.raises(nats.js.errors.APIError) as err: sub = await js.subscribe(subject, cb=cb, flow_control=True) - assert ( - err.value.description == - "consumer with flow control also needs heartbeats" - ) + assert err.value.description == "consumer with flow control also needs heartbeats" - sub = await js.subscribe( - subject, cb=cb, flow_control=True, idle_heartbeat=0.5 - ) + sub = await js.subscribe(subject, cb=cb, flow_control=True, idle_heartbeat=0.5) tasks = [] @@ -2204,7 +2147,7 @@ async def producer(): if mlen - i <= chunksize: chunk = msg[i:] else: - chunk = msg[i:i + chunksize] + chunk = msg[i : i + chunksize] i += chunksize task = asyncio.create_task(js.publish(subject, chunk)) tasks.append(task) @@ -2261,11 +2204,9 @@ async def producer(): if mlen - i <= chunksize: chunk = msg[i:] else: - chunk = msg[i:i + chunksize] + chunk = msg[i : i + chunksize] i += chunksize - task = asyncio.create_task( - js.publish(subject, chunk, headers={"data": "true"}) - ) + task = asyncio.create_task(js.publish(subject, chunk, headers={"data": "true"})) await asyncio.sleep(0) tasks.append(task) @@ -2328,9 +2269,7 @@ def _build_message(sid, subject, reply, data, headers): nc._build_message = _build_message subject = "osub2" - await js2.add_stream( - name=subject, subjects=[subject], storage="memory" - ) + await js2.add_stream(name=subject, subjects=[subject], storage="memory") # Consumer callback. future = asyncio.Future() @@ -2365,11 +2304,9 @@ async def producer(): if mlen - i <= chunksize: chunk = msg[i:] else: - chunk = msg[i:i + chunksize] + chunk = msg[i : i + chunksize] i += chunksize - task = asyncio.create_task( - nc2.publish(subject, chunk, headers={"data": "true"}) - ) + task = asyncio.create_task(nc2.publish(subject, chunk, headers={"data": "true"})) tasks.append(task) task = asyncio.create_task(producer()) @@ -2412,9 +2349,7 @@ async def error_handler(e): errors.append(e) # Consumer - nc = await nats.connect( - error_cb=error_handler, reconnected_cb=consumer_reconnected_cb - ) + nc = await nats.connect(error_cb=error_handler, reconnected_cb=consumer_reconnected_cb) # Producer nc2 = await nats.connect(error_cb=error_handler) @@ -2439,11 +2374,9 @@ async def producer(): if mlen - i <= chunksize: chunk = msg[i:] else: - chunk = msg[i:i + chunksize] + chunk = msg[i : i + chunksize] i += chunksize - task = asyncio.create_task( - nc2.publish(subject, chunk, headers={"data": "true"}) - ) + task = asyncio.create_task(nc2.publish(subject, chunk, headers={"data": "true"})) tasks.append(task) task = asyncio.create_task(producer()) @@ -2465,15 +2398,11 @@ async def cb(msg): if not done.done(): done.set_result(True) - sub = await js.subscribe( - subject, cb=cb, ordered_consumer=True, idle_heartbeat=0.5 - ) + sub = await js.subscribe(subject, cb=cb, ordered_consumer=True, idle_heartbeat=0.5) await asyncio.wait_for(done, 10) # Using only next_msg which would be slower. - sub = await js.subscribe( - subject, ordered_consumer=True, idle_heartbeat=0.5 - ) + sub = await js.subscribe(subject, ordered_consumer=True, idle_heartbeat=0.5) i = 0 while i < stream.state.messages: try: @@ -2488,19 +2417,13 @@ async def cb(msg): ###################### # Reconnecting # ###################### - sub = await js.subscribe( - subject, ordered_consumer=True, idle_heartbeat=0.5 - ) + sub = await js.subscribe(subject, ordered_consumer=True, idle_heartbeat=0.5) i = 0 while i < stream.state.messages: if i == 5000: - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.sleep(0.2) - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].start - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].start) try: msg = await sub.next_msg() data = msg.data.decode("utf-8") @@ -2518,13 +2441,9 @@ async def cb(msg): nonlocal done if i == 10000: - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.sleep(0.2) - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].start - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].start) data = msg.data.decode("utf-8") i += 1 @@ -2532,9 +2451,7 @@ async def cb(msg): if not done.done(): done.set_result(True) - sub = await js.subscribe( - subject, cb=cb, ordered_consumer=True, idle_heartbeat=0.5 - ) + sub = await js.subscribe(subject, cb=cb, ordered_consumer=True, idle_heartbeat=0.5) await asyncio.wait_for(done, 10) await nc.close() @@ -2549,22 +2466,16 @@ async def error_handler(e): nc = await nats.connect(error_cb=error_handler) js = nc.jetstream() - await js.add_stream( - name="MY_STREAM", subjects=["test.*"], storage="memory" - ) + await js.add_stream(name="MY_STREAM", subjects=["test.*"], storage="memory") subject = "test.1" for m in ["1", "2", "3"]: await js.publish(subject=subject, payload=m.encode()) - sub = await js.subscribe( - subject, ordered_consumer=True, idle_heartbeat=0.5 - ) + sub = await js.subscribe(subject, ordered_consumer=True, idle_heartbeat=0.5) info = await sub.consumer_info() orig_name = info.name await js.delete_consumer("MY_STREAM", info.name) - await asyncio.sleep( - 3 - ) # now the consumer should reset due to missing HB + await asyncio.sleep(3) # now the consumer should reset due to missing HB info = await sub.consumer_info() self.assertTrue(orig_name != info.name) @@ -2572,7 +2483,6 @@ async def error_handler(e): class KVTest(SingleJetStreamServerTestCase): - @async_test async def test_kv_simple(self): errors = [] @@ -2713,9 +2623,7 @@ async def test_bucket_name_validation(self): for bucket_name in invalid_bucket_names: with self.subTest(bucket_name): with pytest.raises(InvalidBucketNameError): - await js.create_key_value( - bucket=bucket_name, history=5, ttl=3600 - ) + await js.create_key_value(bucket=bucket_name, history=5, ttl=3600) with pytest.raises(InvalidBucketNameError): await js.key_value(bucket_name) @@ -2748,13 +2656,13 @@ async def test_key_validation(self): with self.subTest(key): # Invalid put (empty) with pytest.raises(InvalidKeyError): - await kv.put(key, b'') + await kv.put(key, b"") with pytest.raises(InvalidKeyError): await kv.get(key) with pytest.raises(InvalidKeyError): - await kv.update(key, b'') + await kv.update(key, b"") @async_test async def test_key_validation_bypass(self): @@ -2771,17 +2679,15 @@ async def test_key_validation_bypass(self): for key in invalid_keys: with self.subTest(key): # Should succeed with validate_keys=False - seq = await kv.put(key, b'test_value', validate_keys=False) + seq = await kv.put(key, b"test_value", validate_keys=False) assert seq > 0 # Should be able to get with validate_keys=False entry = await kv.get(key, validate_keys=False) - assert entry.value == b'test_value' + assert entry.value == b"test_value" # Should be able to update with validate_keys=False - seq2 = await kv.update( - key, b'updated_value', last=seq, validate_keys=False - ) + seq2 = await kv.update(key, b"updated_value", last=seq, validate_keys=False) assert seq2 > seq # Should be able to delete with validate_keys=False @@ -2790,7 +2696,7 @@ async def test_key_validation_bypass(self): # Should still fail with default validate_keys=True with pytest.raises(InvalidKeyError): - await kv.put(key, b'fail') + await kv.put(key, b"fail") with pytest.raises(InvalidKeyError): await kv.get(key) @@ -2810,13 +2716,7 @@ async def error_handler(e): await js.create_key_value(bucket="notok!") bucket = "TEST" - kv = await js.create_key_value( - bucket=bucket, - history=5, - ttl=3600, - description="Basic KV", - direct=False - ) + kv = await js.create_key_value(bucket=bucket, history=5, ttl=3600, description="Basic KV", direct=False) status = await kv.status() si = await js.stream_info("KV_TEST") @@ -2926,10 +2826,7 @@ async def error_handler(e): entry = await kv.get("age", revision=6) assert entry.value == b"fuga" assert entry.revision == 6 - assert ( - str(err.value) == - "nats: key not found: expected '$KV.TEST.age', but got '$KV.TEST.name'" - ) + assert str(err.value) == "nats: key not found: expected '$KV.TEST.age', but got '$KV.TEST.name'" with pytest.raises(KeyNotFoundError) as err: await kv.get("age", revision=5) @@ -2940,16 +2837,14 @@ async def error_handler(e): entry = await kv.get("name", revision=3) assert entry.value == b"bob" - with pytest.raises(KeyWrongLastSequenceError, - match="nats: wrong last sequence: 8"): + with pytest.raises(KeyWrongLastSequenceError, match="nats: wrong last sequence: 8"): await kv.create("age", b"1") # Now let's delete and recreate. await kv.delete("age", last=8) await kv.create("age", b"final") - with pytest.raises(KeyWrongLastSequenceError, - match="nats: wrong last sequence: 10"): + with pytest.raises(KeyWrongLastSequenceError, match="nats: wrong last sequence: 10"): await kv.create("age", b"1") entry = await kv.get("age") @@ -2972,13 +2867,7 @@ async def error_handler(e): js = nc.jetstream() bucket = "TEST" - kv = await js.create_key_value( - bucket=bucket, - history=5, - ttl=3600, - description="Direct KV", - direct=True - ) + kv = await js.create_key_value(bucket=bucket, history=5, ttl=3600, description="Direct KV", direct=True) si = await js.stream_info("KV_TEST") config = si.config @@ -3005,9 +2894,7 @@ async def error_handler(e): assert msg.data == b"333" # next by subject - msg = await js.get_msg( - "KV_TEST", seq=4, next=True, subject="$KV.TEST.C", direct=True - ) + msg = await js.get_msg("KV_TEST", seq=4, next=True, subject="$KV.TEST.C", direct=True) assert msg.data == b"33" @async_test @@ -3138,10 +3025,7 @@ async def error_handler(e): entry = await kv.get("age", revision=6) assert entry.value == b"fuga" assert entry.revision == 6 - assert ( - str(err.value) == - "nats: key not found: expected '$KV.TEST.age', but got '$KV.TEST.name'" - ) + assert str(err.value) == "nats: key not found: expected '$KV.TEST.age', but got '$KV.TEST.name'" with pytest.raises(KeyNotFoundError) as err: await kv.get("age", revision=5) @@ -3152,16 +3036,14 @@ async def error_handler(e): entry = await kv.get("name", revision=3) assert entry.value == b"bob" - with pytest.raises(KeyWrongLastSequenceError, - match="nats: wrong last sequence: 8"): + with pytest.raises(KeyWrongLastSequenceError, match="nats: wrong last sequence: 8"): await kv.create("age", b"1") # Now let's delete and recreate. await kv.delete("age", last=8) await kv.create("age", b"final") - with pytest.raises(KeyWrongLastSequenceError, - match="nats: wrong last sequence: 10"): + with pytest.raises(KeyWrongLastSequenceError, match="nats: wrong last sequence: 10"): await kv.create("age", b"1") entry = await kv.get("age") @@ -3170,10 +3052,7 @@ async def error_handler(e): with pytest.raises(Error) as err: await js.add_stream(name="mirror", mirror_direct=True) assert err.value.err_code == 10052 - assert ( - err.value.description == - "stream has no mirror but does have mirror direct" - ) + assert err.value.description == "stream has no mirror but does have mirror direct" await nc.close() @@ -3540,10 +3419,7 @@ async def error_handler(e): nc = await nats.connect(error_cb=error_handler) js = nc.jetstream() - kv = await js.create_key_value( - bucket="TEST_UPDATE", - republish=nats.js.api.RePublish(src=">", dest="bar.>") - ) + kv = await js.create_key_value(bucket="TEST_UPDATE", republish=nats.js.api.RePublish(src=">", dest="bar.>")) status = await kv.status() sinfo = await js.stream_info("KV_TEST_UPDATE") assert sinfo.config.republish is not None @@ -3585,9 +3461,7 @@ async def error_handler(e): js = nc.jetstream() # Create a KV bucket for testing - kv = await js.create_key_value( - bucket="TEST_LOGGING", history=5, ttl=3600 - ) + kv = await js.create_key_value(bucket="TEST_LOGGING", history=5, ttl=3600) # Add keys to the bucket await kv.put("hello", b"world") @@ -3603,7 +3477,6 @@ async def error_handler(e): class ObjectStoreTest(SingleJetStreamServerTestCase): - @async_test async def test_object_basics(self): errors = [] @@ -3618,9 +3491,7 @@ async def error_handler(e): with pytest.raises(nats.js.errors.InvalidBucketNameError): await js.create_object_store(bucket="notok!") - obs = await js.create_object_store( - bucket="OBJS", description="testing" - ) + obs = await js.create_object_store(bucket="OBJS", description="testing") assert obs._name == "OBJS" assert obs._stream == f"OBJ_OBJS" @@ -3644,10 +3515,7 @@ async def error_handler(e): assert sinfo.config.allow_direct == True assert sinfo.config.mirror_direct == False - bucketname = "".join( - random.SystemRandom().choice(string.ascii_letters) - for _ in range(10) - ) + bucketname = "".join(random.SystemRandom().choice(string.ascii_letters) for _ in range(10)) obs = await js.create_object_store(bucket=bucketname) assert obs._name == bucketname assert obs._stream == f"OBJ_{bucketname}" @@ -3694,9 +3562,7 @@ async def error_handler(e): h = sha256() h.update(filevalue) h.digest() - expected_digest = ( - f"SHA-256={base64.urlsafe_b64encode(h.digest()).decode('utf-8')}" - ) + expected_digest = f"SHA-256={base64.urlsafe_b64encode(h.digest()).decode('utf-8')}" assert info.digest == expected_digest assert info.deleted == False assert info.description == filedesc @@ -3806,10 +3672,7 @@ async def error_handler(e): obr = await obs.get("tmp", writeinto=f) assert obr.data == b"" assert obr.info.size == 1048609 - assert ( - obr.info.digest == - "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" - ) + assert obr.info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" w2 = tempfile.NamedTemporaryFile(delete=False) w2.close() @@ -3817,10 +3680,7 @@ async def error_handler(e): obr = await obs.get("tmp", writeinto=f.buffer) assert obr.data == b"" assert obr.info.size == 1048609 - assert ( - obr.info.digest == - "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" - ) + assert obr.info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" with open(w2.name) as f: result = f.read(-1) @@ -3896,7 +3756,9 @@ async def error_handler(e): obs = await js.create_object_store( "TEST_FILES", - config=nats.js.api.ObjectStoreConfig(description="multi_files", ), + config=nats.js.api.ObjectStoreConfig( + description="multi_files", + ), ) await obs.put("A", b"A") await obs.put("B", b"B") @@ -3940,7 +3802,9 @@ async def error_handler(e): obs = await js.create_object_store( "TEST_FILES", - config=nats.js.api.ObjectStoreConfig(description="multi_files", ), + config=nats.js.api.ObjectStoreConfig( + description="multi_files", + ), ) watcher = await obs.watch() @@ -4056,7 +3920,9 @@ async def error_handler(e): obs = await js.create_object_store( "TEST_LIST", - config=nats.js.api.ObjectStoreConfig(description="listing", ), + config=nats.js.api.ObjectStoreConfig( + description="listing", + ), ) await obs.put("A", b"AAA") await obs.put("B", b"BBB") @@ -4139,7 +4005,6 @@ async def error_handler(e): class ConsumerReplicasTest(SingleJetStreamServerTestCase): - @async_test async def test_number_of_consumer_replicas(self): nc = await nats.connect() @@ -4150,9 +4015,7 @@ async def test_number_of_consumer_replicas(self): await js.publish("test.replicas", f"{i}".encode()) # Create consumer - config = nats.js.api.ConsumerConfig( - num_replicas=1, durable_name="mycons" - ) + config = nats.js.api.ConsumerConfig(num_replicas=1, durable_name="mycons") cons = await js.add_consumer(stream="TESTREPLICAS", config=config) assert cons.config.num_replicas == 1 @@ -4161,7 +4024,6 @@ async def test_number_of_consumer_replicas(self): class AccountLimitsTest(SingleJetStreamServerLimitsTestCase): - @async_test async def test_account_limits(self): nc = await nats.connect() @@ -4171,30 +4033,17 @@ async def test_account_limits(self): with pytest.raises(BadRequestError) as err: await js.add_stream(name="limits", subjects=["limits"]) assert err.value.err_code == 10113 - assert ( - err.value.description == - "account requires a stream config to have max bytes set" - ) + assert err.value.description == "account requires a stream config to have max bytes set" with pytest.raises(BadRequestError) as err: - await js.add_stream( - name="limits", subjects=["limits"], max_bytes=65536 - ) + await js.add_stream(name="limits", subjects=["limits"], max_bytes=65536) assert err.value.err_code == 10122 - assert ( - err.value.description == - "stream max bytes exceeds account limit max stream bytes" - ) + assert err.value.description == "stream max bytes exceeds account limit max stream bytes" - si = await js.add_stream( - name="limits", subjects=["limits"], max_bytes=128 - ) + si = await js.add_stream(name="limits", subjects=["limits"], max_bytes=128) assert si.config.max_bytes == 128 - await js.publish( - "limits", - b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" - ) + await js.publish("limits", b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") si = await js.stream_info("limits") assert si.state.messages == 1 @@ -4304,40 +4153,38 @@ async def test_account_limits(self): api=nats.js.api.APIStats(total=6, errors=0), domain="ngs", tiers={ - "R1": - nats.js.api.Tier( - memory=0, - storage=6829550, - streams=1, - consumers=0, - limits=nats.js.api.AccountLimits( - max_memory=0, - max_storage=2000000000000, - max_streams=100, - max_consumers=1000, - max_ack_pending=-1, - memory_max_stream_bytes=-1, - storage_max_stream_bytes=-1, - max_bytes_required=True, - ), + "R1": nats.js.api.Tier( + memory=0, + storage=6829550, + streams=1, + consumers=0, + limits=nats.js.api.AccountLimits( + max_memory=0, + max_storage=2000000000000, + max_streams=100, + max_consumers=1000, + max_ack_pending=-1, + memory_max_stream_bytes=-1, + storage_max_stream_bytes=-1, + max_bytes_required=True, ), - "R3": - nats.js.api.Tier( - memory=0, - storage=0, - streams=0, - consumers=0, - limits=nats.js.api.AccountLimits( - max_memory=0, - max_storage=500000000000, - max_streams=25, - max_consumers=250, - max_ack_pending=-1, - memory_max_stream_bytes=-1, - storage_max_stream_bytes=-1, - max_bytes_required=True, - ), + ), + "R3": nats.js.api.Tier( + memory=0, + storage=0, + streams=0, + consumers=0, + limits=nats.js.api.AccountLimits( + max_memory=0, + max_storage=500000000000, + max_streams=25, + max_consumers=250, + max_ack_pending=-1, + memory_max_stream_bytes=-1, + storage_max_stream_bytes=-1, + max_bytes_required=True, ), + ), }, ) info = nats.js.api.AccountInfo.from_response(json.loads(blob)) @@ -4346,7 +4193,6 @@ async def test_account_limits(self): class V210FeaturesTest(SingleJetStreamServerTestCase): - @async_test async def test_subject_transforms(self): nc = await nats.connect() @@ -4357,9 +4203,7 @@ async def test_subject_transforms(self): await js.add_stream( name="TRANSFORMS", subjects=["test", "foo"], - subject_transform=nats.js.api.SubjectTransform( - src=">", dest="transformed.>" - ), + subject_transform=nats.js.api.SubjectTransform(src=">", dest="transformed.>"), ) for i in range(0, 10): await js.publish("test", f"{i}".encode()) @@ -4395,12 +4239,8 @@ async def test_subject_transforms(self): name="TRANSFORMS", # The source filters cannot overlap. subject_transforms=[ - nats.js.api.SubjectTransform( - src="transformed.>", dest="fromtest.transformed.>" - ), - nats.js.api.SubjectTransform( - src="foo.>", dest="fromtest.foo.>" - ), + nats.js.api.SubjectTransform(src="transformed.>", dest="fromtest.transformed.>"), + nats.js.api.SubjectTransform(src="foo.>", dest="fromtest.foo.>"), ], ) await js.add_stream( @@ -4427,9 +4267,7 @@ async def test_subject_transforms(self): transformed_source = nats.js.api.StreamSource( name="TRANSFORMS2", subject_transforms=[ - nats.js.api.SubjectTransform( - src=">", dest="fromtest.transformed.>" - ), + nats.js.api.SubjectTransform(src=">", dest="fromtest.transformed.>"), nats.js.api.SubjectTransform(src=">", dest="fromtest.foo.>"), ], ) @@ -4507,9 +4345,7 @@ async def test_stream_consumer_metadata(self): await js.add_consumer( "META", - config=nats.js.api.ConsumerConfig( - durable_name="b", metadata={"hello": "world"} - ), + config=nats.js.api.ConsumerConfig(durable_name="b", metadata={"hello": "world"}), ) cinfo = await js.consumer_info("META", "b") assert cinfo.config.metadata["hello"] == "world" @@ -4537,9 +4373,7 @@ async def test_fetch_pull_subscribe_bind(self): ) # Using named arguments. - psub = await js.pull_subscribe_bind( - stream=stream_name, consumer=cinfo.name - ) + psub = await js.pull_subscribe_bind(stream=stream_name, consumer=cinfo.name) msgs = await psub.fetch(1) msg = msgs[0] await msg.ack() @@ -4551,17 +4385,13 @@ async def test_fetch_pull_subscribe_bind(self): await msg.ack() # Using durable argument to refer to ephemeral is ok for backwards compatibility. - psub = await js.pull_subscribe_bind( - durable=cinfo.name, stream=stream_name - ) + psub = await js.pull_subscribe_bind(durable=cinfo.name, stream=stream_name) msgs = await psub.fetch(1) msg = msgs[0] await msg.ack() # stream, consumer name order - psub = await js.pull_subscribe_bind( - stream=stream_name, durable=cinfo.name - ) + psub = await js.pull_subscribe_bind(stream=stream_name, durable=cinfo.name) msgs = await psub.fetch(1) msg = msgs[0] await msg.ack() @@ -4569,9 +4399,7 @@ async def test_fetch_pull_subscribe_bind(self): assert msg.metadata.num_pending == 1 # name can also be used to refer to the consumer name - psub = await js.pull_subscribe_bind( - stream=stream_name, name=cinfo.name - ) + psub = await js.pull_subscribe_bind(stream=stream_name, name=cinfo.name) msgs = await psub.fetch(1) msg = msgs[0] await msg.ack() @@ -4591,7 +4419,6 @@ async def test_fetch_pull_subscribe_bind(self): class BadStreamNamesTest(SingleJetStreamServerTestCase): - @async_test async def test_add_stream_invalid_names(self): nc = NATS() @@ -4612,10 +4439,10 @@ async def test_add_stream_invalid_names(self): for name in invalid_names: with pytest.raises( - ValueError, - match= - (f"nats: stream name \\({re.escape(name)}\\) is invalid. Names cannot contain whitespace, '\\.', " - "'\\*', '>', path separators \\(forward or backward slash\\), or non-printable characters." - ), + ValueError, + match=( + f"nats: stream name \\({re.escape(name)}\\) is invalid. Names cannot contain whitespace, '\\.', " + "'\\*', '>', path separators \\(forward or backward slash\\), or non-printable characters." + ), ): await js.add_stream(name=name) diff --git a/nats/tests/test_micro_service.py b/nats/tests/test_micro_service.py index 9fa47fc70..6e079ab1b 100644 --- a/nats/tests/test_micro_service.py +++ b/nats/tests/test_micro_service.py @@ -10,7 +10,6 @@ class MicroServiceTest(SingleServerTestCase): - def test_invalid_service_name(self): with self.assertRaises(ValueError) as context: ServiceConfig(name="", version="0.1.0") @@ -32,7 +31,6 @@ def test_invalid_service_version(self): ) def test_invalid_endpoint_subject(self): - async def noop_handler(request: Request) -> None: pass @@ -67,9 +65,7 @@ async def add_handler(request: Request): metadata={"basic": "metadata"}, ) - endpoint_config = EndpointConfig( - name="default", subject="svc.add", handler=add_handler - ) + endpoint_config = EndpointConfig(name="default", subject="svc.add", handler=add_handler) for _ in range(5): svc = await add_service(nc, service_config) @@ -77,13 +73,7 @@ async def add_handler(request: Request): svcs.append(svc) for _ in range(50): - await nc.request( - "svc.add", - json.dumps({ - "x": 22, - "y": 11 - }).encode("utf-8") - ) + await nc.request("svc.add", json.dumps({"x": 22, "y": 11}).encode("utf-8")) for svc in svcs: info = svc.info() @@ -105,9 +95,7 @@ async def add_handler(request: Request): ping_responses = [] while True: try: - ping_responses.append( - await ping_subscription.next_msg(timeout=0.25) - ) + ping_responses.append(await ping_subscription.next_msg(timeout=0.25)) except: break @@ -122,51 +110,40 @@ async def add_handler(request: Request): stats_responses = [] while True: try: - stats_responses.append( - await stats_subscription.next_msg(timeout=0.25) - ) + stats_responses.append(await stats_subscription.next_msg(timeout=0.25)) except: break assert len(stats_responses) == 5 - stats = [ - ServiceStats.from_dict(json.loads(response.data.decode())) - for response in stats_responses - ] - total_requests = sum([ - stat.endpoints[0].num_requests for stat in stats - ]) + stats = [ServiceStats.from_dict(json.loads(response.data.decode())) for response in stats_responses] + total_requests = sum([stat.endpoints[0].num_requests for stat in stats]) assert total_requests == 50 @async_test async def test_add_service(self): - async def noop_handler(request: Request): pass sub_tests = { "no_endpoint": { - "service_config": - ServiceConfig( - name="test_service", - version="0.1.0", - metadata={"basic": "metadata"}, - ), - "expected_ping": - ServicePing( - id="*", - type="io.nats.micro.v1.ping_response", - name="test_service", - version="0.1.0", - metadata={"basic": "metadata"}, - ), + "service_config": ServiceConfig( + name="test_service", + version="0.1.0", + metadata={"basic": "metadata"}, + ), + "expected_ping": ServicePing( + id="*", + type="io.nats.micro.v1.ping_response", + name="test_service", + version="0.1.0", + metadata={"basic": "metadata"}, + ), }, "with_single_endpoint": { - "service_config": - ServiceConfig( - name="test_service", - version="0.1.0", - ), + "service_config": ServiceConfig( + name="test_service", + version="0.1.0", + ), "endpoint_configs": [ EndpointConfig( name="test", @@ -175,20 +152,18 @@ async def noop_handler(request: Request): metadata={"basic": "endpoint_metadata"}, ), ], - "expected_ping": - ServicePing( - id="*", - name="test_service", - version="0.1.0", - metadata={}, - ), + "expected_ping": ServicePing( + id="*", + name="test_service", + version="0.1.0", + metadata={}, + ), }, "with_multiple_endpoints": { - "service_config": - ServiceConfig( - name="test_service", - version="0.1.0", - ), + "service_config": ServiceConfig( + name="test_service", + version="0.1.0", + ), "endpoint_configs": [ EndpointConfig( name="foo", @@ -203,13 +178,12 @@ async def noop_handler(request: Request): handler=noop_handler, ), ], - "expected_ping": - ServicePing( - id="*", - name="test_service", - version="0.1.0", - metadata={}, - ), + "expected_ping": ServicePing( + id="*", + name="test_service", + version="0.1.0", + metadata={}, + ), }, } @@ -249,45 +223,30 @@ async def test_groups(self): "no_groups": { "name": "no groups", "endpoint_name": "foo", - "expected_endpoint": { - "name": "foo", - "subject": "foo" - }, + "expected_endpoint": {"name": "foo", "subject": "foo"}, }, "single_group": { "name": "single group", "endpoint_name": "foo", "group_names": ["g1"], - "expected_endpoint": { - "name": "foo", - "subject": "g1.foo" - }, + "expected_endpoint": {"name": "foo", "subject": "g1.foo"}, }, "single_empty_group": { "name": "single empty group", "endpoint_name": "foo", "group_names": [""], - "expected_endpoint": { - "name": "foo", - "subject": "foo" - }, + "expected_endpoint": {"name": "foo", "subject": "foo"}, }, "empty_groups": { "name": "empty groups", "endpoint_name": "foo", "group_names": ["", "g1", ""], - "expected_endpoint": { - "name": "foo", - "subject": "g1.foo" - }, + "expected_endpoint": {"name": "foo", "subject": "g1.foo"}, }, "multiple_groups": { "endpoint_name": "foo", "group_names": ["g1", "g2", "g3"], - "expected_endpoint": { - "name": "foo", - "subject": "g1.g2.g3.foo" - }, + "expected_endpoint": {"name": "foo", "subject": "g1.g2.g3.foo"}, }, } @@ -299,24 +258,18 @@ async def test_groups(self): async def noop_handler(_): pass - svc = await add_service( - nc, ServiceConfig(name="test_service", version="0.0.1") - ) + svc = await add_service(nc, ServiceConfig(name="test_service", version="0.0.1")) group = svc for group_name in data.get("group_names", []): group = group.add_group(name=group_name) - await group.add_endpoint( - name=data["endpoint_name"], handler=noop_handler - ) + await group.add_endpoint(name=data["endpoint_name"], handler=noop_handler) info = svc.info() assert info.endpoints assert len(info.endpoints) == 1 - expected_endpoint = EndpointInfo( - **data["expected_endpoint"], queue_group="q" - ) + expected_endpoint = EndpointInfo(**data["expected_endpoint"], queue_group="q") assert info.endpoints[0].name == expected_endpoint.name assert info.endpoints[0].subject == expected_endpoint.subject @@ -324,7 +277,6 @@ async def noop_handler(_): @async_test async def test_monitoring_handlers(self): - async def noop_handler(request: Request): pass @@ -385,14 +337,14 @@ async def noop_handler(request: Request): "description": None, "version": "0.1.0", "id": svc.id, - "endpoints": [{ - "name": "default", - "subject": "test.func", - "queue_group": "q", - "metadata": { - "basic": "schema" - }, - }], + "endpoints": [ + { + "name": "default", + "subject": "test.func", + "queue_group": "q", + "metadata": {"basic": "schema"}, + } + ], "metadata": {}, }, }, @@ -404,14 +356,14 @@ async def noop_handler(request: Request): "description": None, "version": "0.1.0", "id": svc.id, - "endpoints": [{ - "name": "default", - "subject": "test.func", - "queue_group": "q", - "metadata": { - "basic": "schema" - }, - }], + "endpoints": [ + { + "name": "default", + "subject": "test.func", + "queue_group": "q", + "metadata": {"basic": "schema"}, + } + ], "metadata": {}, }, }, @@ -423,14 +375,14 @@ async def noop_handler(request: Request): "description": None, "version": "0.1.0", "id": svc.id, - "endpoints": [{ - "name": "default", - "subject": "test.func", - "queue_group": "q", - "metadata": { - "basic": "schema" - }, - }], + "endpoints": [ + { + "name": "default", + "subject": "test.func", + "queue_group": "q", + "metadata": {"basic": "schema"}, + } + ], "metadata": {}, }, }, @@ -447,17 +399,15 @@ async def noop_handler(request: Request): @async_test async def test_service_stats(self): - async def handler(request: Request): await request.respond(b"ok") sub_tests = { "stats_handler": { - "service_config": - ServiceConfig( - name="test_service", - version="0.1.0", - ), + "service_config": ServiceConfig( + name="test_service", + version="0.1.0", + ), "endpoint_configs": [ EndpointConfig( name="default", @@ -468,12 +418,11 @@ async def handler(request: Request): ], }, "with_stats_handler": { - "service_config": - ServiceConfig( - name="test_service", - version="0.1.0", - stats_handler=lambda endpoint: {"key": "val"}, - ), + "service_config": ServiceConfig( + name="test_service", + version="0.1.0", + stats_handler=lambda endpoint: {"key": "val"}, + ), "endpoint_configs": [ EndpointConfig( name="default", @@ -482,16 +431,13 @@ async def handler(request: Request): metadata={"test": "value"}, ) ], - "expected_stats": { - "key": "val" - }, + "expected_stats": {"key": "val"}, }, "with_endpoint": { - "service_config": - ServiceConfig( - name="test_service", - version="0.1.0", - ), + "service_config": ServiceConfig( + name="test_service", + version="0.1.0", + ), "endpoint_configs": [ EndpointConfig( name="default", @@ -521,12 +467,8 @@ async def handler(request: Request): info = svc.info() - stats_subject = control_subject( - ServiceVerb.STATS, "test_service" - ) - stats_response = await nc.request( - stats_subject, b"", timeout=1 - ) + stats_subject = control_subject(ServiceVerb.STATS, "test_service") + stats_response = await nc.request(stats_subject, b"", timeout=1) stats = ServiceStats.from_dict(json.loads(stats_response.data)) assert len(stats.endpoints) == len(info.endpoints) @@ -554,14 +496,10 @@ async def test_request_respond(self): "expected_response": b"OK", }, "byte_response_with_headers": { - "respond_headers": { - "key": "value" - }, + "respond_headers": {"key": "value"}, "respond_data": b"OK", "expected_response": b"OK", - "expected_headers": { - "key": "value" - }, + "expected_headers": {"key": "value"}, }, } @@ -583,11 +521,7 @@ async def handler(request: Request): description="test service", ), ) - await svc.add_endpoint( - EndpointConfig( - name="default", subject="test.func", handler=handler - ) - ) + await svc.add_endpoint(EndpointConfig(name="default", subject="test.func", handler=handler)) response = await nc.request( "test.func", @@ -623,24 +557,20 @@ def test_control_subject(self): for name, data in sub_tests.items(): with self.subTest(name=name): - subject = control_subject( - data["verb"], name=data.get("name"), id=data.get("id") - ) + subject = control_subject(data["verb"], name=data.get("name"), id=data.get("id")) assert subject == data["expected_subject"] @async_test async def test_custom_queue_group(self): - async def noop_handler(request: Request): pass sub_tests = { "default_queue_group": { - "service_config": - ServiceConfig( - name="test_service", - version="0.0.1", - ), + "service_config": ServiceConfig( + name="test_service", + version="0.0.1", + ), "endpoint_configs": [ EndpointConfig( name="foo", @@ -652,12 +582,11 @@ async def noop_handler(request: Request): }, }, "custom_queue_group_on_service_config": { - "service_config": - ServiceConfig( - name="test_service", - version="0.0.1", - queue_group="custom", - ), + "service_config": ServiceConfig( + name="test_service", + version="0.0.1", + queue_group="custom", + ), "endpoint_configs": [ EndpointConfig( name="foo", @@ -670,12 +599,11 @@ async def noop_handler(request: Request): }, }, "endpoint_config_overriding_queue_groups": { - "service_config": - ServiceConfig( - name="test_service", - version="0.0.1", - queue_group="q-config", - ), + "service_config": ServiceConfig( + name="test_service", + version="0.0.1", + queue_group="q-config", + ), "endpoint_configs": [ EndpointConfig( name="foo", @@ -689,12 +617,11 @@ async def noop_handler(request: Request): }, "empty_queue_group_in_option_inherit_from_parent": { "name": "empty queue group in option, inherit from parent", - "service_config": - ServiceConfig( - name="test_service", - version="0.0.1", - queue_group="q-service", - ), + "service_config": ServiceConfig( + name="test_service", + version="0.0.1", + queue_group="q-service", + ), "endpoint_configs": [ EndpointConfig( name="foo", @@ -725,13 +652,9 @@ async def noop_handler(request: Request): info = svc.info() - assert len(info.endpoints - ) == len(data["expected_queue_groups"]) + assert len(info.endpoints) == len(data["expected_queue_groups"]) for endpoint in info.endpoints: - assert ( - endpoint.queue_group == data["expected_queue_groups"][ - endpoint.name] - ) + assert endpoint.queue_group == data["expected_queue_groups"][endpoint.name] await svc.stop() diff --git a/nats/tests/test_nuid.py b/nats/tests/test_nuid.py index 392422bfd..fc3de70b6 100644 --- a/nats/tests/test_nuid.py +++ b/nats/tests/test_nuid.py @@ -19,7 +19,6 @@ class NUIDTest(unittest.TestCase): - def setUp(self): super().setUp() @@ -31,18 +30,14 @@ def test_nuid_are_unique(self): nuid = NUID() entries = [nuid.next().decode() for i in range(500000)] counted_entries = Counter(entries) - repeated = [ - entry for entry, count in counted_entries.items() if count > 1 - ] + repeated = [entry for entry, count in counted_entries.items() if count > 1] self.assertEqual(len(repeated), 0) def test_nuid_are_very_unique(self): nuid = NUID() entries = [nuid.next().decode() for i in range(1000000)] counted_entries = Counter(entries) - repeated = [ - entry for entry, count in counted_entries.items() if count > 1 - ] + repeated = [entry for entry, count in counted_entries.items() if count > 1] self.assertEqual(len(repeated), 0) def test_subsequent_nuid_equal(self): diff --git a/nats/tests/test_parser.py b/nats/tests/test_parser.py index 8cd51b65d..df638b639 100644 --- a/nats/tests/test_parser.py +++ b/nats/tests/test_parser.py @@ -8,7 +8,6 @@ class MockNatsClient: - def __init__(self): self._subs = {} self._pongs = [] @@ -38,7 +37,6 @@ async def _process_info(self, info): class ProtocolParserTest(unittest.TestCase): - def setUp(self): self.loop = asyncio.new_event_loop() diff --git a/nats/tests/utils.py b/nats/tests/utils.py index d200937ab..73362a899 100644 --- a/nats/tests/utils.py +++ b/nats/tests/utils.py @@ -23,7 +23,6 @@ class NATSD: - def __init__( self, port=4222, @@ -63,11 +62,8 @@ def start(self): if Path(self.bin_name).is_file(): self.bin_name = Path(self.bin_name).absolute() # Path in `../scripts/install_nats.sh` - elif Path.home().joinpath(SERVER_BIN_DIR_NAME, - self.bin_name).is_file(): - self.bin_name = str( - Path.home().joinpath(SERVER_BIN_DIR_NAME, self.bin_name) - ) + elif Path.home().joinpath(SERVER_BIN_DIR_NAME, self.bin_name).is_file(): + self.bin_name = str(Path.home().joinpath(SERVER_BIN_DIR_NAME, self.bin_name)) # This directory contains binary elif Path(THIS_DIR).joinpath(self.bin_name).is_file(): self.bin_name = str(Path(THIS_DIR).joinpath(self.bin_name)) @@ -126,55 +122,38 @@ def start(self): if self.debug: self.proc = subprocess.Popen(cmd) else: - self.proc = subprocess.Popen( - cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL - ) + self.proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) if self.debug: if self.proc is None: - print( - "[\031[0;33mDEBUG\033[0;0m] Failed to start server listening on port %d started." - % self.port - ) + print("[\031[0;33mDEBUG\033[0;0m] Failed to start server listening on port %d started." % self.port) else: - print( - "[\033[0;33mDEBUG\033[0;0m] Server listening on port %d started." - % self.port - ) + print("[\033[0;33mDEBUG\033[0;0m] Server listening on port %d started." % self.port) return self.proc def stop(self): if self.debug: - print( - "[\033[0;33mDEBUG\033[0;0m] Server listening on %d will stop." - % self.port - ) + print("[\033[0;33mDEBUG\033[0;0m] Server listening on %d will stop." % self.port) if self.debug: if self.proc is None: - print( - "[\033[0;31mDEBUG\033[0;0m] Failed terminating server listening on port %d" - % self.port - ) + print("[\033[0;31mDEBUG\033[0;0m] Failed terminating server listening on port %d" % self.port) if self.proc.returncode is not None: if self.debug: print( - "[\033[0;31mDEBUG\033[0;0m] Server listening on port {port} finished running already with exit {ret}" - .format(port=self.port, ret=self.proc.returncode) + "[\033[0;31mDEBUG\033[0;0m] Server listening on port {port} finished running already with exit {ret}".format( + port=self.port, ret=self.proc.returncode + ) ) else: os.kill(self.proc.pid, signal.SIGKILL) self.proc.wait() if self.debug: - print( - "[\033[0;33mDEBUG\033[0;0m] Server listening on %d was stopped." - % self.port - ) + print("[\033[0;33mDEBUG\033[0;0m] Server listening on %d was stopped." % self.port) class SingleServerTestCase(unittest.TestCase): - def setUp(self): self.server_pool = [] self.loop = asyncio.new_event_loop() @@ -191,7 +170,6 @@ def tearDown(self): class MultiServerAuthTestCase(unittest.TestCase): - def setUp(self): super().setUp() self.server_pool = [] @@ -199,9 +177,7 @@ def setUp(self): server1 = NATSD(port=4223, user="foo", password="bar", http_port=8223) self.server_pool.append(server1) - server2 = NATSD( - port=4224, user="hoge", password="fuga", http_port=8224 - ) + server2 = NATSD(port=4224, user="hoge", password="fuga", http_port=8224) self.server_pool.append(server2) for natsd in self.server_pool: start_natsd(natsd) @@ -213,7 +189,6 @@ def tearDown(self): class MultiServerAuthTokenTestCase(unittest.TestCase): - def setUp(self): super().setUp() self.server_pool = [] @@ -235,7 +210,6 @@ def tearDown(self): class TLSServerTestCase(unittest.TestCase): - def setUp(self): super().setUp() self.loop = asyncio.new_event_loop() @@ -243,9 +217,7 @@ def setUp(self): self.natsd = NATSD(port=4224, tls=True) start_natsd(self.natsd) - self.ssl_ctx = ssl.create_default_context( - purpose=ssl.Purpose.SERVER_AUTH - ) + self.ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) # self.ssl_ctx.protocol = ssl.PROTOCOL_TLSv1_2 self.ssl_ctx.load_verify_locations(get_config_file("certs/ca.pem")) self.ssl_ctx.load_cert_chain( @@ -259,20 +231,14 @@ def tearDown(self): class TLSServerHandshakeFirstTestCase(unittest.TestCase): - def setUp(self): super().setUp() self.loop = asyncio.new_event_loop() - self.natsd = NATSD( - port=4224, - config_file=get_config_file("conf/tls_handshake_first.conf") - ) + self.natsd = NATSD(port=4224, config_file=get_config_file("conf/tls_handshake_first.conf")) start_natsd(self.natsd) - self.ssl_ctx = ssl.create_default_context( - purpose=ssl.Purpose.SERVER_AUTH - ) + self.ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) # self.ssl_ctx.protocol = ssl.PROTOCOL_TLSv1_2 self.ssl_ctx.load_verify_locations(get_config_file("certs/ca.pem")) self.ssl_ctx.load_cert_chain( @@ -286,26 +252,19 @@ def tearDown(self): class MultiTLSServerAuthTestCase(unittest.TestCase): - def setUp(self): super().setUp() self.server_pool = [] self.loop = asyncio.new_event_loop() - server1 = NATSD( - port=4223, user="foo", password="bar", http_port=8223, tls=True - ) + server1 = NATSD(port=4223, user="foo", password="bar", http_port=8223, tls=True) self.server_pool.append(server1) - server2 = NATSD( - port=4224, user="hoge", password="fuga", http_port=8224, tls=True - ) + server2 = NATSD(port=4224, user="hoge", password="fuga", http_port=8224, tls=True) self.server_pool.append(server2) for natsd in self.server_pool: start_natsd(natsd) - self.ssl_ctx = ssl.create_default_context( - purpose=ssl.Purpose.SERVER_AUTH - ) + self.ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) # self.ssl_ctx.protocol = ssl.PROTOCOL_TLSv1_2 self.ssl_ctx.load_verify_locations(get_config_file("certs/ca.pem")) self.ssl_ctx.load_cert_chain( @@ -320,7 +279,6 @@ def tearDown(self): class ClusteringTestCase(unittest.TestCase): - def setUp(self): super().setUp() self.server_pool = [] @@ -370,7 +328,6 @@ def tearDown(self): class ClusteringDiscoveryAuthTestCase(unittest.TestCase): - def setUp(self): super().setUp() self.server_pool = [] @@ -420,14 +377,11 @@ def tearDown(self): class NkeysServerTestCase(unittest.TestCase): - def setUp(self): self.server_pool = [] self.loop = asyncio.new_event_loop() - server = NATSD( - port=4222, config_file=get_config_file("nkeys/nkeys_server.conf") - ) + server = NATSD(port=4222, config_file=get_config_file("nkeys/nkeys_server.conf")) self.server_pool.append(server) for natsd in self.server_pool: start_natsd(natsd) @@ -439,16 +393,12 @@ def tearDown(self): class TrustedServerTestCase(unittest.TestCase): - def setUp(self): super().setUp() self.server_pool = [] self.loop = asyncio.new_event_loop() - server = NATSD( - port=4222, - config_file=(get_config_file("nkeys/resolver_preload.conf")) - ) + server = NATSD(port=4222, config_file=(get_config_file("nkeys/resolver_preload.conf"))) self.server_pool.append(server) for natsd in self.server_pool: start_natsd(natsd) @@ -460,7 +410,6 @@ def tearDown(self): class SingleJetStreamServerTestCase(unittest.TestCase): - def setUp(self): self.server_pool = [] self.loop = asyncio.new_event_loop() @@ -478,7 +427,6 @@ def tearDown(self): class SingleWebSocketServerTestCase(unittest.TestCase): - def setUp(self): self.server_pool = [] self.loop = asyncio.new_event_loop() @@ -495,25 +443,18 @@ def tearDown(self): class SingleWebSocketTLSServerTestCase(unittest.TestCase): - def setUp(self): self.server_pool = [] self.loop = asyncio.new_event_loop() - self.ssl_ctx = ssl.create_default_context( - purpose=ssl.Purpose.SERVER_AUTH - ) + self.ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) self.ssl_ctx.load_verify_locations(get_config_file("certs/ca.pem")) self.ssl_ctx.load_cert_chain( certfile=get_config_file("certs/client-cert.pem"), keyfile=get_config_file("certs/client-key.pem"), ) - server = NATSD( - port=4222, - tls=True, - config_file=get_config_file("conf/ws_tls.conf") - ) + server = NATSD(port=4222, tls=True, config_file=get_config_file("conf/ws_tls.conf")) self.server_pool.append(server) for natsd in self.server_pool: start_natsd(natsd) @@ -525,7 +466,6 @@ def tearDown(self): class SingleJetStreamServerLimitsTestCase(unittest.TestCase): - def setUp(self): self.server_pool = [] self.loop = asyncio.new_event_loop() @@ -547,14 +487,11 @@ def tearDown(self): class NoAuthUserServerTestCase(unittest.TestCase): - def setUp(self): self.server_pool = [] self.loop = asyncio.new_event_loop() - server = NATSD( - port=4555, config_file=get_config_file("conf/no_auth_user.conf") - ) + server = NATSD(port=4555, config_file=get_config_file("conf/no_auth_user.conf")) self.server_pool.append(server) for natsd in self.server_pool: start_natsd(natsd) @@ -590,36 +527,27 @@ def get_config_file(file_path): def async_test(test_case_fun, timeout=5): - @wraps(test_case_fun) def wrapper(test_case, *args, **kw): asyncio.set_event_loop(test_case.loop) - return asyncio.run( - asyncio.wait_for(test_case_fun(test_case, *args, **kw), timeout) - ) + return asyncio.run(asyncio.wait_for(test_case_fun(test_case, *args, **kw), timeout)) return wrapper def async_long_test(test_case_fun, timeout=20): - @wraps(test_case_fun) def wrapper(test_case, *args, **kw): asyncio.set_event_loop(test_case.loop) - return asyncio.run( - asyncio.wait_for(test_case_fun(test_case, *args, **kw), timeout) - ) + return asyncio.run(asyncio.wait_for(test_case_fun(test_case, *args, **kw), timeout)) return wrapper def async_debug_test(test_case_fun, timeout=3600): - @wraps(test_case_fun) def wrapper(test_case, *args, **kw): asyncio.set_event_loop(test_case.loop) - return asyncio.run( - asyncio.wait_for(test_case_fun(test_case, *args, **kw), timeout) - ) + return asyncio.run(asyncio.wait_for(test_case_fun(test_case, *args, **kw), timeout)) return wrapper diff --git a/pyproject.toml b/pyproject.toml index 820bda4b8..4c175d257 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,17 +1,13 @@ -[dependency-groups] -dev = [ +[tool.uv] +dev-dependencies = [ "pytest>=7.0.0", "pytest-asyncio>=0.21.0", "pytest-cov>=7.0.0", "pytest-xdist>=3.0.0", "mypy>=1.0.0", - "yapf>=0.40.0", "ruff>=0.1.0", - "isort>=5.0.0", "flake8>=7.0.0", ] - -[tool.uv] workspace = { members = ["nats", "nats-server"] } [tool.mypy] @@ -22,14 +18,6 @@ follow_imports = "silent" show_error_codes = true check_untyped_defs = false -[tool.yapf] -split_before_first_argument = true -dedent_closing_brackets = true -coalesce_brackets = true -allow_split_before_dict_value = false -indent_dictionary_value = true -split_before_expression_after_opening_paren = true - [tool.ruff] line-length = 120 target-version = "py37" @@ -38,11 +26,9 @@ target-version = "py37" select = ["E", "F", "W", "I"] ignore = ["E501"] -[tool.isort] -combine_as_imports = true -multi_line_output = 3 -include_trailing_comma = true -src_paths = ["nats/src", "nats/tests"] +[tool.ruff.format] +quote-style = "double" +indent-style = "space" [tool.pytest.ini_options] asyncio_mode = "auto" diff --git a/uv.lock b/uv.lock index 653e7d24c..a0d4a7f3b 100644 --- a/uv.lock +++ b/uv.lock @@ -15,14 +15,12 @@ members = [ [manifest.dependency-groups] dev = [ - { name = "isort", specifier = ">=5.0.0" }, { name = "mypy", specifier = ">=1.0.0" }, { name = "pytest", specifier = ">=7.0.0" }, { name = "pytest-asyncio", specifier = ">=0.21.0" }, { name = "pytest-cov", specifier = ">=7.0.0" }, { name = "pytest-xdist", specifier = ">=3.0.0" }, { name = "ruff", specifier = ">=0.1.0" }, - { name = "yapf", specifier = ">=0.40.0" }, ] [[package]] @@ -405,15 +403,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, ] -[[package]] -name = "isort" -version = "6.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1e/82/fa43935523efdfcce6abbae9da7f372b627b27142c3419fcf13bf5b0c397/isort-6.1.0.tar.gz", hash = "sha256:9b8f96a14cfee0677e78e941ff62f03769a06d412aabb9e2a90487b3b7e8d481", size = 824325, upload-time = "2025-10-01T16:26:45.027Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7f/cc/9b681a170efab4868a032631dea1e8446d8ec718a7f657b94d49d1a12643/isort-6.1.0-py3-none-any.whl", hash = "sha256:58d8927ecce74e5087aef019f778d4081a3b6c98f15a80ba35782ca8a2097784", size = 94329, upload-time = "2025-10-01T16:26:43.291Z" }, -] - [[package]] name = "multidict" version = "6.6.4" @@ -598,15 +587,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" }, ] -[[package]] -name = "platformdirs" -version = "4.4.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/23/e8/21db9c9987b0e728855bd57bff6984f67952bea55d6f75e055c46b5383e8/platformdirs-4.4.0.tar.gz", hash = "sha256:ca753cf4d81dc309bc67b0ea38fd15dc97bc30ce419a7f58d13eb3bf14c4febf", size = 21634, upload-time = "2025-08-26T14:32:04.268Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/40/4b/2028861e724d3bd36227adfa20d3fd24c3fc6d52032f4a93c133be5d17ce/platformdirs-4.4.0-py3-none-any.whl", hash = "sha256:abd01743f24e5287cd7a5db3752faf1a2d65353f38ec26d98e25a6db65958c85", size = 18654, upload-time = "2025-08-26T14:32:02.735Z" }, -] - [[package]] name = "pluggy" version = "1.6.0" @@ -874,18 +854,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, ] -[[package]] -name = "yapf" -version = "0.43.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "platformdirs" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/23/97/b6f296d1e9cc1ec25c7604178b48532fa5901f721bcf1b8d8148b13e5588/yapf-0.43.0.tar.gz", hash = "sha256:00d3aa24bfedff9420b2e0d5d9f5ab6d9d4268e72afbf59bb3fa542781d5218e", size = 254907, upload-time = "2024-11-14T00:11:41.584Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/37/81/6acd6601f61e31cfb8729d3da6d5df966f80f374b78eff83760714487338/yapf-0.43.0-py3-none-any.whl", hash = "sha256:224faffbc39c428cb095818cf6ef5511fdab6f7430a10783fdfb292ccf2852ca", size = 256158, upload-time = "2024-11-14T00:11:39.37Z" }, -] - [[package]] name = "yarl" version = "1.20.1"