Skip to content

Commit

Permalink
Merge pull request #37 from hwwhww/encode_cache
Browse files Browse the repository at this point in the history
Add `encode` cache
  • Loading branch information
hwwhww committed Feb 8, 2019
2 parents 394f530 + d166d6c commit 60865ce
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 1 deletion.
27 changes: 26 additions & 1 deletion ssz/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
)

from ssz.sedes import (
Serializable,
sedes_by_name,
)
from ssz.sedes.base import (
Expand All @@ -13,14 +14,34 @@
)


def encode(value, sedes=None):
def encode(value, sedes=None, cache=True):
"""
Encode object in SSZ format.
`sedes` needs to be explicitly mentioned for encode/decode
of integers(as of now).
`sedes` parameter could be given as a string or as the
actual sedes object itself.
If `value` has an attribute :attr:`_cached_ssz` (as, notably,
:class:`ssz.sedes.Serializable`) and its value is not `None`, this value is
returned bypassing serialization and encoding, unless `sedes` is given (as
the cache is assumed to refer to the standard serialization which can be
replaced by specifying `sedes`).
If `value` is a :class:`ssz.sedes.Serializable` and `cache` is true, the result of
the encoding will be stored in :attr:`_cached_ssz` if it is empty.
"""
if isinstance(value, Serializable):
cached_ssz = value._cached_ssz
if sedes is None and cached_ssz is not None:
return cached_ssz
else:
really_cache = (
cache and
sedes is None
)
else:
really_cache = False

if sedes is not None:
if sedes in sedes_by_name:
# Get the actual sedes object from string representation
Expand All @@ -35,6 +56,10 @@ def encode(value, sedes=None):
sedes_obj = infer_sedes(value)

serialized_obj = sedes_obj.serialize(value)

if really_cache:
value._cached_ssz = serialized_obj

return serialized_obj


Expand Down
2 changes: 2 additions & 0 deletions ssz/sedes/serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def merge_args_to_kwargs(args, kwargs, arg_names):

class BaseSerializable(collections.Sequence):

_cached_ssz = None

def __init__(self, *args, **kwargs):
validate_args_and_kwargs(args, kwargs, self._meta.field_names)
field_values = merge_kwargs_to_args(args, kwargs, self._meta.field_names)
Expand Down
113 changes: 113 additions & 0 deletions tests/sedes/test_speed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import time

from ssz import (
encode,
)
from ssz.sedes import (
List,
Serializable,
bytes32,
uint24,
uint64,
uint384,
)


class ValidatorRecord(Serializable):
fields = [
('pubkey', uint384),
('withdrawal_credentials', bytes32),
('randao_commitment', bytes32),
('randao_layers', uint64),
('status', uint64),
('latest_status_change_slot', uint64),
('exit_count', uint64),
('poc_commitment', bytes32),
('last_poc_change_slot', uint64),
('second_last_poc_change_slot', uint64),
]


class CrosslinkRecord(Serializable):
fields = [
('slot', uint64),
('shard_block_root', bytes32),
]


class ShardCommittee(Serializable):
fields = [
('shard', uint64),
('committee', List(uint24)),
('total_validator_count', uint64),
]


class State(Serializable):
fields = [
('validator_registry', List(ValidatorRecord)),
('shard_and_committee_for_slots', List(List(ShardCommittee))),
('latest_crosslinks', List(CrosslinkRecord)),
]


validator_record = ValidatorRecord(
pubkey=123,
withdrawal_credentials=b'\x56' * 32,
randao_commitment=b'\x56' * 32,
randao_layers=123,
status=123,
latest_status_change_slot=123,
exit_count=123,
poc_commitment=b'\x56' * 32,
last_poc_change_slot=123,
second_last_poc_change_slot=123,
)
crosslink_record = CrosslinkRecord(slot=12847, shard_block_root=b'\x67' * 32)
crosslink_record_stubs = [crosslink_record for i in range(1024)]


def make_state(num_validators):
shard_committee = ShardCommittee(
shard=1,
committee=tuple(range(num_validators // 1024)),
total_validator_count=num_validators,
)
shard_committee_stubs = tuple(tuple(shard_committee for i in range(16)) for i in range(64))
state = State(
validator_registry=tuple(validator_record for i in range(num_validators)),
shard_and_committee_for_slots=shard_committee_stubs,
latest_crosslinks=crosslink_record_stubs,
)
return state


def do_test_serialize(state, rounds=100):
for _ in range(rounds):
x = encode(state, cache=True)
return x


def do_test_serialize_no_cache(state, rounds=100):
for _ in range(rounds):
x = encode(state, cache=False)

return x


def test_encode_cache():
state = make_state(2**10)

start_time = time.time()
without_cache_result = do_test_serialize_no_cache(state)
without_cache_actual_performance = time.time() - start_time
print("Performance of serialization without cache", without_cache_actual_performance)

state = make_state(2**10)
start_time = time.time()
with_cache_result = do_test_serialize(state)
with_cache_actual_performance = time.time() - start_time
print("Performance of serialization with cache", with_cache_actual_performance)

assert with_cache_result == without_cache_result
assert with_cache_actual_performance * 10 < without_cache_actual_performance

0 comments on commit 60865ce

Please sign in to comment.