Skip to content

Commit

Permalink
Merge pull request #347 from nats-io/kv-updates
Browse files Browse the repository at this point in the history
Updates to KV implementation
  • Loading branch information
wallyqs authored Sep 6, 2022
2 parents 0bd775d + a3b1730 commit 2ba281c
Show file tree
Hide file tree
Showing 7 changed files with 708 additions and 78 deletions.
70 changes: 69 additions & 1 deletion nats/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
Subscription,
)

__version__ = '2.1.7'
__version__ = '2.2.0'
__lang__ = 'python3'
_logger = logging.getLogger(__name__)
PROTOCOL = 1
Expand Down Expand Up @@ -115,6 +115,61 @@ class Srv:
did_connect: bool = False
discovered: bool = False
tls_name: Optional[str] = None
server_version: Optional[str] = None


class ServerVersion:

def __init__(self, server_version: str):
self._server_version = server_version
self._major_version = None
self._minor_version = None
self._patch_version = None
self._dev_version = None

def parse_version(self):
v = (self._server_version).split('-')
if len(v) > 1:
self._dev_version = v[1]
tokens = v[0].split('.')
n = len(tokens)
if n > 1:
self._major_version = int(tokens[0])
if n > 2:
self._minor_version = int(tokens[1])
if n > 3:
self._patch_version = int(tokens[2])

@property
def major(self) -> int:
version = self._major_version
if not version:
self.parse_version()
return self._major_version

@property
def minor(self) -> int:
version = self._minor_version
if not version:
self.parse_version()
return self._minor_version

@property
def patch(self) -> int:
version = self._patch_version
if not version:
self.parse_version()
return self._patch_version

@property
def dev(self) -> int:
version = self._dev_version
if not version:
self.parse_version()
return self._dev_version

def __repr__(self) -> str:
return f"<nats server v{self._server_version}>"


async def _default_error_callback(ex: Exception) -> None:
Expand Down Expand Up @@ -1100,6 +1155,16 @@ def is_draining(self) -> bool:
def is_draining_pubs(self) -> bool:
return self._status == Client.DRAINING_PUBS

@property
def connected_server_version(self) -> ServerVersion:
"""
Returns the ServerVersion of the server to which the client
is currently connected.
"""
if self._current_server and self._current_server.server_version:
return ServerVersion(self._current_server.server_version)
return ServerVersion("0.0.0-unknown")

async def _send_command(self, cmd: bytes, priority: bool = False) -> None:
if priority:
self._pending.insert(0, cmd)
Expand Down Expand Up @@ -1777,6 +1842,9 @@ async def _process_connect_init(self) -> None:
self._server_info = srv_info
self._process_info(srv_info, initial_connection=True)

if 'version' in self._server_info:
self._current_server.server_version = self._server_info['version']

if 'max_payload' in self._server_info:
self._max_payload = self._server_info["max_payload"]

Expand Down
12 changes: 11 additions & 1 deletion nats/js/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def _to_nanoseconds(val: Optional[float]) -> Optional[int]:
"""Convert the value from seconds to nanoseconds.
"""
if val is None:
return None
# We use 0 to avoid sending null to Go servers.
return 0
return int(val * _NANOSECOND)

@classmethod
Expand Down Expand Up @@ -235,17 +236,22 @@ class StreamConfig(Base):
deny_delete: bool = False
deny_purge: bool = False
allow_rollup_hdrs: bool = False
allow_direct: Optional[bool] = None

@classmethod
def from_response(cls, resp: Dict[str, Any]):
cls._convert_nanoseconds(resp, 'max_age')
cls._convert_nanoseconds(resp, 'duplicate_window')
cls._convert(resp, 'placement', Placement)
cls._convert(resp, 'mirror', StreamSource)
cls._convert(resp, 'sources', StreamSource)
return super().from_response(resp)

def as_dict(self) -> Dict[str, object]:
result = super().as_dict()
result['duplicate_window'] = self._to_nanoseconds(
self.duplicate_window
)
result['max_age'] = self._to_nanoseconds(self.max_age)
return result

Expand Down Expand Up @@ -468,6 +474,7 @@ class RawStreamMsg(Base):
data: Optional[bytes] = None
hdrs: Optional[bytes] = None
headers: Optional[dict] = None
stream: Optional[str] = None
# TODO: Add 'time'

@property
Expand Down Expand Up @@ -495,6 +502,9 @@ class KeyValueConfig(Base):
max_bytes: Optional[int] = None
storage: Optional[StorageType] = None
replicas: int = 1
placement: Optional[Placement] = None
republish: Optional[bool] = None
direct: Optional[bool] = None

def as_dict(self) -> Dict[str, object]:
result = super().as_dict()
Expand Down
64 changes: 20 additions & 44 deletions nats/js/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 The NATS Authors
# Copyright 2021-2022 The NATS Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -13,7 +13,6 @@
#

