diff --git a/msgpack/_unpacker.pyx b/msgpack/_unpacker.pyx index 3727f50c..0935d15a 100644 --- a/msgpack/_unpacker.pyx +++ b/msgpack/_unpacker.pyx @@ -38,6 +38,7 @@ cdef extern from "unpack.h": Py_ssize_t max_array_len Py_ssize_t max_map_len Py_ssize_t max_ext_len + PyObject* memo ctypedef struct unpack_context: msgpack_user user @@ -61,7 +62,7 @@ cdef inline init_ctx(unpack_context *ctx, const char* encoding, const char* unicode_errors, Py_ssize_t max_str_len, Py_ssize_t max_bin_len, Py_ssize_t max_array_len, Py_ssize_t max_map_len, - Py_ssize_t max_ext_len): + Py_ssize_t max_ext_len, object d): unpack_init(ctx) ctx.user.use_list = use_list ctx.user.raw = raw @@ -101,6 +102,9 @@ cdef inline init_ctx(unpack_context *ctx, ctx.user.encoding = encoding ctx.user.unicode_errors = unicode_errors + Py_INCREF(d) + ctx.user.memo = d + def default_read_extended_type(typecode, data): raise NotImplementedError("Cannot decode extended type with typecode=%d" % typecode) @@ -195,9 +199,10 @@ def unpackb(object packed, object object_hook=None, object list_hook=None, max_ext_len = buf_len try: + memo = {} init_ctx(&ctx, object_hook, object_pairs_hook, list_hook, ext_hook, use_list, raw, strict_map_key, cenc, cerr, - max_str_len, max_bin_len, max_array_len, max_map_len, max_ext_len) + max_str_len, max_bin_len, max_array_len, max_map_len, max_ext_len, memo) ret = unpack_construct(&ctx, buf, buf_len, &off) finally: if new_protocol: @@ -404,7 +409,7 @@ cdef class Unpacker(object): init_ctx(&self.ctx, object_hook, object_pairs_hook, list_hook, ext_hook, use_list, raw, strict_map_key, cenc, cerr, max_str_len, max_bin_len, max_array_len, - max_map_len, max_ext_len) + max_map_len, max_ext_len, {}) def feed(self, object next_bytes): """Append `next_bytes` to internal buffer.""" diff --git a/msgpack/fallback.py b/msgpack/fallback.py index 3836e830..127d6c1f 100644 --- a/msgpack/fallback.py +++ b/msgpack/fallback.py @@ -321,6 +321,8 @@ def __init__(self, file_like=None, read_size=0, use_list=True, raw=True, strict_ self._max_ext_len = max_ext_len self._stream_offset = 0 + self._memo = {} # used to memoize keys to reduce memory consumption + if list_hook is not None and not callable(list_hook): raise TypeError('`list_hook` is not callable') if object_hook is not None and not callable(object_hook): @@ -654,7 +656,9 @@ def _unpack(self, execute=EX_CONSTRUCT): ret = {} for _ in xrange(n): key = self._unpack(EX_CONSTRUCT) - if self._strict_map_key and type(key) not in (unicode, bytes): + if type(key) in (unicode, bytes): + key = self._memo.setdefault(key, key) + elif self._strict_map_key: raise ValueError("%s is not allowed for map key" % str(type(key))) ret[key] = self._unpack(EX_CONSTRUCT) if self._object_hook is not None: diff --git a/msgpack/unpack.h b/msgpack/unpack.h index 85dbbed5..6eba18a0 100644 --- a/msgpack/unpack.h +++ b/msgpack/unpack.h @@ -30,6 +30,7 @@ typedef struct unpack_user { const char *encoding; const char *unicode_errors; Py_ssize_t max_str_len, max_bin_len, max_array_len, max_map_len, max_ext_len; + PyObject *memo; } unpack_user; typedef PyObject* msgpack_unpack_object; @@ -189,7 +190,22 @@ static inline int unpack_callback_map(unpack_user* u, unsigned int n, msgpack_un static inline int unpack_callback_map_item(unpack_user* u, unsigned int current, msgpack_unpack_object* c, msgpack_unpack_object k, msgpack_unpack_object v) { - if (u->strict_map_key && !PyUnicode_CheckExact(k) && !PyBytes_CheckExact(k)) { + if (PyUnicode_CheckExact(k) || PyBytes_CheckExact(k)) { + PyObject *memokey = PyDict_GetItem(u->memo, k); + if (memokey != NULL) { + Py_INCREF(memokey); + Py_DECREF(k); + k = memokey; + } + else { + if (PyDict_SetItem(u->memo, k, k) < 0) { + Py_DECREF(k); + Py_DECREF(v); + return -1; + } + } + } + else if (u->strict_map_key) { PyErr_Format(PyExc_ValueError, "%.100s is not allowed for map key", Py_TYPE(k)->tp_name); return -1; } diff --git a/msgpack/unpack_template.h b/msgpack/unpack_template.h index 9924b9c6..398df42f 100644 --- a/msgpack/unpack_template.h +++ b/msgpack/unpack_template.h @@ -73,6 +73,7 @@ static inline PyObject* unpack_data(unpack_context* ctx) static inline void unpack_clear(unpack_context *ctx) { Py_CLEAR(ctx->stack[0].obj); + Py_CLEAR(ctx->user.memo); } template diff --git a/test/test_unpack.py b/test/test_unpack.py index 00a10612..dbaa3240 100644 --- a/test/test_unpack.py +++ b/test/test_unpack.py @@ -65,6 +65,15 @@ def _hook(self, code, data): assert unpacker.unpack() == {'a': ExtType(2, b'321')} +def test_unpacker_shares_stringly_keys(): + f = BytesIO(packb([{" a ?!": 1}, {" a ?!": 2}])) + unpacker = Unpacker(f) + d1, d2 = unpacker.unpack() + key1, = d1 + key2, = d2 + assert key1 is key2 + + if __name__ == '__main__': test_unpack_array_header_from_file() test_unpacker_hook_refcnt()