Skip to content

Commit

Permalink
Merge pull request #13 from Bhargavasomu/bytes
Browse files Browse the repository at this point in the history
Add support for Bytes and Bytearray objects
  • Loading branch information
jannikluhn committed Dec 19, 2018
2 parents 7f8cf6f + 0686a0f commit 8ed62a3
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 15 deletions.
2 changes: 2 additions & 0 deletions ssz/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
BYTES_PREFIX_LENGTH = 4
LIST_PREFIX_LENGTH = 4
8 changes: 8 additions & 0 deletions ssz/sedes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
Boolean,
)

from .bytes import ( # noqa: F401
bytes_sedes,
Bytes,
)

from .hash import ( # noqa: F401
address,
hash32,
Expand All @@ -28,6 +33,7 @@
from .list import ( # noqa: F401
address_list,
boolean_list,
bytes_list,
empty_list,
hash32_list,
uint32_list,
Expand All @@ -40,6 +46,8 @@
"address": address,
"boolean": boolean,
"boolean_list": boolean_list,
"bytes_list": bytes_list,
"bytes_sedes": bytes_sedes,
"empty_list": empty_list,
"hash32": hash32,
"hash32_list": hash32_list,
Expand Down
72 changes: 72 additions & 0 deletions ssz/sedes/bytes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from ssz.constants import (
BYTES_PREFIX_LENGTH,
)
from ssz.exceptions import (
DeserializationError,
SerializationError,
)


class Bytes:
"""
A sedes for byte objects.
"""

def serialize(self, val):
if not isinstance(val, (bytes, bytearray)):
raise SerializationError(
"Can only serialize bytes or bytearray objects",
val
)

object_len = len(val)
if object_len >= 2 ** (BYTES_PREFIX_LENGTH * 8):
raise SerializationError(
f'Object too long for its length to fit into {BYTES_PREFIX_LENGTH} bytes'
f'after serialization',
val
)

# Convert the length of bytes to a 4 bytes value
object_len_bytes = object_len.to_bytes(BYTES_PREFIX_LENGTH, 'big')

return object_len_bytes + val

def deserialize_segment(self, data, start_index):
"""
Deserialize the data from the given start_index
"""
# Make sure we have sufficient data for inferring length of bytes object
if len(data) < start_index + BYTES_PREFIX_LENGTH:
raise DeserializationError(
'Insufficient data: Cannot retrieve the length of bytes object',
data
)

# object_len contains the length of the original bytes object
object_len = int.from_bytes(data[start_index:start_index + BYTES_PREFIX_LENGTH], 'big')
# object_start_index is the start index of bytes object in the serialized bytes string
object_start_index = start_index + BYTES_PREFIX_LENGTH
object_end_index = object_start_index + object_len

# Make sure we have sufficent data for inferring the whole bytes object
if len(data) < object_end_index:
raise DeserializationError(
'Insufficient data: Cannot retrieve the whole list bytes object',
data
)

return data[object_start_index:object_end_index], object_end_index

def deserialize(self, data):
deserialized_data, end_index = self.deserialize_segment(data, 0)
if end_index != len(data):
raise DeserializationError(
'Data to be deserialized is too long',
data
)

return deserialized_data


bytes_sedes = Bytes()
24 changes: 13 additions & 11 deletions ssz/sedes/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
Iterable,
)

from ssz.constants import (
LIST_PREFIX_LENGTH,
)
from ssz.exceptions import (
DeserializationError,
SerializationError,
)
from ssz.sedes import (
address,
boolean,
bytes_sedes,
hash32,
uint32,
)
Expand All @@ -21,7 +25,6 @@ class List:
WARNING: Avoid sets if possible, may not always lead to expected results
(This is because iteration in sets doesn't always happen in the same order)
"""
LENGTH_BYTES = 4

def __init__(self, element_sedes=None, empty=False):
if element_sedes and empty:
Expand All @@ -42,9 +45,7 @@ def __init__(self, element_sedes=None, empty=False):
def serialize(self, val):
if (
not isinstance(val, Iterable) or
isinstance(val, bytes) or
isinstance(val, bytes) or
isinstance(val, str)
isinstance(val, (bytes, bytearray, str))
):
raise SerializationError(
'Can only serialize Iterable objects, except Dictionaries',
Expand All @@ -60,12 +61,12 @@ def serialize(self, val):
self.element_sedes.serialize(element) for element in val
)

if len(serialized_iterable_string) >= 2 ** (self.LENGTH_BYTES * 8):
if len(serialized_iterable_string) >= 2 ** (LIST_PREFIX_LENGTH * 8):
raise SerializationError(
'List too long to fit into {} bytes after serialization'.format(self.LENGTH_BYTES),
'List too long to fit into {} bytes after serialization'.format(LIST_PREFIX_LENGTH),
val
)
serialized_len = len(serialized_iterable_string).to_bytes(self.LENGTH_BYTES, 'big')
serialized_len = len(serialized_iterable_string).to_bytes(LIST_PREFIX_LENGTH, 'big')

return serialized_len + serialized_iterable_string

Expand All @@ -74,15 +75,15 @@ def deserialize_segment(self, data, start_index):
Deserialize the data from the given start_index
"""
# Make sure we have sufficient data for inferring length of list
if len(data) < start_index + self.LENGTH_BYTES:
if len(data) < start_index + LIST_PREFIX_LENGTH:
raise DeserializationError(
'Insufficient data: Cannot retrieve the length of list',
data
)

# Number of bytes of only the list data, excluding the prepended list length
list_length = int.from_bytes(data[start_index:start_index + self.LENGTH_BYTES], 'big')
list_end_index = start_index + self.LENGTH_BYTES + list_length
list_length = int.from_bytes(data[start_index:start_index + LIST_PREFIX_LENGTH], 'big')
list_end_index = start_index + LIST_PREFIX_LENGTH + list_length
# Make sure we have sufficent data for inferring the whole list
if len(data) < list_end_index:
raise DeserializationError(
Expand All @@ -92,7 +93,7 @@ def deserialize_segment(self, data, start_index):

deserialized_list = []
# element_start_index is the start index of an element in the serialized bytes string
element_start_index = start_index + self.LENGTH_BYTES
element_start_index = start_index + LIST_PREFIX_LENGTH
while element_start_index < list_end_index:
element, element_start_index = self.element_sedes.deserialize_segment(
data, element_start_index
Expand All @@ -114,6 +115,7 @@ def deserialize(self, data):

address_list = List(address)
boolean_list = List(boolean)
bytes_list = List(bytes_sedes)
empty_list = List(empty=True)
hash32_list = List(hash32)
uint32_list = List(uint32)
6 changes: 6 additions & 0 deletions ssz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ssz.sedes import (
List,
boolean,
bytes_sedes,
empty_list,
)

Expand Down Expand Up @@ -46,11 +47,16 @@ def infer_sedes(obj):
"""
if isinstance(obj, bool):
return boolean

elif isinstance(obj, int):
raise TypeError(
'uint sedes object or uint string needs to be specified for ints',
obj
)

elif isinstance(obj, (bytes, bytearray)):
return bytes_sedes

elif isinstance(obj, Iterable):
return infer_list_sedes(obj)

Expand Down
133 changes: 133 additions & 0 deletions tests/core/test_byte_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import pytest

from ssz import (
DeserializationError,
SerializationError,
decode,
encode,
)
from ssz.sedes import (
bytes_sedes,
)


@pytest.mark.parametrize(
'value,expected',
(
(b"", b'\x00\x00\x00\x00'),
(b"I", b'\x00\x00\x00\x01I'),
(b"foo", b'\x00\x00\x00\x03foo'),
(b"hello", b'\x00\x00\x00\x05hello'),
(bytearray(b""), b'\x00\x00\x00\x00'),
(bytearray(b"I"), b'\x00\x00\x00\x01I'),
(bytearray(b"foo"), b'\x00\x00\x00\x03foo'),
(bytearray(b"hello"), b'\x00\x00\x00\x05hello'),
),
)
def test_bytes_serialize_values(value, expected):
assert bytes_sedes.serialize(value) == expected


@pytest.mark.parametrize(
'value',
(
# Non-byte objects
None,
1,
1.0,
'',
'True',
[1, 2, 3],
(1, 2, 3),
{b"0": b"1"},
{b"0", b"1", b"2", b"3"},
),
)
def test_bytes_serialize_bad_values(value):
with pytest.raises(SerializationError):
bytes_sedes.serialize(value)


@pytest.mark.parametrize(
'value,expected',
(
(b'\x00\x00\x00\x00', b""),
(b'\x00\x00\x00\x01I', b"I"),
(b'\x00\x00\x00\x03foo', b"foo"),
(b'\x00\x00\x00\x05hello', b"hello"),
),
)
def test_bytes_deserialize_values(value, expected):
assert bytes_sedes.deserialize(value) == expected


@pytest.mark.parametrize(
'value',
(
# Less than 4 bytes of serialized data
b'\x00\x00\x01',
# Insufficient serialized object data as per found out byte object length
b'\x00\x00\x00\x04',
# Serialized data given is more than what is required
b'\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00\x01' + b'\x00'
),
)
def test_bytes_deserialization_bad_value(value):
with pytest.raises(DeserializationError):
bytes_sedes.deserialize(value)


@pytest.mark.parametrize(
'value,expected',
(
(b"", b''),
(b"I", b'I'),
(b"foo", b'foo'),
(b"hello", b'hello'),
(bytearray(b""), b''),
(bytearray(b"I"), b'I'),
(bytearray(b"foo"), b'foo'),
(bytearray(b"hello"), b'hello'),
),
)
def test_bytes_round_trip(value, expected):
assert bytes_sedes.deserialize(bytes_sedes.serialize(value)) == expected


@pytest.mark.parametrize(
'value,sedes',
(
(b"", 'bytes_sedes'),
(b"I", 'bytes_sedes'),
(b"foo", 'bytes_sedes'),
(b"hello", 'bytes_sedes'),
(b"", bytes_sedes),
(b"I", bytes_sedes),
(b"foo", bytes_sedes),
(b"hello", bytes_sedes),
),
)
def test_bytes_round_trip_codec(value, sedes):
if isinstance(sedes, str):
sedes_obj = eval(sedes)
else:
sedes_obj = sedes
assert decode(encode(value, sedes), sedes_obj) == value


@pytest.mark.parametrize(
'value',
(
b"",
b"I",
b"foo",
b"hello",
),
)
def test_bytes_round_trip_no_sedes(value):
assert decode(encode(value), bytes_sedes) == value

0 comments on commit 8ed62a3

Please sign in to comment.