Skip to content

Commit

Permalink
Merge pull request #7 from Bhargavasomu/uint_and_hashN
Browse files Browse the repository at this point in the history
Add serialization/deserialization functionality for Integers
  • Loading branch information
jannikluhn committed Dec 10, 2018
2 parents 1457d96 + 70070ed commit d87f05d
Show file tree
Hide file tree
Showing 8 changed files with 505 additions and 2 deletions.
26 changes: 24 additions & 2 deletions ssz/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,39 @@

from ssz.exceptions import (
DecodingError,
InvalidSedesError,
)
from ssz.sedes import (
sedes_by_name,
)
from ssz.utils import (
infer_sedes,
is_sedes,
)


def encode(obj):
def encode(obj, sedes=None):
"""
Encode object in SSZ format.
`sedes` needs to be explicitly mentioned for encode/decode
of integers(as of now).
`sedes` parameter could be given as a string or as the
actual sedes object itself.
"""
serialized_obj = infer_sedes(obj).serialize(obj)
if sedes:
if sedes in sedes_by_name:
# Get the actual sedes object from string representation
sedes_obj = sedes_by_name[sedes]
else:
sedes_obj = sedes

if not is_sedes(sedes_obj):
raise InvalidSedesError("Invalid sedes object", sedes)

else:
sedes_obj = infer_sedes(obj)

serialized_obj = sedes_obj.serialize(obj)
return serialized_obj


Expand Down
10 changes: 10 additions & 0 deletions ssz/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@ class SSZException(Exception):
pass


class InvalidSedesError(SSZException):
"""
Exception raised if encoding fails.
"""

def __init__(self, message, sedes):
super(InvalidSedesError, self).__init__(message)
self.sedes = sedes


class EncodingError(SSZException):
"""
Exception raised if encoding fails.
Expand Down
37 changes: 37 additions & 0 deletions ssz/sedes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,40 @@
boolean,
Boolean,
)

from .hash import ( # noqa: F401
address,
hash32,
Hash,
)

from .integer import ( # noqa: F401
uint8,
uint16,
uint24,
uint32,
uint40,
uint48,
uint56,
uint64,
uint128,
uint256,
uint384,
uint512,
UnsignedInteger,
)


sedes_by_name = {
"address": address,
"boolean": boolean,
"hash32": hash32,
"uint8": uint8,
"uint16": uint16,
"uint24": uint24,
"uint32": uint32,
"uint40": uint40,
"uint48": uint48,
"uint56": uint56,
"uint64": uint64,
}
55 changes: 55 additions & 0 deletions ssz/sedes/hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from ssz.exceptions import (
DeserializationError,
SerializationError,
)


class Hash:
"""
A sedes for hashes (hash<N>).
"""
num_bytes = 0

def __init__(self, num_bytes):
if num_bytes <= 0:
raise ValueError(
"Number of bytes should be non-negavtive"
)

self.num_bytes = num_bytes

def serialize(self, val):
if len(val) != self.num_bytes:
raise SerializationError(
"Can only serialize values of {} bytes".format(self.num_bytes),
val
)

return val

def deserialize_segment(self, data, start_index):
"""
Deserialize the data from the given start_index
"""
# Make sure we have sufficient data for deserializing
if len(data) < self.num_bytes + start_index:
raise DeserializationError(
'Insufficient data for deserializing',
data
)
end_index = start_index + self.num_bytes
return data[start_index:end_index], end_index

def deserialize(self, data):
deserialized_data, end_index = self.deserialize_segment(data, 0)
if end_index != len(data):
raise DeserializationError(
'Data to be deserialized is too long',
data
)

return deserialized_data


hash32 = Hash(32)
address = Hash(20)
79 changes: 79 additions & 0 deletions ssz/sedes/integer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from ssz.exceptions import (
DeserializationError,
SerializationError,
)


class UnsignedInteger:
"""
A sedes for integers (uint<N>).
"""
num_bytes = 0

def __init__(self, num_bits):
# Make sure the number of bits are multiple of 8
if num_bits % 8 != 0:
raise ValueError(
"Number of bits should be multiple of 8"
)
if num_bits <= 0:
raise ValueError(
"Number of bits should be greater than 0"
)
self.num_bytes = num_bits // 8

def serialize(self, val):
if isinstance(val, bool) or not isinstance(val, int):
raise SerializationError(
'As per specified sedes object, can only serialize non-negative integer values',
val
)
if val < 0:
raise SerializationError(
'As per specified sedes object, can only serialize non-negative integer values',
val
)

try:
serialized_obj = val.to_bytes(self.num_bytes, 'big')
except OverflowError as err:
raise SerializationError('As per specified sedes object, %s' % err, val)

return serialized_obj

def deserialize_segment(self, data, start_index):
"""
Deserialize the data from the given start_index
"""
# Make sure we have sufficient data for deserializing
if len(data) + start_index < self.num_bytes:
raise DeserializationError(
'Insufficient data for deserializing',
data
)
end_index = start_index + self.num_bytes
return int.from_bytes(data[start_index:end_index], 'big'), end_index

def deserialize(self, data):
deserialized_data, end_index = self.deserialize_segment(data, 0)
if end_index != len(data):
raise DeserializationError(
'Data to be deserialized is too long',
data
)

return deserialized_data


uint8 = UnsignedInteger(8)
uint16 = UnsignedInteger(16)
uint24 = UnsignedInteger(24)
uint32 = UnsignedInteger(32)
uint40 = UnsignedInteger(40)
uint48 = UnsignedInteger(48)
uint56 = UnsignedInteger(56)
uint64 = UnsignedInteger(64)
uint128 = UnsignedInteger(128)
uint256 = UnsignedInteger(256)
uint384 = UnsignedInteger(384)
uint512 = UnsignedInteger(512)
8 changes: 8 additions & 0 deletions ssz/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from ssz.exceptions import (
SerializationError,
)
from ssz.sedes import (
boolean,
)
Expand All @@ -18,6 +21,11 @@ def infer_sedes(obj):
"""
if isinstance(obj, bool):
return boolean
elif isinstance(obj, int):
raise SerializationError(
'uint sedes object or uint string needs to be specified for ints',
obj
)

