Skip to content

Commit

Permalink
add decode_response option for protocol, fix Thriftpy#190
Browse files Browse the repository at this point in the history
  • Loading branch information
lxyu authored and dan-blanchard committed May 6, 2016
1 parent 93be6ec commit 35a9889
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 36 deletions.
7 changes: 7 additions & 0 deletions tests/test_protocol_binary.py
Expand Up @@ -89,6 +89,13 @@ def test_unpack_string():
assert u("你好世界") == proto.read_val(b, TType.STRING)


def test_unpack_binary():
bs = BytesIO(b"\x00\x00\x00\x0c"
b"\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c")
assert u("你好世界").encode("utf-8") == proto.read_val(
bs, TType.STRING, decode_response=False)


def test_write_message_begin():
b = BytesIO()
proto.TBinaryProtocol(b).write_message_begin("test", TType.STRING, 1)
Expand Down
8 changes: 8 additions & 0 deletions tests/test_protocol_compact.py
Expand Up @@ -107,6 +107,14 @@ def test_unpack_string():
assert u('你好世界') == proto.read_val(TType.STRING)


def test_unpack_binary():
b, proto = gen_proto(b'\x0c\xe4\xbd\xa0\xe5\xa5'
b'\xbd\xe4\xb8\x96\xe7\x95\x8c')
proto.decode_response = False

assert u('你好世界').encode("utf-8") == proto.read_val(TType.STRING)


def test_pack_bool():
b, proto = gen_proto()
proto.write_bool(True)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_protocol_cybinary.py
Expand Up @@ -140,6 +140,13 @@ def test_read_string():
assert u("你好世界") == proto.read_val(b, TType.STRING)


def test_read_binary():
b = TCyMemoryBuffer(b"\x00\x00\x00\x0c"
b"\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c")
assert u("你好世界").encode("utf-8") == proto.read_val(
b, TType.STRING, decode_response=False)


def test_write_message_begin():
trans = TCyMemoryBuffer()
b = proto.TCyBinaryProtocol(trans)
Expand Down
36 changes: 24 additions & 12 deletions thriftpy/protocol/binary.py
Expand Up @@ -205,7 +205,7 @@ def read_map_begin(inbuf):
return k_type, v_type, sz


def read_val(inbuf, ttype, spec=None):
def read_val(inbuf, ttype, spec=None, decode_response=True):
if ttype == TType.BOOL:
return bool(unpack_i8(inbuf.read(1)))

Expand All @@ -227,11 +227,15 @@ def read_val(inbuf, ttype, spec=None):
elif ttype == TType.STRING:
sz = unpack_i32(inbuf.read(4))
byte_payload = inbuf.read(sz)
# Since we cannot tell if we're getting STRING or BINARY, try both
try:
return byte_payload.decode('utf-8')
except UnicodeDecodeError:
return byte_payload

# Since we cannot tell if we're getting STRING or BINARY
# if not asked not to decode, try both
if decode_response:
try:
return byte_payload.decode('utf-8')
except UnicodeDecodeError:
pass
return byte_payload

elif ttype == TType.SET or ttype == TType.LIST:
if isinstance(spec, tuple):
Expand Down Expand Up @@ -285,7 +289,7 @@ def read_val(inbuf, ttype, spec=None):
return obj


def read_struct(inbuf, obj):
def read_struct(inbuf, obj, decode_response=True):
while True:
f_type, fid = read_field_begin(inbuf)
if f_type == TType.STOP:
Expand All @@ -307,7 +311,8 @@ def read_struct(inbuf, obj):
skip(inbuf, f_type)
continue

setattr(obj, f_name, read_val(inbuf, f_type, f_container_spec))
setattr(obj, f_name,
read_val(inbuf, f_type, f_container_spec, decode_response))


def skip(inbuf, ftype):
Expand Down Expand Up @@ -351,10 +356,13 @@ def skip(inbuf, ftype):
class TBinaryProtocol(object):
"""Binary implementation of the Thrift protocol driver."""

def __init__(self, trans, strict_read=True, strict_write=True):
def __init__(self, trans,
strict_read=True, strict_write=True,
decode_response=True):
self.trans = trans
self.strict_read = strict_read
self.strict_write = strict_write
self.decode_response = decode_response

def skip(self, ttype):
skip(self.trans, ttype)
Expand All @@ -375,16 +383,20 @@ def write_message_end(self):
pass

def read_struct(self, obj):
return read_struct(self.trans, obj)
return read_struct(self.trans, obj, self.decode_response)

def write_struct(self, obj):
write_val(self.trans, TType.STRUCT, obj)


class TBinaryProtocolFactory(object):
def __init__(self, strict_read=True, strict_write=True):
def __init__(self,strict_read=True, strict_write=True,
decode_response=True):
self.strict_read = strict_read
self.strict_write = strict_write
self.decode_response = decode_response

def get_protocol(self, trans):
return TBinaryProtocol(trans, self.strict_read, self.strict_write)
return TBinaryProtocol(trans,
self.strict_read, self.strict_write,
self.decode_response)
20 changes: 13 additions & 7 deletions thriftpy/protocol/compact.py
Expand Up @@ -122,12 +122,13 @@ class TCompactProtocol(object):
TYPE_BITS = 0x07
TYPE_SHIFT_AMOUNT = 5

def __init__(self, trans):
def __init__(self, trans, decode_response=True):
self.trans = trans
self._last_fid = 0
self._bool_fid = None
self._bool_value = None
self._structs = []
self.decode_response = decode_response

def _get_ttype(self, byte):
return TTYPES[byte & 0x0f]
Expand Down Expand Up @@ -227,12 +228,14 @@ def read_double(self):

