Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add encode cache #37

Merged
merged 3 commits into from
Feb 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion ssz/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
)

from ssz.sedes import (
Serializable,
sedes_by_name,
)
from ssz.sedes.base import (
Expand All @@ -13,14 +14,34 @@
)


def encode(value, sedes=None):
def encode(value, sedes=None, cache=True):
"""
Encode object in SSZ format.
`sedes` needs to be explicitly mentioned for encode/decode
of integers(as of now).
`sedes` parameter could be given as a string or as the
actual sedes object itself.

If `value` has an attribute :attr:`_cached_ssz` (as, notably,
:class:`ssz.sedes.Serializable`) and its value is not `None`, this value is
returned bypassing serialization and encoding, unless `sedes` is given (as
the cache is assumed to refer to the standard serialization which can be
replaced by specifying `sedes`).
If `value` is a :class:`ssz.sedes.Serializable` and `cache` is true, the result of
the encoding will be stored in :attr:`_cached_ssz` if it is empty.
"""
if isinstance(value, Serializable):
cached_ssz = value._cached_ssz
if sedes is None and cached_ssz is not None:
return cached_ssz
else:
really_cache = (
cache and
sedes is None
)
else:
really_cache = False

if sedes is not None:
if sedes in sedes_by_name:
# Get the actual sedes object from string representation
Expand All @@ -35,6 +56,10 @@ def encode(value, sedes=None):
sedes_obj = infer_sedes(value)

serialized_obj = sedes_obj.serialize(value)

if really_cache:
value._cached_ssz = serialized_obj

return serialized_obj


Expand Down
2 changes: 2 additions & 0 deletions ssz/sedes/serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def merge_args_to_kwargs(args, kwargs, arg_names):

class BaseSerializable(collections.Sequence):

_cached_ssz = None

def __init__(self, *args, **kwargs):
validate_args_and_kwargs(args, kwargs, self._meta.field_names)
field_values = merge_kwargs_to_args(args, kwargs, self._meta.field_names)
Expand Down
113 changes: 113 additions & 0 deletions tests/sedes/test_speed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import time

from ssz import (
encode,
)
from ssz.sedes import (
List,
Serializable,
bytes32,
uint24,
uint64,
uint384,
)


class ValidatorRecord(Serializable):
fields = [
('pubkey', uint384),
('withdrawal_credentials', bytes32),
('randao_commitment', bytes32),
('randao_layers', uint64),
('status', uint64),
('latest_status_change_slot', uint64),
('exit_count', uint64),
('poc_commitment', bytes32),
('last_poc_change_slot', uint64),
('second_last_poc_change_slot', uint64),
]


class CrosslinkRecord(Serializable):
fields = [
('slot', uint64),
('shard_block_root', bytes32),
]


class ShardCommittee(Serializable):
fields = [
('shard', uint64),
('committee', List(uint24)),
('total_validator_count', uint64),
]


class State(Serializable):
fields = [
('validator_registry', List(ValidatorRecord)),
('shard_and_committee_for_slots', List(List(ShardCommittee))),
('latest_crosslinks', List(CrosslinkRecord)),
]


validator_record = ValidatorRecord(
pubkey=123,
withdrawal_credentials=b'\x56' * 32,
randao_commitment=b'\x56' * 32,
randao_layers=123,
status=123,
latest_status_change_slot=123,
exit_count=123,
poc_commitment=b'\x56' * 32,
last_poc_change_slot=123,
second_last_poc_change_slot=123,
)
crosslink_record = CrosslinkRecord(slot=12847, shard_block_root=b'\x67' * 32)
crosslink_record_stubs = [crosslink_record for i in range(1024)]


def make_state(num_validators):
shard_committee = ShardCommittee(
shard=1,
committee=tuple(range(num_validators // 1024)),
total_validator_count=num_validators,
)
shard_committee_stubs = tuple(tuple(shard_committee for i in range(16)) for i in range(64))
state = State(
validator_registry=tuple(validator_record for i in range(num_validators)),
shard_and_committee_for_slots=shard_committee_stubs,
latest_crosslinks=crosslink_record_stubs,
)
return state


def do_test_serialize(state, rounds=100):
for _ in range(rounds):
x = encode(state, cache=True)
ChihChengLiang marked this conversation as resolved.
Show resolved Hide resolved
return x


def do_test_serialize_no_cache(state, rounds=100):
for _ in range(rounds):
x = encode(state, cache=False)

return x


def test_encode_cache():
state = make_state(2**10)

start_time = time.time()
without_cache_result = do_test_serialize_no_cache(state)
without_cache_actual_performance = time.time() - start_time
print("Performance of serialization without cache", without_cache_actual_performance)

state = make_state(2**10)
start_time = time.time()
with_cache_result = do_test_serialize(state)
with_cache_actual_performance = time.time() - start_time
print("Performance of serialization with cache", with_cache_actual_performance)

assert with_cache_result == without_cache_result
assert with_cache_actual_performance * 10 < without_cache_actual_performance