Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions msgpack/_packer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,21 @@ cdef class Packer(object):
:param bool use_bin_type:
Use bin type introduced in msgpack spec 2.0 for bytes.
It also enable str8 type for unicode.
:param bool strict_types:
If set to true, types will be checked to be exact. Derived classes
from serializeable types will not be serialized and will be
treated as unsupported type and forwarded to default.
Additionally tuples will not be serialized as lists.
This is useful when trying to implement accurate serialization
for python types.
"""
cdef msgpack_packer pk
cdef object _default
cdef object _bencoding
cdef object _berrors
cdef char *encoding
cdef char *unicode_errors
cdef bint strict_types
cdef bool use_float
cdef bint autoreset

Expand All @@ -82,10 +90,12 @@ cdef class Packer(object):
self.pk.length = 0

def __init__(self, default=None, encoding='utf-8', unicode_errors='strict',
use_single_float=False, bint autoreset=1, bint use_bin_type=0):
use_single_float=False, bint autoreset=1, bint use_bin_type=0,
bint strict_types=0):
"""
"""
self.use_float = use_single_float
self.strict_types = strict_types
self.autoreset = autoreset
self.pk.use_bin_type = use_bin_type
if default is not None:
Expand Down Expand Up @@ -121,19 +131,20 @@ cdef class Packer(object):
cdef dict d
cdef size_t L
cdef int default_used = 0
cdef bint strict_types = self.strict_types

if nest_limit < 0:
raise PackValueError("recursion limit exceeded.")

while True:
if o is None:
ret = msgpack_pack_nil(&self.pk)
elif isinstance(o, bool):
elif PyBool_Check(o) if strict_types else isinstance(o, bool):
if o:
ret = msgpack_pack_true(&self.pk)
else:
ret = msgpack_pack_false(&self.pk)
elif PyLong_Check(o):
elif PyLong_CheckExact(o) if strict_types else PyLong_Check(o):
# PyInt_Check(long) is True for Python 3.
# So we should test long before int.
try:
Expand All @@ -150,25 +161,25 @@ cdef class Packer(object):
continue
else:
raise
elif PyInt_Check(o):
elif PyInt_CheckExact(o) if strict_types else PyInt_Check(o):
longval = o
ret = msgpack_pack_long(&self.pk, longval)
elif PyFloat_Check(o):
elif PyFloat_CheckExact(o) if strict_types else PyFloat_Check(o):
if self.use_float:
fval = o
ret = msgpack_pack_float(&self.pk, fval)
else:
dval = o
ret = msgpack_pack_double(&self.pk, dval)
elif PyBytes_Check(o):
elif PyBytes_CheckExact(o) if strict_types else PyBytes_Check(o):
L = len(o)
if L > (2**32)-1:
raise ValueError("bytes is too large")
rawval = o
ret = msgpack_pack_bin(&self.pk, L)
if ret == 0:
ret = msgpack_pack_raw_body(&self.pk, rawval, L)
elif PyUnicode_Check(o):
elif PyUnicode_CheckExact(o) if strict_types else PyUnicode_Check(o):
if not self.encoding:
raise TypeError("Can't encode unicode string: no encoding is specified")
o = PyUnicode_AsEncodedString(o, self.encoding, self.unicode_errors)
Expand All @@ -191,7 +202,7 @@ cdef class Packer(object):
if ret != 0: break
ret = self._pack(v, nest_limit-1)
if ret != 0: break
elif PyDict_Check(o):
elif not strict_types and PyDict_Check(o):
L = len(o)
if L > (2**32)-1:
raise ValueError("dict is too large")
Expand All @@ -202,7 +213,7 @@ cdef class Packer(object):
if ret != 0: break
ret = self._pack(v, nest_limit-1)
if ret != 0: break
elif isinstance(o, ExtType):
elif type(o) is ExtType if strict_types else isinstance(o, ExtType):
# This should be before Tuple because ExtType is namedtuple.
longval = o.code
rawval = o.data
Expand All @@ -211,7 +222,7 @@ cdef class Packer(object):
raise ValueError("EXT data is too large")
ret = msgpack_pack_ext(&self.pk, longval, L)
ret = msgpack_pack_raw_body(&self.pk, rawval, L)
elif PyTuple_Check(o) or PyList_Check(o):
elif PyList_CheckExact(o) if strict_types else (PyTuple_Check(o) or PyList_Check(o)):
L = len(o)
if L > (2**32)-1:
raise ValueError("list is too large")
Expand Down
44 changes: 33 additions & 11 deletions msgpack/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ def getvalue(self):
DEFAULT_RECURSE_LIMIT = 511


def _check_type_strict(obj, t, type=type, tuple=tuple):
if type(t) is tuple:
return type(obj) in t
else:
return type(obj) is t


def unpack(stream, **kwargs):
"""
Unpack an object from `stream`.
Expand Down Expand Up @@ -609,9 +616,18 @@ class Packer(object):
:param bool use_bin_type:
Use bin type introduced in msgpack spec 2.0 for bytes.
It also enable str8 type for unicode.
:param bool strict_types:
If set to true, types will be checked to be exact. Derived classes
from serializeable types will not be serialized and will be
treated as unsupported type and forwarded to default.
Additionally tuples will not be serialized as lists.
This is useful when trying to implement accurate serialization
for python types.
"""
def __init__(self, default=None, encoding='utf-8', unicode_errors='strict',
use_single_float=False, autoreset=True, use_bin_type=False):
use_single_float=False, autoreset=True, use_bin_type=False,
strict_types=False):
self._strict_types = strict_types
self._use_float = use_single_float
self._autoreset = autoreset
self._use_bin_type = use_bin_type
Expand All @@ -623,18 +639,24 @@ def __init__(self, default=None, encoding='utf-8', unicode_errors='strict',
raise TypeError("default must be callable")
self._default = default

def _pack(self, obj, nest_limit=DEFAULT_RECURSE_LIMIT, isinstance=isinstance):
def _pack(self, obj, nest_limit=DEFAULT_RECURSE_LIMIT,
check=isinstance, check_type_strict=_check_type_strict):
default_used = False
if self._strict_types:
check = check_type_strict
list_types = list
else:
list_types = (list, tuple)
while True:
if nest_limit < 0:
raise PackValueError("recursion limit exceeded")
if obj is None:
return self._buffer.write(b"\xc0")
if isinstance(obj, bool):
if check(obj, bool):
if obj:
return self._buffer.write(b"\xc3")
return self._buffer.write(b"\xc2")
if isinstance(obj, int_types):
if check(obj, int_types):
if 0 <= obj < 0x80:
return self._buffer.write(struct.pack("B", obj))
if -0x20 <= obj < 0:
Expand All @@ -660,7 +682,7 @@ def _pack(self, obj, nest_limit=DEFAULT_RECURSE_LIMIT, isinstance=isinstance):
default_used = True
continue
raise PackValueError("Integer value out of range")
if self._use_bin_type and isinstance(obj, bytes):
if self._use_bin_type and check(obj, bytes):
n = len(obj)
if n <= 0xff:
self._buffer.write(struct.pack('>BB', 0xc4, n))
Expand All @@ -671,8 +693,8 @@ def _pack(self, obj, nest_limit=DEFAULT_RECURSE_LIMIT, isinstance=isinstance):
else:
raise PackValueError("Bytes is too large")
return self._buffer.write(obj)
if isinstance(obj, (Unicode, bytes)):
if isinstance(obj, Unicode):
if check(obj, (Unicode, bytes)):
if check(obj, Unicode):
if self._encoding is None:
raise TypeError(
"Can't encode unicode string: "
Expand All @@ -690,11 +712,11 @@ def _pack(self, obj, nest_limit=DEFAULT_RECURSE_LIMIT, isinstance=isinstance):
else:
raise PackValueError("String is too large")
return self._buffer.write(obj)
if isinstance(obj, float):
if check(obj, float):
if self._use_float:
return self._buffer.write(struct.pack(">Bf", 0xca, obj))
return self._buffer.write(struct.pack(">Bd", 0xcb, obj))
if isinstance(obj, ExtType):
if check(obj, ExtType):
code = obj.code
data = obj.data
assert isinstance(code, int)
Expand All @@ -719,13 +741,13 @@ def _pack(self, obj, nest_limit=DEFAULT_RECURSE_LIMIT, isinstance=isinstance):
self._buffer.write(struct.pack("b", code))
self._buffer.write(data)
return
if isinstance(obj, (list, tuple)):
if check(obj, list_types):
n = len(obj)
self._fb_pack_array_header(n)
for i in xrange(n):
self._pack(obj[i], nest_limit - 1)
return
if isinstance(obj, dict):
if check(obj, dict):
return self._fb_pack_map_pairs(len(obj), dict_iteritems(obj),
nest_limit - 1)
if not default_used and self._default is not None:
Expand Down
15 changes: 15 additions & 0 deletions test/test_stricttype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# coding: utf-8

from collections import namedtuple
from msgpack import packb, unpackb


def test_namedtuple():
T = namedtuple('T', "foo bar")
def default(o):
if isinstance(o, T):
return dict(o._asdict())
raise TypeError('Unsupported type %s' % (type(o),))
packed = packb(T(1, 42), strict_types=True, use_bin_type=True, default=default)
unpacked = unpackb(packed, encoding='utf-8')
assert unpacked == {'foo': 1, 'bar': 42}