Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 33 additions & 15 deletions src/lean_spec/subspecs/networking/enr/eth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from typing import ClassVar

from lean_spec.subspecs.networking.types import ForkDigest
from lean_spec.subspecs.networking.types import ForkDigest, Version
from lean_spec.types import StrictBaseModel, Uint64
from lean_spec.types.bitfields import BaseBitvector
from lean_spec.types.boolean import Boolean
Expand All @@ -45,18 +45,18 @@ class Eth2Data(StrictBaseModel):
fork_digest: ForkDigest
"""Current active fork identifier (4 bytes)."""

next_fork_version: ForkDigest
"""Fork version of next scheduled fork. Equals current if none scheduled."""
next_fork_version: Version
"""Fork version of next scheduled fork. Equals current version if none scheduled."""

next_fork_epoch: Uint64
"""Epoch when next fork activates. FAR_FUTURE_EPOCH if none scheduled."""

@classmethod
def no_scheduled_fork(cls, current_digest: ForkDigest) -> "Eth2Data":
"""Create Eth2Data with no scheduled fork."""
def no_scheduled_fork(cls, current_digest: ForkDigest, current_version: Version) -> "Eth2Data":
"""Create Eth2Data indicating no scheduled fork."""
return cls(
fork_digest=current_digest,
next_fork_version=current_digest,
next_fork_version=current_version,
next_fork_epoch=FAR_FUTURE_EPOCH,
)

Expand All @@ -74,32 +74,32 @@ class AttestationSubnets(BaseBitvector):
@classmethod
def none(cls) -> "AttestationSubnets":
"""No subscriptions."""
return cls(data=[Boolean(False)] * 64)
return cls(data=[Boolean(False)] * cls.LENGTH)

@classmethod
def all(cls) -> "AttestationSubnets":
"""Subscribe to all 64 subnets."""
return cls(data=[Boolean(True)] * 64)
return cls(data=[Boolean(True)] * cls.LENGTH)

@classmethod
def from_subnet_ids(cls, subnet_ids: list[int]) -> "AttestationSubnets":
"""Subscribe to specific subnets."""
bits = [Boolean(False)] * 64
bits = [Boolean(False)] * cls.LENGTH
for sid in subnet_ids:
if not 0 <= sid < 64:
if not 0 <= sid < cls.LENGTH:
raise ValueError(f"Subnet ID must be 0-63, got {sid}")
bits[sid] = Boolean(True)
return cls(data=bits)

def is_subscribed(self, subnet_id: int) -> bool:
"""Check if subscribed to a subnet."""
if not 0 <= subnet_id < 64:
if not 0 <= subnet_id < self.LENGTH:
raise ValueError(f"Subnet ID must be 0-63, got {subnet_id}")
return bool(self.data[subnet_id])

def subscribed_subnets(self) -> list[int]:
"""List of subscribed subnet IDs."""
return [i for i in range(64) if self.data[i]]
return [i for i in range(self.LENGTH) if self.data[i]]

def subscription_count(self) -> int:
"""Number of subscribed subnets."""
Expand All @@ -119,15 +119,33 @@ class SyncCommitteeSubnets(BaseBitvector):
@classmethod
def none(cls) -> "SyncCommitteeSubnets":
"""No subscriptions."""
return cls(data=[Boolean(False)] * 4)
return cls(data=[Boolean(False)] * cls.LENGTH)

@classmethod
def all(cls) -> "SyncCommitteeSubnets":
"""Subscribe to all 4 subnets."""
return cls(data=[Boolean(True)] * 4)
return cls(data=[Boolean(True)] * cls.LENGTH)

@classmethod
def from_subnet_ids(cls, subnet_ids: list[int]) -> "SyncCommitteeSubnets":
"""Subscribe to specific sync subnets."""
bits = [Boolean(False)] * cls.LENGTH
for sid in subnet_ids:
if not 0 <= sid < cls.LENGTH:
raise ValueError(f"Sync subnet ID must be 0-3, got {sid}")
bits[sid] = Boolean(True)
return cls(data=bits)

