Skip to content
This repository has been archived by the owner on Dec 10, 2018. It is now read-only.

Commit

Permalink
Merge pull request #149 from maralla/validate_field_types
Browse files Browse the repository at this point in the history
validate field types when write data to transport
  • Loading branch information
lxyu committed Aug 5, 2015
2 parents a9a65d4 + 4958f04 commit 32f9696
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 49 deletions.
18 changes: 18 additions & 0 deletions tests/test_base.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

import thriftpy
from thriftpy.thrift import parse_spec, TType


def test_obj_equalcheck():
Expand Down Expand Up @@ -37,3 +38,20 @@ def test_default_value():
ab = thriftpy.load("addressbook.thrift")

assert ab.PhoneNumber().type == ab.PhoneType.MOBILE


def test_parse_spec():
ab = thriftpy.load("addressbook.thrift")

cases = [
((TType.I32, None), "I32"),
((TType.STRUCT, ab.PhoneNumber), "PhoneNumber"),
((TType.LIST, TType.I32), "LIST<I32>"),
((TType.LIST, (TType.STRUCT, ab.PhoneNumber)), "LIST<PhoneNumber>"),
((TType.MAP, (TType.STRING, (
TType.LIST, (TType.MAP, (TType.STRING, TType.STRING))))),
"MAP<STRING, LIST<MAP<STRING, STRING>>>")
]

for spec, res in cases:
assert parse_spec(*spec) == res
29 changes: 28 additions & 1 deletion tests/test_protocol_cybinary.py
Expand Up @@ -7,7 +7,7 @@
import pytest

from thriftpy._compat import u
from thriftpy.thrift import TType, TPayload
from thriftpy.thrift import TType, TPayload, TDecodeException
from thriftpy.transport import TSocket, TServerSocket
from thriftpy.utils import hexlify

Expand Down Expand Up @@ -402,3 +402,30 @@ def test_multiple_read_struct():
p.read_struct(_item2)

assert _item1 == item1 and _item2 == item2


def test_write_decode_error():
t = TCyMemoryBuffer()
p = proto.TCyBinaryProtocol(t)

class T(TPayload):
thrift_spec = {
1: (TType.I32, "id", False),
2: (TType.LIST, "phones", TType.STRING, False),
3: (TType.STRUCT, "item", TItem, False),
4: (TType.MAP, "mm", (TType.STRING, (TType.STRUCT, TItem)), False)
}
default_spec = [("id", None), ("phones", None), ("item", None),
("mm", None)]

cases = [
(T(id="hello"), "Field 'id(1)' of 'T' needs type 'I32', but the value is `'hello'`"), # noqa
(T(phones=[90, 12]), "Field 'phones(2)' of 'T' needs type 'LIST<STRING>', but the value is `[90, 12]`"), # noqa
(T(item=12), "Field 'item(3)' of 'T' needs type 'TItem', but the value is `12`"), # noqa
(T(mm=[45, 56]), "Field 'mm(4)' of 'T' needs type 'MAP<STRING, TItem>', but the value is `[45, 56]`") # noqa
]

for obj, res in cases:
with pytest.raises(TDecodeException) as exc:
p.write_struct(obj)
assert str(exc.value) == res
121 changes: 73 additions & 48 deletions thriftpy/protocol/cybin/cybin.pyx
Expand Up @@ -4,6 +4,8 @@ from cpython cimport bool

from thriftpy.transport.cybase cimport CyTransportBase, STACK_STRING_LEN

from ..thrift import TDecodeException

cdef extern from "endian_port.h":
int16_t htobe16(int16_t n)
int32_t htobe32(int32_t n)
Expand Down Expand Up @@ -94,6 +96,63 @@ cdef inline int write_double(CyTransportBase buf, double val) except -1:
return 0


cdef inline write_list(CyTransportBase buf, list val, spec):
cdef TType e_type
cdef int val_len

if isinstance(spec, int):
e_type = spec
e_spec = None
else:
e_type = spec[0]
e_spec = spec[1]

val_len = len(val)
write_i08(buf, e_type)
write_i32(buf, val_len)

for e_val in val:
c_write_val(buf, e_type, e_val, e_spec)


cdef inline write_string(CyTransportBase buf, bytes val):
cdef int val_len = len(val)
write_i32(buf, val_len)

buf.c_write(<char*>val, val_len)


cdef inline write_dict(CyTransportBase buf, dict val, spec):
cdef int val_len
cdef TType v_type, k_type