def read_string(self):
len = self._read_size()

byte_payload = self.trans.read(len)
try:
return byte_payload.decode('utf-8')
except UnicodeDecodeError:
return byte_payload

if self.decode_response:
try:
byte_payload = byte_payload.decode('utf-8')
except UnicodeDecodeError:
pass
return byte_payload

def read_bool(self):
if self._bool_value is not None:
Expand Down Expand Up @@ -556,5 +559,8 @@ def skip(self, ttype):


class TCompactProtocolFactory(object):
def __init__(self, decode_response=True):
self.decode_response = decode_response

def get_protocol(self, trans):
return TCompactProtocol(trans)
return TCompactProtocol(trans, decode_response=self.decode_response)
44 changes: 27 additions & 17 deletions thriftpy/protocol/cybin/cybin.pyx
Expand Up @@ -153,7 +153,7 @@ cdef inline write_dict(CyTransportBase buf, object val, spec):
c_write_val(buf, v_type, v, v_spec)


cdef inline read_struct(CyTransportBase buf, obj):
cdef inline read_struct(CyTransportBase buf, obj, decode_response=True):
cdef dict field_specs = obj.thrift_spec
cdef int fid
cdef TType field_type, ttype
Expand Down Expand Up @@ -182,7 +182,7 @@ cdef inline read_struct(CyTransportBase buf, obj):
else:
spec = field_spec[2]

setattr(obj, name, c_read_val(buf, ttype, spec))
setattr(obj, name, c_read_val(buf, ttype, spec, decode_response))

return obj

Expand Down Expand Up @@ -217,7 +217,7 @@ cdef inline write_struct(CyTransportBase buf, obj):
write_i08(buf, T_STOP)


cdef inline c_read_string(CyTransportBase buf, int32_t size):
cdef inline c_read_binary(CyTransportBase buf, int32_t size):
cdef char string_val[STACK_STRING_LEN]

if size > STACK_STRING_LEN:
Expand All @@ -229,13 +229,15 @@ cdef inline c_read_string(CyTransportBase buf, int32_t size):
buf.c_read(size, string_val)
py_data = string_val[:size]

try:
return py_data.decode("utf-8")
except UnicodeDecodeError:
return py_data
return py_data


cdef inline c_read_string(CyTransportBase buf, int32_t size):
return c_read_binary(buf, size).decode("utf-8")

cdef c_read_val(CyTransportBase buf, TType ttype, spec=None):

cdef c_read_val(CyTransportBase buf, TType ttype, spec=None,
decode_response=True):
cdef int size
cdef int64_t n
cdef TType v_type, k_type, orig_type, orig_key_type
Expand All @@ -261,7 +263,10 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None):

elif ttype == T_STRING:
size = read_i32(buf)
return c_read_string(buf, size)
if decode_response:
return c_read_string(buf, size)
else:
return c_read_binary(buf, size)

elif ttype == T_SET or ttype == T_LIST:
if isinstance(spec, int):
Expand Down Expand Up @@ -343,7 +348,6 @@ cdef c_write_val(CyTransportBase buf, TType ttype, val, spec=None):
write_string(buf, val)

elif ttype == T_SET or ttype == T_LIST:
assert not isinstance(val, basestring)
write_list(buf, val, spec)

elif ttype == T_MAP:
Expand All @@ -367,7 +371,7 @@ cpdef skip(CyTransportBase buf, TType ttype):
read_i64(buf)
elif ttype == T_STRING:
size = read_i32(buf)
c_read_string(buf, size)
c_read_binary(buf, size)
elif ttype == T_SET or ttype == T_LIST:
v_type = <TType>read_i08(buf)
size = read_i32(buf)
Expand All @@ -389,8 +393,8 @@ cpdef skip(CyTransportBase buf, TType ttype):
skip(buf, f_type)


def read_val(CyTransportBase buf, TType ttype):
return c_read_val(buf, ttype)
def read_val(CyTransportBase buf, TType ttype, decode_response=True):
return c_read_val(buf, ttype, None, decode_response)


def write_val(CyTransportBase buf, TType ttype, val, spec=None):
Expand All @@ -401,11 +405,14 @@ cdef class TCyBinaryProtocol(object):
cdef public CyTransportBase trans
cdef public bool strict_read
cdef public bool strict_write
cdef public bool decode_response

def __init__(self, trans, strict_read=True, strict_write=True):
def __init__(self, trans, strict_read=True, strict_write=True,
decode_response=True):
self.trans = trans
self.strict_read = strict_read
self.strict_write = strict_write
self.decode_response = decode_response

def skip(self, ttype):
skip(self.trans, <TType>(ttype))
Expand Down Expand Up @@ -452,7 +459,7 @@ cdef class TCyBinaryProtocol(object):

def read_struct(self, obj):
try:
return read_struct(self.trans, obj)
return read_struct(self.trans, obj, self.decode_response)
except Exception:
self.trans.clean()
raise
Expand All @@ -466,9 +473,12 @@ cdef class TCyBinaryProtocol(object):


class TCyBinaryProtocolFactory(object):
def __init__(self, strict_read=True, strict_write=True):
def __init__(self, strict_read=True, strict_write=True,
decode_response=True):
self.strict_read = strict_read
self.strict_write = strict_write
self.decode_response = decode_response

def get_protocol(self, trans):
return TCyBinaryProtocol(trans, self.strict_read, self.strict_write)
return TCyBinaryProtocol(
trans, self.strict_read, self.strict_write, self.decode_response)

0 comments on commit 35a9889

Please sign in to comment.