def is_subscribed(self, subnet_id: int) -> bool:
"""Check if subscribed to a sync subnet."""
if not 0 <= subnet_id < 4:
if not 0 <= subnet_id < self.LENGTH:
raise ValueError(f"Sync subnet ID must be 0-3, got {subnet_id}")
return bool(self.data[subnet_id])

def subscribed_subnets(self) -> list[int]:
"""List of subscribed sync subnet IDs."""
return [i for i in range(self.LENGTH) if self.data[i]]

def subscription_count(self) -> int:
"""Number of subscribed sync subnets."""
return sum(1 for b in self.data if b)
3 changes: 3 additions & 0 deletions src/lean_spec/subspecs/networking/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
ForkDigest = Bytes4
"""4-byte fork identifier ensuring network isolation between forks."""

Version = Bytes4
"""4-byte fork version number (e.g., 0x01000000 for Phase0)."""

SeqNumber = Uint64
"""Sequence number used in ENR records, metadata, and ping messages."""

Expand Down
108 changes: 104 additions & 4 deletions tests/lean_spec/subspecs/networking/enr/test_eth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from pydantic import ValidationError

from lean_spec.subspecs.networking.enr import Eth2Data
from lean_spec.subspecs.networking.enr.eth2 import AttestationSubnets, SyncCommitteeSubnets
from lean_spec.subspecs.networking.enr.eth2 import (
FAR_FUTURE_EPOCH,
AttestationSubnets,
SyncCommitteeSubnets,
)
from lean_spec.types import Uint64
from lean_spec.types.byte_arrays import Bytes4

Expand All @@ -25,11 +29,12 @@ def test_create_eth2_data(self) -> None:
def test_no_scheduled_fork_factory(self) -> None:
"""no_scheduled_fork factory creates correct data."""
digest = Bytes4(b"\xab\xcd\xef\x01")
data = Eth2Data.no_scheduled_fork(digest)
version = Bytes4(b"\x01\x00\x00\x00")
data = Eth2Data.no_scheduled_fork(digest, version)

assert data.fork_digest == digest
assert data.next_fork_version == digest
assert data.next_fork_epoch == Uint64(2**64 - 1)
assert data.next_fork_version == version
assert data.next_fork_epoch == FAR_FUTURE_EPOCH

def test_eth2_data_immutable(self) -> None:
"""Eth2Data is immutable (frozen)."""
Expand All @@ -41,6 +46,10 @@ def test_eth2_data_immutable(self) -> None:
with pytest.raises(ValidationError):
data.fork_digest = Bytes4(b"\x00\x00\x00\x00")

def test_far_future_epoch_value(self) -> None:
"""FAR_FUTURE_EPOCH is max uint64."""
assert FAR_FUTURE_EPOCH == Uint64(2**64 - 1)


class TestAttestationSubnets:
"""Tests for AttestationSubnets bitvector."""
Expand Down Expand Up @@ -93,6 +102,39 @@ def test_invalid_subnet_id_in_is_subscribed(self) -> None:
with pytest.raises(ValueError):
subnets.is_subscribed(-1)

def test_from_subnet_ids_empty_list(self) -> None:
"""from_subnet_ids with empty list creates no subscriptions."""
subnets = AttestationSubnets.from_subnet_ids([])
assert subnets.subscription_count() == 0
assert subnets.subscribed_subnets() == []

def test_from_subnet_ids_with_duplicates(self) -> None:
"""from_subnet_ids handles duplicates correctly."""
subnets = AttestationSubnets.from_subnet_ids([5, 5, 5, 10])
assert subnets.subscription_count() == 2
assert subnets.subscribed_subnets() == [5, 10]