key = spec[0]
if isinstance(key, int):
k_type = key
k_spec = None
else:
k_type = key[0]
k_spec = key[1]

value = spec[1]
if isinstance(value, int):
v_type = value
v_spec = None
else:
v_type = value[0]
v_spec = value[1]

val_len = len(val)

write_i08(buf, k_type)
write_i08(buf, v_type)
write_i32(buf, val_len)

for k, v in val.items():
c_write_val(buf, k_type, k, k_spec)
c_write_val(buf, v_type, v, v_spec)


cdef inline read_struct(CyTransportBase buf, obj):
cdef dict field_specs = obj.thrift_spec
cdef int fid
Expand Down Expand Up @@ -149,7 +208,11 @@ cdef inline write_struct(CyTransportBase buf, obj):

write_i08(buf, f_type)
write_i16(buf, fid)
c_write_val(buf, f_type, v, container_spec)
try:
c_write_val(buf, f_type, v, container_spec)
except (TypeError, AttributeError):
raise TDecodeException(obj.__class__.__name__, fid, f_name, v,
f_type, container_spec)

write_i08(buf, T_STOP)

Expand Down Expand Up @@ -253,9 +316,6 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None):


cdef c_write_val(CyTransportBase buf, TType ttype, val, spec=None):
cdef int val_len
cdef TType e_type, v_type, k_type

if ttype == T_BOOL:
write_i08(buf, 1 if val else 0)

Expand All @@ -276,54 +336,19 @@ cdef c_write_val(CyTransportBase buf, TType ttype, val, spec=None):

elif ttype == T_STRING:
if not isinstance(val, bytes):
val = val.encode("utf-8")

val_len = len(val)
write_i32(buf, val_len)

buf.c_write(<char*>val, val_len)
try:
val = val.encode("utf-8")
except Exception:
pass
write_string(buf, val)

elif ttype == T_SET or ttype == T_LIST:
if isinstance(spec, int):
e_type = spec
e_spec = None
else:
e_type = spec[0]
e_spec = spec[1]

val_len = len(val)
write_i08(buf, e_type)
write_i32(buf, val_len)

for e_val in val:
c_write_val(buf, e_type, e_val, e_spec)
if isinstance(val, (tuple, dict, set)):
val = list(val)
write_list(buf, val, spec)

elif ttype == T_MAP:
key = spec[0]
if isinstance(key, int):
k_type = key
k_spec = None
else:
k_type = key[0]
k_spec = key[1]

value = spec[1]
if isinstance(value, int):
v_type = value
v_spec = None
else:
v_type = value[0]
v_spec = value[1]

val_len = len(val)

write_i08(buf, k_type)
write_i08(buf, v_type)
write_i32(buf, val_len)

for k, v in val.items():
c_write_val(buf, k_type, k, k_spec)
c_write_val(buf, v_type, v, v_spec)
write_dict(buf, val, spec)

elif ttype == T_STRUCT:
write_struct(buf, val)
Expand Down
34 changes: 34 additions & 0 deletions thriftpy/thrift.py
Expand Up @@ -19,6 +19,23 @@ def args2kwargs(thrift_spec, *args):
return dict(zip(arg_names, args))


def parse_spec(ttype, spec=None):
name_map = TType._VALUES_TO_NAMES
_type = lambda s: parse_spec(*s) if isinstance(s, tuple) else name_map[s]

if spec is None:
return name_map[ttype]

if ttype == TType.STRUCT:
return spec.__name__

if ttype in (TType.LIST, TType.SET):
return "%s<%s>" % (name_map[ttype], _type(spec))

if ttype == TType.MAP:
return "MAP<%s, %s>" % (_type(spec[0]), _type(spec[1]))


class TType(object):
STOP = 0
VOID = 1
Expand Down Expand Up @@ -307,6 +324,23 @@ class TException(TPayload, Exception):
"""Base class for all thrift exceptions."""


class TDecodeException(TException):
def __init__(self, name, fid, field, value, ttype, spec=None):
self.struct_name = name
self.fid = fid
self.field = field
self.value = value

self.type_repr = parse_spec(ttype, spec)

def __str__(self):
return (
"Field '%s(%s)' of '%s' needs type '%s', "
"but the value is `%r`"
) % (self.field, self.fid, self.struct_name, self.type_repr,
self.value)


class TApplicationException(TException):
"""Application level thrift exceptions."""

Expand Down

0 comments on commit 32f9696

Please sign in to comment.