import asyncio
import base64
import json
import time
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
Expand Down Expand Up @@ -954,40 +953,6 @@ async def _fetch_n(

return msgs

#############################
# #
# JetStream Manager Context #
# #
#############################

async def get_last_msg(
self,
stream_name: str,
subject: str,
) -> api.RawStreamMsg:
"""
get_last_msg retrieves a message from a stream.
"""
req_subject = f"{self._prefix}.STREAM.MSG.GET.{stream_name}"
req = {'last_by_subj': subject}
data = json.dumps(req)
resp = await self._api_request(
req_subject, data.encode(), timeout=self._timeout
)
raw_msg = api.RawStreamMsg.from_response(resp['message'])
if raw_msg.hdrs:
hdrs = base64.b64decode(raw_msg.hdrs)
raw_headers = hdrs[NATS_HDR_LINE_SIZE + _CRLF_LEN_:]
parsed_headers = self._jsm._hdr_parser.parsebytes(raw_headers)
headers = None
if len(parsed_headers.items()) > 0:
headers = {}
for k, v in parsed_headers.items():
headers[k] = v
raw_msg.headers = headers

return raw_msg

######################
# #
# KeyValue Context #
Expand All @@ -1008,6 +973,7 @@ async def key_value(self, bucket: str) -> KeyValue:
stream=stream,
pre=KV_PRE_TEMPLATE.format(bucket=bucket),
js=self,
direct=si.config.allow_direct
)

async def create_key_value(
Expand All @@ -1022,27 +988,37 @@ async def create_key_value(
config = api.KeyValueConfig(bucket=params["bucket"])
config = config.evolve(**params)

duplicate_window = 2 * 60 # 2 minutes
if config.ttl and config.ttl < duplicate_window:
duplicate_window = config.ttl

stream = api.StreamConfig(
name=KV_STREAM_TEMPLATE.format(bucket=config.bucket),
description=None,
description=config.description,
subjects=[f"$KV.{config.bucket}.>"],
max_msgs_per_subject=config.history,
max_bytes=config.max_bytes,
allow_direct=config.direct,
allow_rollup_hdrs=True,
deny_delete=True,
discard=api.DiscardPolicy.NEW,
duplicate_window=duplicate_window,
max_age=config.ttl,
max_bytes=config.max_bytes,
max_consumers=-1,
max_msg_size=config.max_value_size,
storage=config.storage,
max_msgs=-1,
max_msgs_per_subject=config.history,
num_replicas=config.replicas,
allow_rollup_hdrs=True,
deny_delete=True,
storage=config.storage,
)
await self.add_stream(stream)

si = await self.add_stream(stream)
assert stream.name is not None

return KeyValue(
name=config.bucket,
stream=stream.name,
pre=KV_PRE_TEMPLATE.format(bucket=config.bucket),
js=self,
direct=si.config.allow_direct
)

async def delete_key_value(self, bucket: str) -> bool:
Expand Down
46 changes: 41 additions & 5 deletions nats/js/errors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2016-2021 The NATS Authors
# Copyright 2016-2022 The NATS Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -37,7 +37,7 @@ def __str__(self) -> str:
return f"nats: JetStream.{self.__class__.__name__} {desc}"


@dataclass
@dataclass(repr=False, init=False)
class APIError(Error):
"""
An Error that is the result of interacting with NATS JetStream.
Expand Down Expand Up @@ -117,7 +117,7 @@ class NotFoundError(APIError):

class BadRequestError(APIError):
"""
A 400 error
A 400 error.
"""
pass

Expand Down Expand Up @@ -162,11 +162,18 @@ class BucketNotFoundError(NotFoundError):
pass


class BadBucketError(Error):
class BadBucketError(APIError):
pass


class KeyDeletedError(Error):
class KeyValueError(APIError):
"""
Raised when there is an issue interacting with the KeyValue store.
"""
pass


class KeyDeletedError(KeyValueError, NotFoundError):
"""
Raised when trying to get a key that was deleted from a JetStream KeyValue store.
"""
Expand All @@ -177,3 +184,32 @@ def __init__(self, entry=None, op=None) -> None:

def __str__(self) -> str:
return "nats: key was deleted"


class KeyNotFoundError(KeyValueError, NotFoundError):
"""
Raised when trying to get a key that does not exists from a JetStream KeyValue store.
"""

def __init__(self, entry=None, op=None, message=None) -> None:
self.entry = entry
self.op = op
self.message = message

def __str__(self) -> str:
s = "nats: key not found"
if self.message:
s += f": {self.message}"
return s


class KeyWrongLastSequenceError(KeyValueError, BadRequestError):
"""
Raised when trying to update a key with the wrong last sequence.
"""

def __init__(self, description=None) -> None:
self.description = description

def __str__(self) -> str:
return f"nats: {self.description}"
Loading

0 comments on commit 2ba281c

Please sign in to comment.