Skip to content

Commit

Permalink
add escape hatch for custom JSON serialization (#1955)
Browse files Browse the repository at this point in the history
* add escape hatch for custom JSON serialization

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

* fix pydocstyle

* fix whitespace

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tomer Nosrati <tomer.nosrati@gmail.com>
Co-authored-by: Ville Lindholm <ville@lindholm.dev>
  • Loading branch information
4 people committed Mar 13, 2024
1 parent 85111fc commit 25d02e6
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 25 deletions.
66 changes: 42 additions & 24 deletions kombu/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def default(self, o):

for t, (marker, encoder) in _encoders.items():
if isinstance(o, t):
return _as(marker, encoder(o))
return (
encoder(o) if marker is None else _as(marker, encoder(o))
)

# Bytes is slightly trickier, so we cannot put them directly
# into _encoders, because we use two formats: bytes, and base64.
Expand All @@ -50,7 +52,11 @@ def _as(t: str, v: Any):


def dumps(
s, _dumps=json.dumps, cls=JSONEncoder, default_kwargs=None, **kwargs
s,
_dumps=json.dumps,
cls=JSONEncoder,
default_kwargs=None,
**kwargs
):
"""Serialize object to json string."""
default_kwargs = default_kwargs or {}
Expand Down Expand Up @@ -94,35 +100,47 @@ def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook):

def register_type(
t: type[T],
marker: str,
marker: str | None,
encoder: Callable[[T], EncodedT],
decoder: Callable[[EncodedT], T],
decoder: Callable[[EncodedT], T] = lambda d: d,
):
"""Add support for serializing/deserializing native python type."""
"""Add support for serializing/deserializing native python type.
If marker is `None`, the encoding is a pure transformation and the result
is not placed in an envelope, so `decoder` is unnecessary. Decoding must
instead be handled outside this library.
"""
_encoders[t] = (marker, encoder)
_decoders[marker] = decoder
if marker is not None:
_decoders[marker] = decoder


_encoders: dict[type, tuple[str, EncoderT]] = {}
_encoders: dict[type, tuple[str | None, EncoderT]] = {}
_decoders: dict[str, DecoderT] = {
"bytes": lambda o: o.encode("utf-8"),
"base64": lambda o: base64.b64decode(o.encode("utf-8")),
}

# NOTE: datetime should be registered before date,
# because datetime is also instance of date.
register_type(datetime, "datetime", datetime.isoformat, datetime.fromisoformat)
register_type(
date,
"date",
lambda o: o.isoformat(),
lambda o: datetime.fromisoformat(o).date(),
)
register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat)
register_type(Decimal, "decimal", str, Decimal)
register_type(
uuid.UUID,
"uuid",
lambda o: {"hex": o.hex},
lambda o: uuid.UUID(**o),
)

def _register_default_types():
# NOTE: datetime should be registered before date,
# because datetime is also instance of date.
register_type(datetime, "datetime", datetime.isoformat,
datetime.fromisoformat)
register_type(
date,
"date",
lambda o: o.isoformat(),
lambda o: datetime.fromisoformat(o).date(),
)
register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat)
register_type(Decimal, "decimal", str, Decimal)
register_type(
uuid.UUID,
"uuid",
lambda o: {"hex": o.hex},
lambda o: uuid.UUID(**o),
)


_register_default_types()
43 changes: 42 additions & 1 deletion t/unit/utils/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import uuid
from collections import namedtuple
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal

Expand All @@ -11,7 +12,8 @@
from hypothesis import strategies as st

from kombu.utils.encoding import str_to_bytes
from kombu.utils.json import dumps, loads
from kombu.utils.json import (_register_default_types, dumps, loads,
register_type)

if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
Expand All @@ -28,6 +30,10 @@ def __json__(self):


class test_JSONEncoder:
@pytest.fixture(autouse=True)
def reset_registered_types(self):
_register_default_types()

@pytest.mark.freeze_time("2015-10-21")
def test_datetime(self):
now = datetime.utcnow()
Expand Down Expand Up @@ -82,6 +88,41 @@ def test_UUID(self):
assert loaded_value == {'u': id}
assert loaded_value["u"].version == id.version

def test_register_type_overrides_defaults(self):
# This type is already registered by default, let's override it
register_type(uuid.UUID, "uuid", lambda o: "custom", lambda o: o)
value = uuid.uuid4()
loaded_value = loads(dumps({'u': value}))
assert loaded_value == {'u': "custom"}

def test_register_type_with_new_type(self):
# Guaranteed never before seen type
@dataclass()
class SomeType:
a: int

register_type(SomeType, "some_type", lambda o: "custom", lambda o: o)
value = SomeType(42)
loaded_value = loads(dumps({'u': value}))
assert loaded_value == {'u': "custom"}

def test_register_type_with_empty_marker(self):
register_type(
datetime,
None,
lambda o: o.isoformat(),
lambda o: "should never be used"
)
now = datetime.utcnow()
serialized_str = dumps({'now': now})
deserialized_value = loads(serialized_str)

assert "__type__" not in serialized_str
assert "__value__" not in serialized_str

# Check that there is no extra deserialization happening
assert deserialized_value == {'now': now.isoformat()}

def test_default(self):
with pytest.raises(TypeError):
dumps({'o': object()})
Expand Down

0 comments on commit 25d02e6

Please sign in to comment.