msg = 'Did not find sedes handling type {}'.format(type(obj).__name__)
raise TypeError(msg)
120 changes: 120 additions & 0 deletions tests/core/test_hash_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import pytest

from ssz import (
DeserializationError,
SerializationError,
decode,
encode,
)
from ssz.sedes import (
Hash,
address,
hash32,
)


@pytest.mark.parametrize(
'num_bytes',
(
0,
-10,
-100,
),
)
def test_reject_hash_object_negative_bytes(num_bytes):
with pytest.raises(ValueError):
Hash(num_bytes)


def test_hash_serialize_values():
for num_bytes in range(1, 33):
value = b'\x01' * num_bytes
assert Hash(num_bytes).serialize(value) == value


@pytest.mark.parametrize(
'value,sedes',
(
(b'\x01' * 32, Hash(16)),
(b'\x01' * 32, Hash(20)),
(b'\x01' * 16, Hash(20)),
(b'\x01' * 16, hash32),
(b'\x01' * 32, Hash(20)),
),
)
def test_hash_serialize_bad_values(value, sedes):
with pytest.raises(SerializationError):
sedes.serialize(value)


def test_hash_deserialize_values():
for num_bytes in range(1, 33):
value = b'\x01' * num_bytes
assert Hash(num_bytes).deserialize(value) == value


@pytest.mark.parametrize(
'value,sedes',
(
# Values too short
(b'\x01' * 15, Hash(16)),
(b'\x01' * 16, Hash(20)),
(b'\x01' * 10, Hash(20)),
(b'\x01' * 5, Hash(20)),
(b'\x01' * 16, Hash(32)),
# Values too long
(b'\x01' * 20, Hash(16)),
(b'\x01' * 25, Hash(20)),
(b'\x01' * 40, Hash(32)),
),
)
def test_hash_deserialize_bad_values(value, sedes):
with pytest.raises(DeserializationError):
sedes.deserialize(value)


def test_hash_round_trip():
for num_bytes in range(1, 33):
value = b'\x01' * num_bytes
sedes_obj = Hash(num_bytes)
assert sedes_obj.deserialize(sedes_obj.serialize(value)) == value


@pytest.mark.parametrize(
'value',
(
b'\x00' * 20,
b'\x01' * 20,
),
)
def test_address_round_trip(value):
assert address.deserialize(address.serialize(value)) == value


@pytest.mark.parametrize(
'value,sedes',
(
(b'\x01' * 32, 'hash32'),
(b'\x01' * 32, hash32),
(b'\x01' * 32, Hash(32)),
(b'\x01' * 64, Hash(64)),
),
)
def test_hash_round_trip_codec(value, sedes):
if isinstance(sedes, str):
sedes_obj = eval(sedes)
else:
sedes_obj = sedes
assert decode(encode(value, sedes), sedes_obj) == value


@pytest.mark.parametrize(
'value',
(
b'\x00' * 20,
b'\x01' * 20,
),
)
def test_address_round_trip_codec(value):
assert decode(encode(value, address), address) == value

0 comments on commit d87f05d

Please sign in to comment.