diff --git a/src/async_kernel/compat/json.py b/src/async_kernel/compat/json.py index 9e78bfff..3e57e0d7 100644 --- a/src/async_kernel/compat/json.py +++ b/src/async_kernel/compat/json.py @@ -9,9 +9,31 @@ if TYPE_CHECKING: from collections.abc import Callable - pack_json_bytes: Callable[[Any], bytes] - pack_json_str: Callable[[Any], str] - unpack_json: Callable[[str | bytes], Any] + from jupyter_client.jsonutil import json_default + + def pack_json_bytes(obj: Any, /, default: Callable[[Any], Any] | None = json_default) -> bytes: + """ + Pack obj into json serialized bytes. + + Args: + data: The data to serialize. + default: A function that should return a serializable version of obj or raise TypeError. + """ + ... + + def pack_json_str(obj: Any, /, default: Callable[[Any], Any] | None = json_default) -> str: + """ + Pack obj into json serialized string. + + Args: + data: The data to serialize. + default: A function that should return a serializable version of obj or raise TypeError. + """ + ... + + def unpack_json(data: str | bytes, /) -> Any: + "Deserialize data in a Python object." + ... if importlib.util.find_spec("jupyter_client"): @@ -19,11 +41,11 @@ from jupyter_client.jsonutil import json_default - def _jc_pack_bytes(data: Any) -> bytes: - return json.dumps(data, default=json_default).encode() + def _jc_pack_bytes(obj: Any, default: Callable[[Any], Any] | None = json_default) -> bytes: + return json.dumps(obj, default=default).encode() - def _jc_pack_str(data: Any) -> str: - return json.dumps(data, default=json_default) + def _jc_pack_str(obj: Any, default: Callable[[Any], Any] | None = json_default) -> str: + return json.dumps(obj, allow_nan=False, default=default) pack_json_bytes, pack_json_str, unpack_json = _jc_pack_bytes, _jc_pack_str, json.loads @@ -36,10 +58,10 @@ def _jc_pack_str(data: Any) -> str: ORJSON_OPTION = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z - def _oj_pack_bytes(data) -> bytes: - return orjson.dumps(data, default=json_default, option=ORJSON_OPTION) + def _oj_pack_bytes(obj: Any, default: Callable[[Any], Any] | None = json_default) -> bytes: + return orjson.dumps(obj, default=default, option=ORJSON_OPTION) - def _oj_pack_str(data) -> str: - return orjson.dumps(data, default=json_default, option=ORJSON_OPTION).decode() + def _oj_pack_str(obj: Any, default: Callable[[Any], Any] | None = json_default) -> str: + return orjson.dumps(obj, default=default, option=ORJSON_OPTION).decode() pack_json_bytes, pack_json_str, unpack_json = _oj_pack_bytes, _oj_pack_str, orjson.loads