def test_encode_bytes_empty(self) -> None:
"""Empty subscriptions serialize to 8 zero bytes."""
subnets = AttestationSubnets.none()
assert subnets.encode_bytes() == b"\x00" * 8

def test_encode_bytes_all(self) -> None:
"""All subscriptions serialize to 8 0xff bytes."""
subnets = AttestationSubnets.all()
assert subnets.encode_bytes() == b"\xff" * 8

def test_decode_bytes_roundtrip(self) -> None:
"""Encode then decode produces equivalent result."""
original = AttestationSubnets.from_subnet_ids([0, 5, 63])
encoded = original.encode_bytes()
decoded = AttestationSubnets.decode_bytes(encoded)
assert decoded.subscribed_subnets() == original.subscribed_subnets()

def test_length_constant(self) -> None:
"""LENGTH constant is 64."""
assert AttestationSubnets.LENGTH == 64


class TestSyncCommitteeSubnets:
"""Tests for SyncCommitteeSubnets bitvector."""
Expand Down Expand Up @@ -128,3 +170,61 @@ def test_is_subscribed_raises_for_negative_id(self) -> None:
subnets = SyncCommitteeSubnets.none()
with pytest.raises(ValueError, match="must be 0-3"):
subnets.is_subscribed(-1)

def test_from_subnet_ids_specific(self) -> None:
"""from_subnet_ids() creates specific subscriptions."""
subnets = SyncCommitteeSubnets.from_subnet_ids([0, 2])
assert subnets.is_subscribed(0)
assert not subnets.is_subscribed(1)
assert subnets.is_subscribed(2)
assert not subnets.is_subscribed(3)

def test_from_subnet_ids_empty_list(self) -> None:
"""from_subnet_ids with empty list creates no subscriptions."""
subnets = SyncCommitteeSubnets.from_subnet_ids([])
assert subnets.subscription_count() == 0

def test_from_subnet_ids_with_duplicates(self) -> None:
"""from_subnet_ids handles duplicates correctly."""
subnets = SyncCommitteeSubnets.from_subnet_ids([1, 1, 1, 3])
assert subnets.subscription_count() == 2
assert subnets.subscribed_subnets() == [1, 3]

def test_from_subnet_ids_invalid(self) -> None:
"""from_subnet_ids() raises for invalid subnet IDs."""
with pytest.raises(ValueError, match="must be 0-3"):
SyncCommitteeSubnets.from_subnet_ids([4])

with pytest.raises(ValueError, match="must be 0-3"):
SyncCommitteeSubnets.from_subnet_ids([-1])

def test_subscribed_subnets(self) -> None:
"""subscribed_subnets() returns correct list."""
subnets = SyncCommitteeSubnets.from_subnet_ids([1, 3])
assert subnets.subscribed_subnets() == [1, 3]

def test_subscription_count(self) -> None:
"""subscription_count() returns correct count."""
subnets = SyncCommitteeSubnets.from_subnet_ids([0, 2, 3])
assert subnets.subscription_count() == 3

def test_encode_bytes_empty(self) -> None:
"""Empty subscriptions serialize to 1 zero byte."""
subnets = SyncCommitteeSubnets.none()
assert subnets.encode_bytes() == b"\x00"

def test_encode_bytes_all(self) -> None:
"""All subscriptions serialize to 0x0f (lower 4 bits set)."""
subnets = SyncCommitteeSubnets.all()
assert subnets.encode_bytes() == b"\x0f"

def test_decode_bytes_roundtrip(self) -> None:
"""Encode then decode produces equivalent result."""
original = SyncCommitteeSubnets.from_subnet_ids([0, 2])
encoded = original.encode_bytes()
decoded = SyncCommitteeSubnets.decode_bytes(encoded)
assert decoded.subscribed_subnets() == original.subscribed_subnets()

def test_length_constant(self) -> None:
"""LENGTH constant is 4."""
assert SyncCommitteeSubnets.LENGTH == 4
Loading