Skip to content

Commit

Permalink
Fix problems with serialization refactor found by full testing
Browse files Browse the repository at this point in the history
  • Loading branch information
gmfeinberg committed Dec 13, 2022
1 parent 3f2e36c commit 928e24b
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 30 deletions.
9 changes: 8 additions & 1 deletion src/borneo/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@

from .serdeutil import SerdeUtil

try:
from . import serde
except ImportError:
import serde



class PlanIterState(object):
STATE = enum(OPEN=0,
RUNNING=1,
Expand Down Expand Up @@ -656,7 +663,7 @@ class ConstIter(PlanIter):

def __init__(self, bis):
super(ConstIter, self).__init__(bis)
self._value = SerdeUtil.read_field_value(bis)
self._value = serde.BinaryProtocol.read_field_value(bis)

def close(self, rcb):
state = rcb.get_state(self.state_pos)
Expand Down
12 changes: 6 additions & 6 deletions src/borneo/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def read_field_value(bis):
if t == SerdeUtil.FIELD_VALUE_TYPE.ARRAY:
return BinaryProtocol.read_list(bis)
elif t == SerdeUtil.FIELD_VALUE_TYPE.BINARY:
return SerdeUtil.read_bytearray(bis)
return SerdeUtil.read_bytearray(bis, False)
elif t == SerdeUtil.FIELD_VALUE_TYPE.BOOLEAN:
return bis.read_boolean()
elif t == SerdeUtil.FIELD_VALUE_TYPE.DOUBLE:
Expand Down Expand Up @@ -211,7 +211,7 @@ def read_topology_info(bis):

@staticmethod
def read_version(bis):
return Version.create_version(SerdeUtil.read_bytearray(bis))
return Version.create_version(SerdeUtil.read_bytearray(bis, False))

# Writes fields from ReadRequest.
@staticmethod
Expand Down Expand Up @@ -536,7 +536,7 @@ def deserialize(self, request, bis, serial_version):
result = operations.MultiDeleteResult()
BinaryProtocol.deserialize_consumed_capacity(bis, result)
result.set_num_deletions(SerdeUtil.read_packed_int(bis))
result.set_continuation_key(SerdeUtil.read_bytearray(bis))
result.set_continuation_key(SerdeUtil.read_bytearray(bis, False))
return result


Expand Down Expand Up @@ -672,7 +672,7 @@ def serialize(self, request, bos, serial_version):
SerdeUtil.write_packed_int(bos, len(variables))
for key in variables:
SerdeUtil.write_string(bos, key)
SerdeUtil.write_field_value(bos, variables[key])
BinaryProtocol.write_field_value(bos, variables[key])
else:
SerdeUtil.write_packed_int(bos, 0)
else:
Expand All @@ -698,10 +698,10 @@ def deserialize(self, request, bis, serial_version):
SerdeUtil.read_packed_int_array(bis))
cont_keys = list()
for i in range(len(pids)):
cont_keys.append(SerdeUtil.read_bytearray(bis))
cont_keys.append(SerdeUtil.read_bytearray(bis, False))
result.set_partition_cont_keys(cont_keys)
BinaryProtocol.deserialize_consumed_capacity(bis, result)
result.set_continuation_key(SerdeUtil.read_bytearray(bis))
result.set_continuation_key(SerdeUtil.read_bytearray(bis, False))
request.set_cont_key(result.get_continuation_key())
# In V2, if the QueryRequest was not initially prepared, the prepared
# statement created at the proxy is returned back along with the query
Expand Down
42 changes: 40 additions & 2 deletions src/borneo/serdeutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,15 @@ def map_exception(code, msg):
'Unknown error code ' + str(code) + ': ' + msg)

@staticmethod
def read_bytearray(bis):
def read_bytearray(bis, skip):
"""
Reads a possibly None byte array as a
:py:meth:`read_sequence_length` followed by the array contents.
:param bis: the byte input stream.
:type bis: ByteInputStream
:param skip: True if skipping vs reading
:type bis: boolean
:returns: the array or None.
:rtype: bytearray
"""
Expand All @@ -293,12 +295,20 @@ def read_bytearray(bis):
raise IOError('Invalid length of byte array: ' + str(length))
if length == -1:
return None
if length == 0:
if length == 0 and not skip:
return bytearray()
if skip:
bis.set_offset(bis.get_offset() + length)
return None
buf = bytearray(length)
bis.read_fully(buf)
return buf

@staticmethod
def read_full_int(bis):
# Reads a full, 4-byte int
return bis.read_int()

@staticmethod
def read_bytearray_with_int(bis):
# Reads a byte array that has a not-packed integer size.
Expand Down Expand Up @@ -433,6 +443,13 @@ def read_string_array(bis):
array.append(SerdeUtil.read_string(bis))
return array

@staticmethod
def read_float(bis):
"""
Reads a float, which is a double-precision floating point
"""
return bis.read_float()

@staticmethod
def trace(msg, level):
if level <= SerdeUtil.TRACE_LEVEL:
Expand Down Expand Up @@ -460,6 +477,11 @@ def write_bytearray_with_int(bos, value):
bos.write_int(len(value))
bos.write_bytearray(value)

@staticmethod
def write_int_at_offset(bos, offset, value):
# Writes a full 4-byte int at the specified offset
bos.write_int_at_offset(offset, value)

@staticmethod
def write_datetime(bos, value):
# Serialize a datetime value.
Expand Down Expand Up @@ -521,6 +543,11 @@ def write_packed_long(bos, value):
offset = PackedInteger.write_sorted_long(buf, 0, value)
bos.write_bytearray(buf, 0, offset)

@staticmethod
def write_full_int(bos, value):
# Writes a full, 4-byte int
bos.write_int(value)

@staticmethod
def write_sequence_length(bos, length):
"""
Expand All @@ -540,6 +567,17 @@ def write_sequence_length(bos, length):
'Invalid sequence length: ' + str(length))
SerdeUtil.write_packed_int(bos, length)

@staticmethod
def write_float(bos, value):
"""
Writes a float, which is a double-precision floating point
:param bos: the byte output stream.
:type bos: ByteOutputStream
:param value: the float to be written.
:type value: float
"""
bos.write_float(value)

@staticmethod
def write_serial_version(bos, serial_version):
# Writes the (short) serial version
Expand Down
21 changes: 0 additions & 21 deletions test/write_multiple.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,27 +150,6 @@ def testWriteMultipleAddIllegalRequestAndAbortIfUnsuccessful(self):
self.requests[0], True).add(self.illegal_requests[1], False)
self.assertRaises(IllegalArgumentException, self.handle.write_multiple,
self.write_multiple_request)
if not is_onprem():
# add operations when the request size exceeded the limit
self.write_multiple_request.clear()
for op in range(64):
row = get_row()
row['fld_str'] = self.get_random_str(0.4)
self.write_multiple_request.add(PutRequest().set_value(
row).set_table_name(table_name), True)
self.assertRaises(RequestSizeLimitException,
self.handle.write_multiple,
self.write_multiple_request)
# add operations when sub requests reached the max number
self.write_multiple_request.clear()
for op in range(51):
row = get_row()
row['fld_id'] = op
self.write_multiple_request.add(PutRequest().set_value(
row).set_table_name(table_name), True)
self.assertRaises(BatchOperationNumberLimitException,
self.handle.write_multiple,
self.write_multiple_request)

def testWriteMultipleGetRequestWithIllegalIndex(self):
self.assertRaises(IllegalArgumentException,
Expand Down

0 comments on commit 928e24b

Please sign in to comment.