diff --git a/.gitignore b/.gitignore index 420842c6d..705422e42 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,6 @@ scripts/ # Client codebase used by Claude skill for running reference tests clients/ + +# Local agent worktrees +.claude/worktrees/ diff --git a/packages/testing/src/consensus_testing/keys.py b/packages/testing/src/consensus_testing/keys.py index 2832a5e81..1998f6ba8 100755 --- a/packages/testing/src/consensus_testing/keys.py +++ b/packages/testing/src/consensus_testing/keys.py @@ -69,7 +69,6 @@ Slot, Uint64, ValidatorIndex, - ValidatorIndices, ) KeyRole = Literal["attestation", "proposal"] @@ -534,13 +533,13 @@ def sign_and_aggregate( """ raw_xmss = [ ( + vid, self.get_public_keys(vid)[0], self.sign_attestation_data(vid, attestation_data), ) for vid in validator_ids ] return TypeOneMultiSignature.aggregate( - xmss_participants=ValidatorIndices(data=validator_ids).to_aggregation_bits(), children=[], raw_xmss=raw_xmss, message=hash_tree_root(attestation_data), @@ -599,8 +598,7 @@ def build_attestation_proofs( proofs.append( TypeOneMultiSignature.aggregate( children=[], - raw_xmss=list(zip(public_keys, signatures, strict=True)), - xmss_participants=agg.aggregation_bits, + raw_xmss=list(zip(validator_ids, public_keys, signatures, strict=True)), message=hash_tree_root(agg.data), slot=agg.data.slot, ) diff --git a/packages/testing/src/consensus_testing/test_types/block_spec.py b/packages/testing/src/consensus_testing/test_types/block_spec.py index f95b2f459..72aa4a0ec 100644 --- a/packages/testing/src/consensus_testing/test_types/block_spec.py +++ b/packages/testing/src/consensus_testing/test_types/block_spec.py @@ -276,7 +276,6 @@ def _sign_block( Complete signed block. """ block_root = hash_tree_root(final_block) - proposer_participants = ValidatorIndices(data=[proposer_index]).to_aggregation_bits() proposer_pubkey = key_manager.get_public_keys(proposer_index)[1] # The binding rejects placeholder bytes; if anything in the merged @@ -295,8 +294,7 @@ def _sign_block( ) proposer_type_1 = TypeOneMultiSignature.aggregate( children=[], - raw_xmss=[(proposer_pubkey, proposer_signature)], - xmss_participants=proposer_participants, + raw_xmss=[(proposer_index, proposer_pubkey, proposer_signature)], message=block_root, slot=self.slot, ) diff --git a/pyproject.toml b/pyproject.toml index e3bfe06ff..30f716da2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,10 +78,14 @@ known-first-party = ["lean_spec"] [tool.ruff.lint.per-file-ignores] "tests/**" = ["D"] +"src/lean_spec/subspecs/xmss/constants.py" = ["N802"] [tool.ty.environment] python-version = "3.12" +[tool.ty.src] +exclude = [".claude/"] + [tool.ty.terminal] error-on-warning = true diff --git a/src/lean_spec/forks/lstar/aggregation_select.py b/src/lean_spec/forks/lstar/aggregation_select.py new file mode 100644 index 000000000..7f3951503 --- /dev/null +++ b/src/lean_spec/forks/lstar/aggregation_select.py @@ -0,0 +1,51 @@ +"""Greedy proof selection for lstar block production.""" + +from lean_spec.subspecs.xmss.aggregation import TypeOneMultiSignature +from lean_spec.types import ValidatorIndex + + +def select_greedily( + *proof_sets: set[TypeOneMultiSignature] | None, +) -> tuple[list[TypeOneMultiSignature], set[ValidatorIndex]]: + """ + Greedy set-cover over Type-1 proofs maximizing validator coverage. + + Iterates the proof sets in order, repeatedly picking the proof with the + most uncovered validators until no further coverage is possible. + Earlier proof sets are prioritized so gossip-fresh proofs win over + already-known ones. + + The validator-index sets are materialized once per proof, not inside the + inner max key, so the loop runs in O(P * V) instead of O(P^2 * V). + + Args: + *proof_sets: One or more sets of Type-1 proofs, ordered by priority. + None entries are skipped. + + Returns: + The chosen proofs and the union of validator indices they cover. + """ + selected: list[TypeOneMultiSignature] = [] + covered: set[ValidatorIndex] = set() + + for proofs in proof_sets: + if not proofs: + continue + + # Materialize each proof's validator index set once. + # The greedy loop below would otherwise recompute it on every comparison. + coverage_of: dict[TypeOneMultiSignature, set[ValidatorIndex]] = { + p: set(p.participants.to_validator_indices()) for p in proofs + } + remaining = list(proofs) + + while remaining: + best = max(remaining, key=lambda p: len(coverage_of[p] - covered)) + new_coverage = coverage_of[best] - covered + if not new_coverage: + break + selected.append(best) + covered |= new_coverage + remaining.remove(best) + + return selected, covered diff --git a/src/lean_spec/forks/lstar/spec.py b/src/lean_spec/forks/lstar/spec.py index 66e3a2b2b..702ac3c02 100644 --- a/src/lean_spec/forks/lstar/spec.py +++ b/src/lean_spec/forks/lstar/spec.py @@ -5,6 +5,7 @@ from collections.abc import Iterable, Sequence, Set as AbstractSet from typing import Any, ClassVar +from lean_spec.forks.lstar.aggregation_select import select_greedily from lean_spec.forks.lstar.containers import ( AggregatedAttestation, AttestationData, @@ -55,7 +56,6 @@ Uint8, Uint64, ValidatorIndex, - ValidatorIndices, ) from ..protocol import ForkProtocol, SpecBlockType, SpecStateType @@ -763,7 +763,7 @@ def build_block( found_entries = True - selected, _ = TypeOneMultiSignature.select_greedily(proofs) + selected, _ = select_greedily(proofs) aggregated_signatures.extend(selected) for proof in selected: aggregated_attestations.append( @@ -836,7 +836,6 @@ def build_block( for proof in proofs ] sig = TypeOneMultiSignature.aggregate( - xmss_participants=None, children=children, raw_xmss=[], message=hash_tree_root(att_data), @@ -1661,9 +1660,7 @@ def aggregate(self, store: LstarStore) -> tuple[LstarStore, list[SignedAggregate # New payloads go first because they represent uncommitted # work — known payloads fill remaining gaps. - child_proofs, covered = TypeOneMultiSignature.select_greedily( - new.get(data), known.get(data) - ) + child_proofs, covered = select_greedily(new.get(data), known.get(data)) # Phase 2: Fill # @@ -1689,17 +1686,6 @@ def aggregate(self, store: LstarStore) -> tuple[LstarStore, list[SignedAggregate if not raw_entries and len(child_proofs) < 2: continue - # Encode raw signers as a compact bitfield when present. - # Child-only aggregation (no raw signatures) must pass None. - if raw_entries: - xmss_participants = ValidatorIndices( - data=[vid for vid, _, _ in raw_entries] - ).to_aggregation_bits() - raw_xmss = [(pk, sig) for _, pk, sig in raw_entries] - else: - xmss_participants = None - raw_xmss = [] - # Phase 3: Aggregate # # Build the recursive proof tree. @@ -1719,11 +1705,11 @@ def aggregate(self, store: LstarStore) -> tuple[LstarStore, list[SignedAggregate ] # Hand everything to the XMSS subspec. + # Each fresh entry already carries its validator index alongside its key and signature. # Out comes a single proof covering all selected validators. proof = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=children, - raw_xmss=raw_xmss, + raw_xmss=raw_entries, message=hash_tree_root(data), slot=data.slot, ) diff --git a/src/lean_spec/subspecs/sync/service.py b/src/lean_spec/subspecs/sync/service.py index 305c8dd84..5d3e1b9fa 100644 --- a/src/lean_spec/subspecs/sync/service.py +++ b/src/lean_spec/subspecs/sync/service.py @@ -617,13 +617,13 @@ def _deconstruct_block_into_store( continue try: + # The split takes the bits from the block attestation this + # component binds, since the Rust binding does not return them. block_t1 = type_two.split_by_msg( message=data_root, public_keys_per_message=public_keys_per_message, + participants=att.aggregation_bits, ) - # split_by_msg returns an empty participant bitfield; restore - # the bits from the block attestation this component binds. - block_t1 = block_t1.model_copy(update={"participants": att.aggregation_bits}) if local_proofs: combined = TypeOneMultiSignature.aggregate( @@ -638,7 +638,6 @@ def _deconstruct_block_into_store( for child in (block_t1, *local_proofs) ], raw_xmss=[], - xmss_participants=None, message=data_root, slot=data.slot, ) diff --git a/src/lean_spec/subspecs/validator/service.py b/src/lean_spec/subspecs/validator/service.py index 2cf71bf01..c0d3099f9 100644 --- a/src/lean_spec/subspecs/validator/service.py +++ b/src/lean_spec/subspecs/validator/service.py @@ -50,7 +50,7 @@ from lean_spec.subspecs.xmss import TARGET_SIGNATURE_SCHEME from lean_spec.subspecs.xmss.aggregation import TypeOneMultiSignature, TypeTwoMultiSignature from lean_spec.subspecs.xmss.containers import PublicKey, Signature -from lean_spec.types import ByteList512KiB, Bytes32, Slot, Uint64, ValidatorIndex, ValidatorIndices +from lean_spec.types import ByteList512KiB, Bytes32, Slot, Uint64, ValidatorIndex from .constants import HYSTERESIS_BAND, NETWORK_STALL_THRESHOLD, SYNC_LAG_THRESHOLD from .registry import ValidatorEntry, ValidatorRegistry @@ -447,12 +447,10 @@ def _sign_block( proposer_pubkey = validators[validator_index].get_proposal_pubkey() # Wrap the proposer's raw XMSS signature into a singleton Type-1. - # The participant set is just the proposer index. - proposer_participants = ValidatorIndices(data=[validator_index]).to_aggregation_bits() + # The single fresh entry carries the proposer index alongside its key and signature. proposer_type_1 = TypeOneMultiSignature.aggregate( children=[], - raw_xmss=[(proposer_pubkey, proposer_signature)], - xmss_participants=proposer_participants, + raw_xmss=[(validator_index, proposer_pubkey, proposer_signature)], message=block_root, slot=block.slot, ) diff --git a/src/lean_spec/subspecs/xmss/__init__.py b/src/lean_spec/subspecs/xmss/__init__.py index a699229c9..3fd1f7ed3 100644 --- a/src/lean_spec/subspecs/xmss/__init__.py +++ b/src/lean_spec/subspecs/xmss/__init__.py @@ -1,8 +1,10 @@ -""" -This package provides a Python specification for the Generalized XMSS -hash-based signature scheme. +"""Generalized XMSS hash-based signature scheme. -It exposes the core data structures and the main interface functions. +References: + - Hash-Based Multi-Signatures for Post-Quantum Ethereum. + https://eprint.iacr.org/2025/055.pdf + - Aborting Random Oracles, How to Build Them, How to Use Them. + https://eprint.iacr.org/2026/016.pdf """ from .containers import PublicKey, SecretKey diff --git a/src/lean_spec/subspecs/xmss/aggregation.py b/src/lean_spec/subspecs/xmss/aggregation.py index 6b2ba0094..4a0816d0e 100644 --- a/src/lean_spec/subspecs/xmss/aggregation.py +++ b/src/lean_spec/subspecs/xmss/aggregation.py @@ -1,13 +1,12 @@ -""" -Multi-signature aggregation for the Lean Ethereum consensus spec. +"""Multi-signature aggregation over Generalized XMSS. Two proof shapes: -- Type-1: many validators on one message (one AttestationData, or one block root). -- Type-2: a merge of N Type-1 proofs over distinct messages. -""" +- Type-1: many validators on a single message (one attestation, or one block root). +- Type-2: a merge of several Type-1 proofs over distinct messages. -from collections.abc import Sequence +The Rust binding owns proof construction and cryptographic checks. +""" from lean_multisig_py import ( aggregate_type_1, @@ -18,7 +17,7 @@ verify_type_2_with_messages, ) -from lean_spec.config import LEAN_ENV, LeanEnvMode +from lean_spec.config import LEAN_ENV from lean_spec.types import ( AggregationBits, ByteList512KiB, @@ -28,33 +27,35 @@ ValidatorIndex, ValidatorIndices, ) -from lean_spec.types.boolean import Boolean from .containers import PublicKey, Signature -LOG_INV_RATE_TEST = 1 -""" -Inverse rate exponent for test mode (fastest, biggest proofs). +LOG_INV_RATE: int = 1 if LEAN_ENV == "test" else 2 +"""Inverse-rate exponent forwarded to the SNARK backend. -This parameter is forwarded to `lean_multisig_py` prover and controls a performance/size trade-off: - -- Lower values generate proofs faster but increase proof size. -- Higher values reduce proof size but increase prover work. +- A smaller rate trades verifier cost for prover speed. +- Test mode favors prover speed. """ -LOG_INV_RATE_PROD = 2 -"""Inverse rate exponent for production mode (balanced speed vs proof size).""" +# The environment is fixed for the lifetime of the process. +# +# One setup call covers every aggregation, verification, split, and merge below. +# +# Per-call invocations then default to the mode established here. +setup_prover(mode=LEAN_ENV) class AggregationError(Exception): - """Raised when signature aggregation, merging, splitting, or verification fails.""" + """Raised when aggregation, merging, splitting, or verification fails.""" class TypeOneMultiSignature(Container): - """A single-message proof aggregating signatures from many validators. + """Single-message proof aggregating signatures from many validators. + + Every validator signs the same message for the same slot. - The signed message and slot are rederived by the verifier from the - block body it already trusts, so they live outside the proof envelope. + The message and slot stay outside the proof. + The verifier rederives them from the block body it already trusts. """ participants: AggregationBits @@ -63,275 +64,274 @@ class TypeOneMultiSignature(Container): proof: ByteList512KiB """Aggregated proof bytes in compact no-pubkeys representation.""" - @staticmethod - def select_greedily( - *proof_sets: set["TypeOneMultiSignature"] | None, - ) -> tuple[list["TypeOneMultiSignature"], set[ValidatorIndex]]: - """Greedy set-cover over Type-1 proofs to maximise validator coverage. - - Repeatedly selects the proof covering the most uncovered validators - until no proof adds new coverage. Earlier proof sets are - prioritised: gossip-fresh proofs win over already-known ones. - """ - selected: list[TypeOneMultiSignature] = [] - covered: set[ValidatorIndex] = set() - - for proofs in proof_sets: - if not proofs: - continue - - remaining = list(proofs) - - while remaining: - best = max( - remaining, - key=lambda p: len(set(p.participants.to_validator_indices()) - covered), - ) - new_coverage = set(best.participants.to_validator_indices()) - covered - - if not new_coverage: - break - - selected.append(best) - covered |= new_coverage - remaining.remove(best) - - return selected, covered - - @staticmethod + @classmethod def aggregate( - children: Sequence[tuple["TypeOneMultiSignature", Sequence[PublicKey]]], - raw_xmss: Sequence[tuple[PublicKey, Signature]], - xmss_participants: AggregationBits | None, + cls, + children: list[tuple["TypeOneMultiSignature", list[PublicKey]]], + raw_xmss: list[tuple[ValidatorIndex, PublicKey, Signature]], message: Bytes32, slot: Slot, - mode: LeanEnvMode | None = None, ) -> "TypeOneMultiSignature": - """Aggregate raw XMSS signatures and child Type-1 proofs into one Type-1 proof. + """Fold fresh signatures and child proofs into one single-message proof. - Proof bytes are stored in compact no-pubkeys form. Participant identity is - tracked separately in participants (attestation bits on the wire). - """ - if not raw_xmss and not children: - raise AggregationError("At least one raw signature or child proof is required") + # Overview - if raw_xmss and xmss_participants is None: - raise AggregationError("xmss_participants is required when raw_xmss is provided") + Two kinds of contribution merge into one proof. - if not raw_xmss and len(children) < 2: - raise AggregationError( - "At least two child proofs are required when no raw signatures are provided" - ) + - A fresh signer contributes a single raw signature. + - A child proof contributes an already-aggregated bundle of signers. - aggregated_validator_ids: set[ValidatorIndex] = set() - if xmss_participants is not None: - aggregated_validator_ids.update(xmss_participants.to_validator_indices()) + The result names the union of every contributing validator. + The prover compresses all contributions into one proof over the shared message. - if len(aggregated_validator_ids) != len(raw_xmss): - raise AggregationError("Raw signature count does not match XMSS participant count") + # Why the index travels with each fresh signer - # Include child participants in the aggregated participants - for child, _ in children: - aggregated_validator_ids.update(child.participants.to_validator_indices()) - participants = ValidatorIndices(data=sorted(aggregated_validator_ids)).to_aggregation_bits() + A public key carries no validator index on its own. + Pairing the index with each fresh entry lets the bitfield be derived, not passed in. + An empty list of fresh signers simply contributes no indices. - mode = mode or LEAN_ENV - setup_prover(mode=mode) - log_inv_rate = LOG_INV_RATE_TEST if mode == "test" else LOG_INV_RATE_PROD + Args: + children: Child proofs, each paired with the public keys it names. + raw_xmss: Fresh entries, each carrying its validator index, public key, and signature. + message: The 32-byte message every signer signed. + slot: The slot every signer signed for. - raw_pubkeys_ssz = [pk.encode_bytes() for pk, _ in raw_xmss] - raw_signatures_ssz = [sig.encode_bytes() for _, sig in raw_xmss] + Returns: + A single-message proof covering the union of all participants. - children_bytes: list[tuple[list[bytes], bytes]] = [] - for idx, (child, child_public_keys_raw) in enumerate(children): - child_public_keys = list(child_public_keys_raw) - expected = child.participants.data.count(Boolean(1)) - if len(child_public_keys) != expected: - raise AggregationError( - f"Type-1 aggregate child {idx} expected {expected} pubkeys, " - f"got {len(child_public_keys)}" - ) - - child_pks_ssz = [pk.encode_bytes() for pk in child_public_keys] - child_wire = bytes(child.proof.data) - if not child_wire: - raise AggregationError(f"Child proof {idx} has empty proof bytes") - children_bytes.append((child_pks_ssz, child_wire)) + Raises: + AggregationError: When the prover rejects the inputs. + """ + # Phase 1: union every contributing validator index. + # + # Fresh signers bring their own index. + # Child proofs expose theirs through the participant bitfield. + all_indices = {vid for vid, _, _ in raw_xmss}.union( + *(child.participants.to_validator_indices() for child, _ in children) + ) + participants = ValidatorIndices(data=sorted(all_indices)).to_aggregation_bits() + + # Phase 2: serialize inputs to the prover's wire format. + raw_pubkeys_ssz = [pk.encode_bytes() for _, pk, _ in raw_xmss] + raw_signatures_ssz = [sig.encode_bytes() for _, _, sig in raw_xmss] + children_bytes = [ + ([pk.encode_bytes() for pk in pubkeys], bytes(child.proof.data)) + for child, pubkeys in children + ] + # Phase 3: hand off to the Rust prover. + # The mode argument routes the call to the matching backend bytecode. try: _, type1_wire = aggregate_type_1( raw_pubkeys_ssz, raw_signatures_ssz, bytes(message), int(slot), - log_inv_rate, - children_bytes if children_bytes else None, - mode=mode, + LOG_INV_RATE, + children_bytes or None, + mode=LEAN_ENV, ) except Exception as exc: - raise AggregationError(f"Type-1 aggregation failed: {exc}") from exc + raise AggregationError(str(exc)) from exc - return TypeOneMultiSignature( - participants=participants, - proof=ByteList512KiB(data=type1_wire), - ) + return cls(participants=participants, proof=ByteList512KiB(data=type1_wire)) def verify( self, - public_keys: Sequence[PublicKey], + public_keys: list[PublicKey], message: Bytes32, slot: Slot, - mode: LeanEnvMode | None = None, ) -> None: - """Verify this single-message Type-1 proof against a resolved set of pubkeys.""" - mode = mode or LEAN_ENV - setup_prover(mode=mode) + """Verify this single-message Type-1 proof against a pubkey set. - expected = self.participants.data.count(Boolean(1)) + Args: + public_keys: Pubkeys for the validators named by participants. + message: Message bound by the proof. + slot: Slot bound by the proof. + + Raises: + AggregationError: When the pubkey count does not match the bitfield + or the Rust verifier rejects the proof. + """ + # The bitfield names one validator per set bit. + # The caller must supply exactly that many keys, in the same order. + # A miscount would otherwise fail deep in the verifier with an opaque error. + expected = len(self.participants.to_validator_indices()) if len(public_keys) != expected: raise AggregationError( f"Type-1 verify expected {expected} pubkeys for participants, " f"got {len(public_keys)}" ) - pks_ssz = [pk.encode_bytes() for pk in public_keys] + # Hand the resolved keys, message, and slot to the Rust verifier. + # The mode argument selects the matching backend bytecode. try: verify_type_1( - pks_ssz, + [pk.encode_bytes() for pk in public_keys], bytes(message), int(slot), bytes(self.proof.data), - mode=mode, + mode=LEAN_ENV, ) except Exception as exc: raise AggregationError(f"Type-1 verification failed: {exc}") from exc class TypeTwoMultiSignature(Container): - """A merged proof covering many distinct messages. + """Merged proof covering many distinct messages. - On the wire a SignedBlock carries the SSZ-serialised form of this - container as its single proof blob. + Each component is a single-message proof over its own message. + Merging binds the components into one proof the block can carry whole. + + A signed block stores this proof as a single serialized blob. """ proof: ByteList512KiB """Compact no-pubkeys serialized Type-2 proof bytes.""" - @staticmethod + @classmethod def aggregate( - parts: Sequence[TypeOneMultiSignature], - public_keys_per_part: Sequence[Sequence[PublicKey]] | None = None, - mode: LeanEnvMode | None = None, + cls, + parts: list[TypeOneMultiSignature], + public_keys_per_part: list[list[PublicKey]], ) -> "TypeTwoMultiSignature": - """Merge several Type-1 proofs (each over a distinct message) into one Type-2 proof. + """Merge several single-message proofs over distinct messages into one. + + # Why the public keys are passed in + + - A merged proof stores no public keys. + - The prover needs them as external context to fold the components together. + - They cannot be recovered from the proofs, so the caller supplies them. - The returned Type-2 proof bytes are stored in compact no-pubkeys form. + Args: + parts: The single-message proofs to merge, one per distinct message. + public_keys_per_part: Public keys for each component, in the same order as the proofs. + + Returns: + A merged proof binding every component to its own message. + + Raises: + AggregationError: When no proofs are given, a pubkey list disagrees + with its participant count, or the prover rejects the inputs. """ if not parts: raise AggregationError("Type-2 aggregate requires at least one Type-1 input") - mode = mode or LEAN_ENV - setup_prover(mode=mode) - log_inv_rate = LOG_INV_RATE_TEST if mode == "test" else LOG_INV_RATE_PROD - - if public_keys_per_part is not None and len(public_keys_per_part) != len(parts): - raise AggregationError( - f"Type-2 aggregate expected pubkeys for {len(parts)} parts, " - f"got {len(public_keys_per_part)}" - ) - + # Each component carries the public keys named by its bitfield, in the same order. + # + # A miscount would otherwise fail deep in the prover with an opaque error. type1_entries: list[tuple[list[bytes], bytes]] = [] - for idx, part in enumerate(parts): - expected = part.participants.data.count(Boolean(1)) - if public_keys_per_part is None: - raise AggregationError( - "public_keys_per_part is required when Type-1 proofs are stored without pubkeys" - ) - pubkeys = list(public_keys_per_part[idx]) + for idx, (part, pubkeys) in enumerate(zip(parts, public_keys_per_part, strict=True)): + expected = len(part.participants.to_validator_indices()) if len(pubkeys) != expected: raise AggregationError( f"Type-2 aggregate entry {idx} expected {expected} pubkeys, got {len(pubkeys)}" ) - pks_ssz = [pk.encode_bytes() for pk in pubkeys] - type1_entries.append((pks_ssz, bytes(part.proof.data))) + type1_entries.append(([pk.encode_bytes() for pk in pubkeys], bytes(part.proof.data))) + # Hand the per-component keys and proof bytes to the Rust prover. + # + # The mode argument selects the matching backend bytecode. try: - _, type2_wire = merge_many_type_1(type1_entries, log_inv_rate, mode=mode) + _, type2_wire = merge_many_type_1(type1_entries, LOG_INV_RATE, mode=LEAN_ENV) except Exception as exc: - raise AggregationError(f"Type-2 aggregation failed: {exc}") from exc + raise AggregationError(str(exc)) from exc - return TypeTwoMultiSignature(proof=ByteList512KiB(data=type2_wire)) + return cls(proof=ByteList512KiB(data=type2_wire)) def split_by_msg( self, message: Bytes32, - public_keys_per_message: Sequence[Sequence[PublicKey]], - mode: LeanEnvMode | None = None, + public_keys_per_message: list[list[PublicKey]], + participants: AggregationBits, ) -> TypeOneMultiSignature: - """Recover the Type-1 proof bound to a specific message from this Type-2 merge. + """Recover the Type-1 proof bound to one message from this Type-2 merge. - public_keys_per_message defines the per-component pubkey layout the - Type-2 was built with. - """ - mode = mode or LEAN_ENV - setup_prover(mode=mode) - log_inv_rate = LOG_INV_RATE_TEST if mode == "test" else LOG_INV_RATE_PROD + # Why the layout and participants are passed in + + - A merged proof stores neither the public keys nor the participant bitfields. + - The prover needs the original key layout to isolate one component. + - The caller supplies both, drawn from the block attestation this component binds. + Args: + message: Message that selects the Type-1 component. + public_keys_per_message: Pubkey layout this Type-2 was built with. + participants: Bitfield naming the validators of the recovered component. + + Returns: + The Type-1 proof bound to the message. + + Raises: + AggregationError: When the Rust binding rejects the split. + """ + # Each component carries the public keys named by its bitfield, in the same order. pub_keys_per_component_ssz: list[list[bytes]] = [ [pk.encode_bytes() for pk in pks] for pks in public_keys_per_message ] + # Hand the key layout, merged proof, and selector message to the Rust prover. + # + # The mode argument selects the matching backend bytecode. try: _, type1_wire = split_type_2_by_msg( pub_keys_per_component_ssz, bytes(self.proof.data), bytes(message), - log_inv_rate, - mode=mode, + LOG_INV_RATE, + mode=LEAN_ENV, ) except Exception as exc: - raise AggregationError(f"Type-2 split-by-message failed: {exc}") from exc + raise AggregationError(f"Type-2 split failed: {exc}") from exc return TypeOneMultiSignature( - participants=AggregationBits(data=[]), + participants=participants, proof=ByteList512KiB(data=type1_wire), ) def verify( self, - public_keys_per_message: Sequence[Sequence[PublicKey]], - messages: Sequence[tuple[Bytes32, Slot]], - mode: LeanEnvMode | None = None, + public_keys_per_message: list[list[PublicKey]], + messages: list[tuple[Bytes32, Slot]], ) -> None: - """Verify this multi-message Type-2 proof. - - Each entry of public_keys_per_message corresponds to one Type-1 - component merged into this Type-2. - The parallel messages entry binds that component to a specific - message hash and slot. - Without this binding the proof would verify against any attacker - chosen attestation data that resolves to the same pubkeys. - """ - mode = mode or LEAN_ENV - setup_prover(mode=mode) + """Verify this multi-message proof against its per-component bindings. + + # The message bindings + Each component is checked against one message and slot supplied by the caller. + Without that binding the proof would accept attacker-chosen data resolving to the same keys. + The parallel lists pin every component to the message it actually signed. + + Args: + public_keys_per_message: Public keys for each component, in component order. + messages: Message-slot pair each component is bound to, parallel to the keys. + + Raises: + AggregationError: When the two lists disagree in length, or the verifier rejects. + """ + # Each component needs exactly one message-slot binding. + # + # A length mismatch would leave components unbound or misaligned. if len(messages) != len(public_keys_per_message): raise AggregationError( f"Type-2 verify expected {len(public_keys_per_message)} message bindings, " f"got {len(messages)}" ) + # Serialize the key layout and the per-component message bindings. pub_keys_per_component_ssz: list[list[bytes]] = [ [pk.encode_bytes() for pk in pks] for pks in public_keys_per_message ] expected_messages = [(bytes(msg), int(slot)) for msg, slot in messages] + # Hand the layout, bindings, and merged proof to the Rust verifier. + # + # The mode argument selects the matching backend bytecode. try: verify_type_2_with_messages( pub_keys_per_component_ssz, expected_messages, bytes(self.proof.data), - mode=mode, + mode=LEAN_ENV, ) except Exception as exc: raise AggregationError(f"Type-2 verification failed: {exc}") from exc diff --git a/src/lean_spec/subspecs/xmss/constants.py b/src/lean_spec/subspecs/xmss/constants.py index f6dd295a1..a95d6ec56 100644 --- a/src/lean_spec/subspecs/xmss/constants.py +++ b/src/lean_spec/subspecs/xmss/constants.py @@ -1,15 +1,4 @@ -""" -Defines the cryptographic constants and configuration presets for the -XMSS spec. - -This specification corresponds to the "hashing-optimized" Top Level Target Sum -instantiation from the canonical Rust implementation -(production instantiation). - -We also provide a test instantiation for testing purposes. -""" - -from __future__ import annotations +"""Cryptographic constants and configuration presets for the XMSS spec.""" import math from typing import Final @@ -26,51 +15,34 @@ class XmssConfig(StrictBaseModel): """A model holding the configuration constants for an XMSS preset.""" - # --- Core Scheme Configuration --- MESSAGE_LENGTH: int """The length in bytes for all messages to be signed.""" LOG_LIFETIME: int """The base-2 logarithm of the scheme's maximum lifetime.""" - @property - def LIFETIME(self) -> Uint64: # noqa: N802 - """ - The maximum number of slots supported by this configuration. - - An individual key pair can be active for a smaller sub-range. - """ - return Uint64(1 << self.LOG_LIFETIME) - DIMENSION: int - """The total number of hash chains, `v`.""" + """The total number of hash chains, v.""" BASE: int """The alphabet size for the digits of the encoded message.""" Z: int - """Number of base-`BASE` digits extracted from each field element.""" + """Number of base-BASE digits extracted from each field element.""" Q: int - """Quotient such that `Q * BASE^Z == P - 1`.""" + """Quotient such that Q * BASE^Z == P - 1.""" TARGET_SUM: int """The required sum of all codeword chunks for a signature to be valid.""" MAX_TRIES: int - """ - How often one should try at most to resample a random value. - - This is currently based on experiments with the Rust implementation. - Should probably be modified in production. - """ + """How often one should try at most to resample a random value.""" PARAMETER_LEN: int - """ - The length of the public parameter `P`. + """The length of the public parameter P. - It is used to specialize the hash function. - """ + It is used to specialize the hash function.""" TWEAK_LEN_FE: int """The length of a domain-separating tweak.""" @@ -79,7 +51,7 @@ def LIFETIME(self) -> Uint64: # noqa: N802 """The length of a message after being encoded into field elements.""" RAND_LEN_FE: int - """The length of the randomness `rho` used during message encoding.""" + """The length of the randomness rho used during message encoding.""" HASH_LEN_FE: int """The output length of the main tweakable hash function.""" @@ -88,33 +60,53 @@ def LIFETIME(self) -> Uint64: # noqa: N802 """The capacity of the Poseidon1 sponge, defining its security level.""" @model_validator(mode="after") - def _validate_decomposition(self) -> XmssConfig: + def _validate_decomposition(self) -> "XmssConfig": """Verify that Q * BASE^Z == P - 1.""" if self.Q * self.BASE**self.Z != P - 1: raise ValueError(f"Q * BASE^Z must equal P-1={P - 1}") return self @property - def MH_HASH_LEN_FE(self) -> int: # noqa: N802 + def LIFETIME(self) -> Uint64: + """The maximum number of slots supported by this configuration. + + An individual key pair can be active for a smaller sub-range. + """ + return Uint64(1 << self.LOG_LIFETIME) + + @property + def LEAVES_PER_BOTTOM_TREE(self) -> int: + """Slots covered by one bottom tree, W = sqrt(LIFETIME) = 2^(LOG_LIFETIME / 2).""" + return 1 << (self.LOG_LIFETIME // 2) + + @property + def MH_HASH_LEN_FE(self) -> int: """Number of Poseidon output field elements needed for the aborting decode.""" return math.ceil(self.DIMENSION / self.Z) @property - def SIGNATURE_LEN_BYTES(self) -> int: # noqa: N802 - """ - The SSZ-encoded size of a signature in bytes. + def SIGNATURE_LEN_BYTES(self) -> int: + """The SSZ-encoded size of a signature in bytes. - Includes raw field data plus SSZ offset overhead for variable-size fields: + # Layout - - Signature container: 2 offsets (path, hashes) - - HashTreeOpening container: 1 offset (siblings) + authentication path : one sibling digest per tree level (variable) + encoding randomness : fixed run of field elements (fixed) + released chain ends : one digest per hash chain (variable) """ - # Raw data sizes + # One sibling digest per level climbed from leaf to root. path_siblings_size = self.LOG_LIFETIME * self.HASH_LEN_FE * P_BYTES rho_size = self.RAND_LEN_FE * P_BYTES + # One released chain end per chain, so the count is the scheme dimension. hashes_size = self.DIMENSION * self.HASH_LEN_FE * P_BYTES - # SSZ offset overhead: 3 variable fields × 4 bytes each + # SSZ writes a four-byte offset ahead of each variable-length field. + # + # path -> offset 1 (top level) + # chain ends -> offset 2 (top level) + # siblings -> offset 3 (nested inside the path) + # + # The randomness is fixed-length, so it carries no offset. ssz_offset_overhead = 3 * BYTES_PER_LENGTH_OFFSET return path_siblings_size + rho_size + hashes_size + ssz_offset_overhead @@ -170,10 +162,5 @@ def SIGNATURE_LEN_BYTES(self) -> int: # noqa: N802 PRF_KEY_LENGTH: Final = 32 """The length of the PRF secret key in bytes.""" -_LEAN_ENV_TO_CONFIG = { - "test": TEST_CONFIG, - "prod": PROD_CONFIG, -} - -TARGET_CONFIG: Final = _LEAN_ENV_TO_CONFIG[LEAN_ENV] -"""The active XMSS configuration based on LEAN_ENV environment variable.""" +TARGET_CONFIG: Final = TEST_CONFIG if LEAN_ENV == "test" else PROD_CONFIG +"""Active configuration selected at import time from the LEAN_ENV environment variable.""" diff --git a/src/lean_spec/subspecs/xmss/containers.py b/src/lean_spec/subspecs/xmss/containers.py index 2057bc6c8..06aaef326 100644 --- a/src/lean_spec/subspecs/xmss/containers.py +++ b/src/lean_spec/subspecs/xmss/containers.py @@ -11,13 +11,13 @@ from ...types.container import Container from ...types.exceptions import SSZError from .constants import TARGET_CONFIG -from .subtree import HashSubTree +from .merkle import HashSubTree +from .prf import PRFKey from .types import ( HashDigestList, HashDigestVector, HashTreeOpening, Parameter, - PRFKey, Randomness, ) @@ -33,7 +33,7 @@ class PublicKey(Container): class Signature(Container): - """A single XMSS signature for one (slot, message) under one public key.""" + """A single XMSS signature for one slot and message under one public key.""" path: HashTreeOpening """Authentication path from the one-time key up to the Merkle root.""" @@ -58,56 +58,57 @@ def get_byte_length(cls) -> int: @model_serializer(mode="plain", when_used="json") def _serialize_as_bytes(self) -> str: - """Serialize as "0x"-prefixed hex of the SSZ bytes for JSON output.""" + """Serialize as a 0x-prefixed hex string for JSON output.""" return "0x" + self.encode_bytes().hex() class SecretKey(Container): - """ - Private state of an XMSS key pair. MUST BE KEPT CONFIDENTIAL. - - The tree of one-time keys is split into a top tree plus many bottom trees. - - Each bottom tree: - - has width W = 2^(LOG_LIFETIME / 2), - - covers W slots. + """Private state of an XMSS key pair. - Bottom tree i covers slots [i*W, (i+1)*W). + Must be kept confidential. - The signer keeps the full top tree resident, plus two adjacent bottom trees. + # Tree layout - This means that we have a sliding window of 2W consecutive slots. + - The one-time keys split into one top tree over many bottom trees. + - Each bottom tree spans W = 2^(LOG_LIFETIME / 2) slots. + - Bottom tree i covers the W slots starting at i times W. - These slots can be signed immediately without recomputing anything. + # Sliding window - Memory stays O(sqrt(LIFETIME)) regardless of how long the key lives. + - The signer keeps the top tree resident plus two adjacent bottom trees. + - That window can sign 2W consecutive slots. + - Resident memory stays near the square root of the lifetime. """ prf_key: PRFKey - """Master secret seed; every one-time key is derived from this.""" + """Master secret seed. + + Every one-time key is derived from this. + """ parameter: Parameter """Public parameter mirrored here so signing is self-contained.""" activation_slot: Slot - """ - First slot this key can sign for. + """First slot this key can sign for. - Aligned down to a multiple of W so each bottom tree covers exactly W slots. + - Aligned down to a multiple of W. + - Each bottom tree then covers exactly W slots. """ num_active_slots: Uint64 - """ - Number of consecutive slots this key can sign for. + """Number of consecutive slots this key can sign for. - Rounded up to a multiple of W, with a minimum of 2W so the prepared window always fits. + - Rounded up to a multiple of W, with a minimum of 2W. + - The prepared window then always fits. """ top_tree: HashSubTree """ Full top tree, always resident. - Its lowest layer holds the bottom-tree roots; its top is the public-key root. + - Its lowest layer holds the bottom-tree roots. + - Its top layer is the public-key root. """ left_bottom_tree_index: Uint64 @@ -137,7 +138,10 @@ class KeyPair(StrictBaseModel): @field_validator("public_key", mode="before") @classmethod def _decode_public_key(cls, value: object) -> object: - """Decode a hex string into a PublicKey; pass other inputs through.""" + """Decode hex strings to a public key. + + Other input shapes pass through unchanged. + """ if not isinstance(value, str): return value try: @@ -148,7 +152,10 @@ def _decode_public_key(cls, value: object) -> object: @field_validator("secret_key", mode="before") @classmethod def _decode_secret_key(cls, value: object) -> object: - """Decode a hex string into a SecretKey; pass other inputs through.""" + """Decode hex strings to a secret key. + + Other input shapes pass through unchanged. + """ if not isinstance(value, str): return value try: @@ -163,25 +170,16 @@ def _encode_hex(self, value: PublicKey | SecretKey) -> str: class ValidatorKeyPair(StrictBaseModel): - """ - Two independent XMSS key pairs for one validator's two signing roles. + """Two independent XMSS key pairs for one validator's two signing roles. A validator signs two messages per slot: - - one attestation (always), - - one block proposal (only when chosen as proposer). - - A one-time signature exhausts a leaf, so a single XMSS key cannot serve - both roles in the same slot. - Splitting into two independent pairs lets a validator sign both - messages from independent Winternitz chains. - - JSON shape on disk: + - one attestation, always, + - one block proposal, only when chosen as proposer. - { - "attestation_keypair": {"public_key": "", "secret_key": ""}, - "proposal_keypair": {"public_key": "", "secret_key": ""} - } + A one-time signature exhausts a leaf. + So one key cannot cover both roles in the same slot. + Two independent pairs let each role sign from its own Winternitz chains. """ attestation_keypair: KeyPair diff --git a/src/lean_spec/subspecs/xmss/encoding.py b/src/lean_spec/subspecs/xmss/encoding.py new file mode 100644 index 000000000..0dd35aba7 --- /dev/null +++ b/src/lean_spec/subspecs/xmss/encoding.py @@ -0,0 +1,177 @@ +"""Message-to-codeword pipeline for the Generalized XMSS scheme. + +# Overview + +The pipeline has two layers: + +- Encoding maps a message to a codeword, the vector of digits a signature commits to. +- Decoding is the inner step that unpacks hash field elements into those digits. + +A codeword is a vertex of a high-dimensional hypercube. +Encoding accepts the vertex only when its digits lie on the target-sum layer. +A signer retries with fresh randomness until the filter accepts. + +Concretely, the test preset uses 4 digits in base 8 with target sum 6: + +- The vector [5, 0, 1, 0] sums to 6, lands on the layer, and is accepted. +- The vector [2, 2, 2, 2] sums to 8, misses the layer, and forces a retry. + +# Construction + +The hypercube decode is the aborting encoding of Aborting Random Oracles, Section 6.1. +https://eprint.iacr.org/2026/016.pdf + +The target-sum filter is the top-level acceptance test from the canonical Rust instantiation. + +# Why the decode can abort + +The decode turns each uniform field element into uniform base-BASE digits. +Uniform output is what lets the security analysis model the hash as a random oracle. + +A field element takes one of the P values from 0 to P - 1. +The prime is chosen so that P - 1 = Q * BASE^Z. +The values 0 to P - 2 therefore form BASE^Z groups of Q consecutive integers. + +Integer division by Q maps each group to one quotient in 0 to BASE^Z - 1: + + 0 .. Q-1 -> quotient 0 + Q .. 2Q-1 -> quotient 1 + ... + P-1-Q .. P-2 -> quotient BASE^Z - 1 + P-1 -> no quotient + +Each quotient expands into Z base-BASE digits. +Every quotient is equally likely, which makes the digits uniform. + +The value P - 1 falls outside every group. +The decode rejects it, a rare event near 4.7e-10 that barely affects signing. +""" + +from lean_spec.types import Bytes32, Uint64 + +from ..koalabear import Fp +from .constants import TWEAK_PREFIX_MESSAGE, XmssConfig +from .field import int_to_base_p +from .poseidon import PoseidonXmss +from .types import Parameter, Randomness + + +def encode_message(config: XmssConfig, message: Bytes32) -> list[Fp]: + """Encode a 32-byte message into field elements via base-P decomposition. + + The bytes are read little-endian as a single integer. + """ + acc = int.from_bytes(message, "little") + return int_to_base_p(acc, config.MSG_LEN_FE) + + +def encode_epoch(config: XmssConfig, epoch: Uint64) -> list[Fp]: + """Encode the epoch and the message-hash subdomain into field elements. + + The 8-bit prefix separates the message-hash subdomain from the chain and tree subdomains. + """ + # Layout: + # + # (epoch << 8) | MESSAGE_PREFIX + acc = (int(epoch) << 8) | TWEAK_PREFIX_MESSAGE + return int_to_base_p(acc, config.TWEAK_LEN_FE) + + +def aborting_decode(config: XmssConfig, field_elements: list[Fp]) -> list[int] | None: + """Reject-sample each field element into base-BASE digits. + + For each element A_i: + + 1. If A_i >= Q * BASE^Z, that is A_i == P - 1, abort and return None. + 2. Compute d_i = A_i // Q in [0, BASE^Z - 1]. + 3. Emit Z base-BASE digits of d_i, least significant first. + + Return the first DIMENSION digits. + """ + threshold = config.Q * config.BASE**config.Z + + digits: list[int] = [] + for fe in field_elements: + a = int(fe) + + # The only rejection case is A_i == P - 1. + if a >= threshold: + return None + + # Quotient by Q strips the residue. + # The remainder is uniform in [0, BASE^Z - 1]. + d = a // config.Q + for _ in range(config.Z): + d, digit = divmod(d, config.BASE) + digits.append(digit) + + return digits[: config.DIMENSION] + + +def message_hash( + poseidon: PoseidonXmss, + config: XmssConfig, + parameter: Parameter, + epoch: Uint64, + rho: Randomness, + message: Bytes32, +) -> list[int] | None: + """Hash the inputs with Poseidon1 and decode into a candidate codeword. + + Args: + poseidon: Cached Poseidon1 engine. + config: Active XMSS configuration. + parameter: Public parameter P. + epoch: Current epoch. + rho: Per-attempt randomness. + message: Message being signed. + + Returns: + Codeword of DIMENSION digits in [0, BASE-1], or None on rejection. + """ + # Encode the message and epoch as field elements before hashing. + message_fe = encode_message(config, message) + epoch_fe = encode_epoch(config, epoch) + + # One Poseidon1 call produces enough output for the aborting decode. + base_input = message_fe + parameter.elements + epoch_fe + rho.elements + poseidon_output = poseidon.compress(base_input, 24, config.MH_HASH_LEN_FE) + + return aborting_decode(config, poseidon_output) + + +def target_sum_encode( + poseidon: PoseidonXmss, + config: XmssConfig, + parameter: Parameter, + message: Bytes32, + rho: Randomness, + epoch: Uint64, +) -> list[int] | None: + """Encode a message into a codeword if it meets the target sum. + + The signer retries with fresh randomness on rejection. + + Args: + poseidon: Cached Poseidon1 engine. + config: Active XMSS configuration. + parameter: Public parameter for domain separation. + message: Message being signed. + rho: Per-attempt randomness. + epoch: Current epoch. + + Returns: + Codeword on success, None when the attempt must be retried. + """ + # Phase 1: aborting hypercube decode of the Poseidon1 output. + codeword_candidate = message_hash(poseidon, config, parameter, epoch, rho, message) + if codeword_candidate is None: + return None + + # Phase 2: target-sum acceptance condition. + # A valid codeword is a vertex on the hypercube layer whose digit sum is TARGET_SUM. + if sum(codeword_candidate) == config.TARGET_SUM: + return codeword_candidate + + # The caller retries with new randomness. + return None diff --git a/src/lean_spec/subspecs/xmss/field.py b/src/lean_spec/subspecs/xmss/field.py new file mode 100644 index 000000000..053b0b8bc --- /dev/null +++ b/src/lean_spec/subspecs/xmss/field.py @@ -0,0 +1,32 @@ +"""Field-element decomposition and secure sampling for the Generalized XMSS scheme.""" + +import secrets + +from ..koalabear import Fp, P +from .constants import XmssConfig +from .types import HashDigestVector, Parameter + + +def int_to_base_p(value: int, num_limbs: int) -> list[Fp]: + """Decompose an integer into a fixed-size list of base-P field elements.""" + acc = value + limbs: list[Fp] = [] + for _ in range(num_limbs): + limbs.append(Fp(value=acc)) + acc //= P + return limbs + + +def random_field_elements(length: int) -> list[Fp]: + """Sample a list of secure-random field elements in [0, P).""" + return [Fp(value=secrets.randbelow(P)) for _ in range(length)] + + +def random_parameter(config: XmssConfig) -> Parameter: + """Sample a fresh public parameter for one XMSS key pair.""" + return Parameter(data=random_field_elements(config.PARAMETER_LEN)) + + +def random_domain(config: XmssConfig) -> HashDigestVector: + """Sample a fresh hash-digest-sized vector of field elements.""" + return HashDigestVector(data=random_field_elements(config.HASH_LEN_FE)) diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index 4a45a152f..361863456 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -1,175 +1,144 @@ -""" -Defines the core interface for the Generalized XMSS signature scheme. - -Specification for the high-level functions (`key_gen`, `sign`, `verify`). - -This constitutes the public API of the signature scheme. -""" - -from __future__ import annotations +"""Public interface for the Generalized XMSS signature scheme.""" from lean_spec.config import LEAN_ENV -from lean_spec.subspecs.xmss.target_sum import ( - PROD_TARGET_SUM_ENCODER, - TEST_TARGET_SUM_ENCODER, - TargetSumEncoder, -) from lean_spec.types import Bytes32, Slot, StrictBaseModel, Uint64 -from .constants import ( - PROD_CONFIG, - TEST_CONFIG, - XmssConfig, -) +from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig from .containers import KeyPair, PublicKey, SecretKey, Signature -from .prf import PROD_PRF, TEST_PRF, Prf -from .rand import PROD_RAND, TEST_RAND, Rand -from .subtree import HashSubTree, combined_path, verify_path -from .tweak_hash import ( - PROD_TWEAK_HASHER, - TEST_TWEAK_HASHER, - TweakHasher, -) +from .encoding import target_sum_encode +from .field import random_parameter +from .merkle import HashSubTree, combined_path, verify_path +from .poseidon import PROD_POSEIDON, TEST_POSEIDON, PoseidonXmss +from .prf import PRFKey from .types import HashDigestList, HashDigestVector -from .utils import expand_activation_time - - -class GeneralizedXmssScheme(StrictBaseModel): - """ - Instance of the Generalized XMSS signature scheme for a given config. - - This class holds the configuration and component instances needed to - perform key generation, signing, and verification operations. - """ - config: XmssConfig - """Configuration parameters for the XMSS scheme.""" - prf: Prf - """Pseudorandom function for deriving secret values.""" +def _expand_activation_time( + log_lifetime: int, desired_activation_slot: int, desired_num_active_slots: int +) -> tuple[int, int]: + """Snap a requested slot window onto whole bottom trees. - hasher: TweakHasher - """Hash function with tweakable domain separation.""" + # Overview - encoder: TargetSumEncoder - """Message encoder that produces valid codewords.""" + A bottom tree covers C consecutive slots, where C is the square root of the lifetime. + The lifetime is C such trees laid end to end, so C * C slots in total. - rand: Rand - """Random data generator for key generation.""" + Args: + log_lifetime: Base-2 logarithm of the lifetime in slots. + desired_activation_slot: First slot the caller wants to sign. + desired_num_active_slots: Number of slots the caller wants to sign. - def key_gen(self, activation_slot: Slot, num_active_slots: Uint64) -> KeyPair: - """ - Generates a new cryptographic key pair for a specified range of slots. - - This is a **randomized** algorithm that establishes a signer's identity using - the memory-efficient Top-Bottom Tree Traversal approach. + Returns: + The half-open bottom-tree index range (start, end). + It covers slots [start * C, end * C). + """ + # C is one bottom tree's worth of slots, the square root of the lifetime. + # C is a power of two, so clearing the low bits rounds a slot down to a tree boundary. + c = 1 << (log_lifetime // 2) + c_mask = ~(c - 1) + + desired_end_slot = desired_activation_slot + desired_num_active_slots + + # Phase 1: round the start down and the end up onto tree boundaries. + # Adding C - 1 before clearing the low bits rounds the end up rather than down. + start = desired_activation_slot & c_mask + end = (desired_end_slot + c - 1) & c_mask + + # Phase 2: widen to two trees so the resident signing window always fits. + if end - start < 2 * c: + end = start + 2 * c + + # Phase 3: clamp the window into the lifetime. + lifetime = c * c + if end > lifetime: + duration = end - start + if duration > lifetime: + # The request is wider than the whole lifetime, so cover all of it. + start = 0 + end = lifetime + else: + # Slide the window back so it ends exactly at the lifetime boundary. + end = lifetime + start = (lifetime - duration) & c_mask - ### Key Generation Algorithm + # Convert the slot boundaries to bottom-tree indices. + return (start // c, end // c) - 1. **Expand Activation Time**: Align the requested activation interval to - `sqrt(LIFETIME)` boundaries to enable efficient tree partitioning. - This ensures the interval starts at a multiple of `sqrt(LIFETIME)` and - has a minimum duration of `2 * sqrt(LIFETIME)` slots. - 2. **Generate Master Secrets**: Generate PRF key and public parameter `P`. - The PRF key allows deterministic on-demand regeneration of one-time keys. +class GeneralizedXmssScheme(StrictBaseModel): + """Generalized XMSS signature scheme bound to one configuration.""" - 3. **Generate First Two Bottom Trees**: Create the first two bottom trees - (covering the initial `2 * sqrt(LIFETIME)` slots) and keep them in memory. - Each bottom tree covers `sqrt(LIFETIME)` consecutive slots. + config: XmssConfig + """Configuration parameters for this instance.""" - 4. **Generate Remaining Bottom Tree Roots**: For all other bottom trees in - the range, generate only their roots (not the full trees). This saves - memory since we only need the first two trees for the prepared window. + poseidon: PoseidonXmss + """Cached Poseidon1 engine used by every primitive in the scheme.""" - 5. **Build Top Tree**: Construct the top tree from all bottom tree roots. - The top tree's lowest layer contains the bottom tree roots, and it is - built upward to the global Merkle root. + def key_gen(self, activation_slot: Slot, num_active_slots: Uint64) -> KeyPair: + """Generate a fresh key pair active for an aligned slot range. - ### Memory Efficiency + # Overview - Traditional approach: O(LIFETIME) memory - Top-Bottom approach: O(sqrt(LIFETIME)) memory + The signer keeps only the two leftmost bottom trees resident, plus every bottom-tree root. + That bounds secret-key memory near the square root of the lifetime. - For LOG_LIFETIME=32 (2^32 slots): - - Traditional: ~hundreds of GiB - - Top-Bottom: much more reasonable + The requested range is snapped outward to whole bottom trees. + So the returned key may cover more slots than asked for. Args: - activation_slot: The starting slot for which this key is valid. - - Will be aligned downward to `sqrt(LIFETIME)` boundary. - num_active_slots: The number of consecutive slots the key can be used for. - - Will be rounded up to at least `2 * sqrt(LIFETIME)`. + activation_slot: Requested first signable slot. + num_active_slots: Requested number of signable slots. Returns: - A `KeyPair` containing the public and secret keys. + A key pair holding the public root and the resident signer state. - Note: - The actual activation slot and num_active_slots in the returned SecretKey - may be larger than requested due to alignment requirements. - - For the formal specification of this process, please refer to: - - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 - - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 - - The canonical Rust implementation: https://github.com/b-wagn/hash-sig + Raises: + ValueError: When the requested range exceeds the lifetime. """ - # Retrieve the scheme's configuration parameters. config = self.config - # Ensure the requested activation range is within the scheme's total supported lifetime. + # The requested range must fit within the global lifetime. if int(activation_slot) + int(num_active_slots) > int(config.LIFETIME): raise ValueError("Activation range exceeds the key's lifetime.") - # Generate the random public parameter `P` and the master PRF key. - # - `P` ensures hash function outputs are unique to this key pair. - # - PRF key is the single master secret from which all one-time keys are derived. - parameter = self.rand.parameter() - prf_key = self.prf.key_gen() + # Phase 1: draw the master secret and the public parameter. + parameter = random_parameter(config) + prf_key = PRFKey.generate() - # Step 1: Expand and align activation time to sqrt(LIFETIME) boundaries. - start_bottom_tree_index, end_bottom_tree_index = expand_activation_time( + # Phase 2: align the requested interval onto bottom-tree boundaries. + start_bottom_tree_index, end_bottom_tree_index = _expand_activation_time( config.LOG_LIFETIME, int(activation_slot), int(num_active_slots) ) + actual_activation_slot = start_bottom_tree_index * config.LEAVES_PER_BOTTOM_TREE + actual_num_active_slots = ( + end_bottom_tree_index - start_bottom_tree_index + ) * config.LEAVES_PER_BOTTOM_TREE - num_bottom_trees = end_bottom_tree_index - start_bottom_tree_index - leaves_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) - - # Calculate the actual (expanded) activation slot and count. - actual_activation_slot = start_bottom_tree_index * leaves_per_bottom_tree - actual_num_active_slots = num_bottom_trees * leaves_per_bottom_tree - - # Step 2: Generate the first two bottom trees (kept in memory). + # Phase 3: build the two leftmost bottom trees and keep them resident. left_bottom_tree = HashSubTree.from_prf_key( - prf=self.prf, - hasher=self.hasher, - rand=self.rand, + poseidon=self.poseidon, config=config, prf_key=prf_key, bottom_tree_index=Uint64(start_bottom_tree_index), parameter=parameter, ) right_bottom_tree = HashSubTree.from_prf_key( - prf=self.prf, - hasher=self.hasher, - rand=self.rand, + poseidon=self.poseidon, config=config, prf_key=prf_key, bottom_tree_index=Uint64(start_bottom_tree_index + 1), parameter=parameter, ) - # Collect roots for building the top tree. bottom_tree_roots: list[HashDigestVector] = [ left_bottom_tree.root(), right_bottom_tree.root(), ] - # Step 3: Generate remaining bottom trees (only their roots). + # Phase 4: build every other bottom tree and remember only its root. for i in range(start_bottom_tree_index + 2, end_bottom_tree_index): tree = HashSubTree.from_prf_key( - prf=self.prf, - hasher=self.hasher, - rand=self.rand, + poseidon=self.poseidon, config=config, prf_key=prf_key, bottom_tree_index=Uint64(i), @@ -177,21 +146,18 @@ def key_gen(self, activation_slot: Slot, num_active_slots: Uint64) -> KeyPair: ) bottom_tree_roots.append(tree.root()) - # Step 4: Build the top tree from bottom tree roots. + # Phase 5: assemble the top tree from all bottom-tree roots. top_tree = HashSubTree.new_top_tree( - hasher=self.hasher, - rand=self.rand, + poseidon=self.poseidon, + config=config, depth=config.LOG_LIFETIME, start_bottom_tree_index=Uint64(start_bottom_tree_index), parameter=parameter, bottom_tree_roots=bottom_tree_roots, ) - # Extract the global root. - root = top_tree.root() - - # Assemble and return the keys. - pk = PublicKey(root=root, parameter=parameter) + # Pack the public root and the resident signer state into the key pair. + pk = PublicKey(root=top_tree.root(), parameter=parameter) sk = SecretKey( prf_key=prf_key, parameter=parameter, @@ -205,113 +171,84 @@ def key_gen(self, activation_slot: Slot, num_active_slots: Uint64) -> KeyPair: return KeyPair(public_key=pk, secret_key=sk) def sign(self, sk: SecretKey, slot: Slot, message: Bytes32) -> Signature: - """ - Produces a digital signature for a given message at a specific slot. - - This is a **deterministic** algorithm. Calling `sign` twice with the same - (sk, slot, message) triple produces the same signature. - - **CRITICAL SECURITY WARNING**: A secret key for a given slot must **NEVER** be used - to sign two different messages. Doing so would reveal parts of the secret key - and allow an attacker to forge signatures. This is the fundamental security - property of a synchronized (stateful) signature scheme. + """Produce a signature for a message at a specific slot. - ### Signing Algorithm + Phase 1: enforce that the slot is inside the activation and prepared windows. + Phase 2: search for randomness rho whose encoding lands on the target-sum layer. + Phase 3: walk each Winternitz chain to the released hash dictated by the codeword. + Phase 4: build the combined Merkle path through the bottom and top trees. - 1. **Message Encoding with Randomness (`rho`)**: The "Target Sum" scheme - requires the message hash to be encoded into a `codeword` whose digits - sum to a predefined target. A direct hash of the message is unlikely to - satisfy this. Therefore, the algorithm repeatedly hashes the message - combined with deterministic randomness (`rho`) derived from the PRF - until a valid `codeword` is found. - - 2. **One-Time Signature**: The `codeword` dictates how the one-time signature is - formed. For each digit `x_i` in the codeword, the signer reveals an intermediate - hash value by applying the hash function `x_i` times to the secret start of the - `i`-th hash chain. - The collection of these intermediate hashes forms the one-time signature. - - 3. **Merkle Path**: The signer retrieves the Merkle authentication path for the leaf - corresponding to the current `slot`. This path proves that the one-time public key - for this slot is part of the main public key (the Merkle root). + Signing is deterministic in (sk, slot, message). + A secret key must never sign two different messages for the same slot. Args: - sk: The secret key to use for signing. - slot: The slot for which the signature is being created. - message: The message to be signed. + sk: Secret key. + slot: Signing slot. + message: Message to sign. Returns: - The resulting `Signature` object. + A signature carrying the OTS, the Merkle path, and the randomness. - For the formal specification of this process, please refer to: - - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 - - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 - - The canonical Rust implementation: https://github.com/b-wagn/hash-sig + Raises: + ValueError: When the slot is outside the activation or prepared window. + RuntimeError: When no valid encoding is found within MAX_TRIES attempts. """ - # Retrieve the scheme's configuration parameters. config = self.config - - # Verify that the secret key is currently active for the requested signing slot. slot_int = int(slot) activation_int = int(sk.activation_slot) + + # Phase 1a: the slot must lie in the key's activation range. + # + # This is a synchronized one-time scheme: each slot consumes a distinct one-time key. + # The key only holds material for the contiguous range fixed at generation. if not (activation_int <= slot_int < activation_int + int(sk.num_active_slots)): raise ValueError("Key is not active for the specified slot.") - # Verify that the slot is within the prepared interval (covered by loaded bottom trees). + # Phase 1b: the slot must lie in the prepared window. # - # With top-bottom tree traversal, only slots within the prepared interval can be - # signed without computing additional bottom trees. - # - # If the slot is outside this range, we need to slide the window forward. - leaves_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) - prepared_start = int(sk.left_bottom_tree_index) * leaves_per_bottom_tree - prepared_end = prepared_start + 2 * leaves_per_bottom_tree - if not (prepared_start <= slot_int < prepared_end): + # The signature opens this slot's leaf through the bottom tree that holds it. + # Only the two adjacent resident bottom trees are available. + # A slot outside them would force rebuilding a tree from the PRF, so we refuse. + prepared = self.get_prepared_interval(sk) + if slot_int not in prepared: raise ValueError( f"Slot {slot} is outside the prepared interval " - f"[{prepared_start}, {prepared_end}). " + f"[{prepared.start}, {prepared.stop}). " f"Call advance_preparation() to slide the window forward." ) - # Find a valid message encoding. + # Phase 2: find randomness whose encoding lands on the target-sum layer. # - # This loop repeatedly tries different randomness `rho` until the encoder - # produces a valid codeword (i.e., one that meets the target sum constraint). + # A valid codeword must have digits summing to the target. + # That constant sum is what makes distinct codewords incomparable, hence unforgeable. + # The message hash hits the layer only for some randomness, so resample until it does. # - # The randomness is deterministically derived from the PRF to ensure - # that signing is reproducible for the same (sk, slot, message). + # The randomness is derived from the PRF, keyed by the message and an attempt counter. + # Signing the same message twice is therefore reproducible. for attempts in range(config.MAX_TRIES): - # Derive deterministic randomness `rho` from PRF using the attempt counter. - rho = self.prf.get_randomness(sk.prf_key, slot, message, Uint64(attempts)) - # Attempt to encode the message with the deterministic `rho`. - codeword = self.encoder.encode(sk.parameter, message, rho, slot) - # If encoding is successful, we've found our `rho` and `codeword`. - # - # We can exit the loop. + rho = sk.prf_key.derive_randomness(config, slot, message, Uint64(attempts)) + codeword = target_sum_encode(self.poseidon, config, sk.parameter, message, rho, slot) if codeword is not None: break else: - # This block executes only if the `for` loop completes without a `break`. - # - # This means that no valid encoding was found after the maximum number of tries. raise RuntimeError( f"Failed to find a valid message encoding after {config.MAX_TRIES} tries." ) - # Sanity check to ensure the encoder returned a codeword of the correct length. - if len(codeword) != self.config.DIMENSION: + # A valid codeword carries exactly one digit per Winternitz chain. + if len(codeword) != config.DIMENSION: raise RuntimeError("Encoding is broken: returned too many or too few chunks.") - # Compute the one-time signature hashes based on the codeword. + # Phase 3: release each Winternitz chain at the position its digit selects. + # + # Every chain starts from a secret derived from the PRF. + # Hashing that start forward by the digit gives the value to reveal. + # The verifier later finishes the remaining steps to reach the chain end. ots_hashes: list[HashDigestVector] = [] for chain_index, steps in enumerate(codeword): - # Derive the secret start of the current chain using the master PRF key. - start_digest = self.prf.apply(sk.prf_key, slot, Uint64(chain_index)) - # Walk the hash chain for the number of `steps` specified by the - # corresponding digit in the codeword. - # - # The result is one component of the OTS. - ots_digest = self.hasher.hash_chain( + start_digest = sk.prf_key.derive_chain_start(config, slot, Uint64(chain_index)) + ots_digest = self.poseidon.hash_chain( + config=config, parameter=sk.parameter, epoch=slot, chain_index=chain_index, @@ -321,106 +258,71 @@ def sign(self, sk: SecretKey, slot: Slot, message: Bytes32) -> Signature: ) ots_hashes.append(ots_digest) - # Retrieve the Merkle authentication path for the current slot's leaf. - # With top-bottom tree traversal, we use combined_path to merge paths from - # the bottom tree and top tree. - - # Determine which bottom tree contains this slot (reuse leaves_per_bottom_tree from above). - boundary = (int(sk.left_bottom_tree_index) + 1) * leaves_per_bottom_tree + # Phase 4: open this slot's leaf up to the public root. + # + # The opening climbs the bottom tree that holds the slot, then the top tree. + # The slot's side of the prepared window selects which resident bottom tree to climb. + boundary = prepared.start + config.LEAVES_PER_BOTTOM_TREE bottom_tree = sk.left_bottom_tree if slot_int < boundary else sk.right_bottom_tree + path = combined_path(sk.top_tree, bottom_tree, Uint64(slot)) - # Generate the combined authentication path - path = combined_path(sk.top_tree, bottom_tree, Uint64(int(slot))) - - # Assemble and return the final signature, which contains: - # - The OTS, - # - The Merkle path, - # - The randomness `rho` needed for verification. + # The signature carries the opening, the randomness, and the released chain values. + # The randomness lets the verifier recompute the same codeword. return Signature(path=path, rho=rho, hashes=HashDigestList(data=ots_hashes)) def verify(self, pk: PublicKey, slot: Slot, message: Bytes32, sig: Signature) -> bool: - r""" - Verifies a digital signature against a public key, message, and slot. - - This is a **deterministic** algorithm. - - ### Verification Algorithm - - 1. **Re-encode Message**: The verifier uses the randomness `rho` from the - signature to re-compute the codeword $x = (x_1, \dots, x_v)$ from the message `m`. - If the encoding is invalid (e.g., does not meet the target sum), verification fails. - - 2. **Reconstruct One-Time Public Key**: For each intermediate hash $y_i$ in the - signature's `hashes` field, the verifier completes the corresponding hash chain. - Since $y_i$ was computed by hashing $x_i$ times, the verifier applies the - hash function an additional `BASE - 1 - x_i` times to arrive at the - chain's public endpoint, which is one component of the one-time public key. - - 3. **Compute Merkle Leaf**: The verifier hashes the full set of reconstructed - chain endpoints to compute the expected Merkle leaf for the given slot. + """Verify a signature against a public key, message, and slot. - 4. **Verify Merkle Path**: The verifier uses the authentication `path` from the - signature to compute a candidate Merkle root, starting from the leaf computed - in the previous step. Verification succeeds if and only if this candidate root - matches the `root` stored in the `PublicKey`. + Phase 1: bound-check the slot. + Phase 1 rejects without raising on bad input. + Phase 2: recompute the codeword using the randomness carried by the signature. + Phase 3: complete each Winternitz chain from the released hash to its endpoint. + Phase 4: rebuild the Merkle root from the chain endpoints and the opening. Args: - pk: The public key to verify against. - slot: The slot the signature corresponds to. - message: The message that was supposedly signed. - sig: The signature object to be verified. + pk: Public key. + slot: Signing slot claimed by the signature. + message: Message claimed by the signature. + sig: Signature to verify. Returns: - `True` if the signature is valid, `False` otherwise. - - For the formal specification of this process, please refer to: - - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 - - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 - - The canonical Rust implementation: https://github.com/b-wagn/hash-sig + True when the signature is valid against the public key, false otherwise. """ - # Retrieve the scheme's configuration parameters. config = self.config - # Validate slot bounds. - # - # Return False instead of raising to avoid panic on invalid signatures. - # The slot is attacker-controlled input. - if int(slot) >= int(self.config.LIFETIME): + # Phase 1: bound check on the slot. + # The slot is attacker-controlled, so a malformed value returns False + # rather than panicking deep in the verification routine. + if int(slot) >= int(config.LIFETIME): return False - # Re-encode the message using the randomness `rho` from the signature. - # - # If the encoding is invalid (e.g., fails the target sum check), the signature is invalid. - codeword = self.encoder.encode(pk.parameter, message, sig.rho, slot) + # Phase 2: rederive the codeword from the signature's randomness. + # A failing aborting decode means the signature cannot be valid. + codeword = target_sum_encode(self.poseidon, config, pk.parameter, message, sig.rho, slot) if codeword is None: return False - # Reconstruct the one-time public key (the list of chain endpoints). + # Phase 3: finish each chain from the released hash to its endpoint. chain_ends: list[HashDigestVector] = [] for chain_index, xi in enumerate(codeword): - # The signature provides `start_digest`, which is the hash value after `xi` steps. + # The signature provides the digest after xi steps along the chain. + # We hash the remaining BASE - 1 - xi times to reach the endpoint. start_digest = sig.hashes[chain_index] - # We must perform the remaining `BASE - 1 - xi` hashing steps - # to compute the public endpoint of the chain. - num_steps_remaining = config.BASE - 1 - xi - end_digest = self.hasher.hash_chain( + end_digest = self.poseidon.hash_chain( + config=config, parameter=pk.parameter, epoch=slot, chain_index=chain_index, start_step=xi, - num_steps=num_steps_remaining, + num_steps=config.BASE - 1 - xi, start_digest=start_digest, ) chain_ends.append(end_digest) - # Verify the Merkle path. - # - # This function internally: - # - Hashes the `chain_ends` to get the leaf node for the slot, - # - Uses the `opening` path from the signature to compute a candidate root. - # - It returns true if and only if this candidate root matches the public key's root. + # Phase 4: rebuild and compare against the trusted root. return verify_path( - hasher=self.hasher, + poseidon=self.poseidon, + config=config, parameter=pk.parameter, root=pk.root, position=slot, @@ -429,94 +331,57 @@ def verify(self, pk: PublicKey, slot: Slot, message: Bytes32, sig: Signature) -> ) def get_activation_interval(self, sk: SecretKey) -> range: - """ - Returns the slot range for which this secret key is active. - - The activation interval is `[activation_slot, activation_slot + num_active_slots)`. - A signature can only be created for a slot within this range. + """Return the activation interval as a Python range. - Args: - sk: The secret key to query. - - Returns: - A Python range object representing the valid slot range. + A signature is only valid for a slot inside this range. """ start = int(sk.activation_slot) - end = start + int(sk.num_active_slots) - return range(start, end) + return range(start, start + int(sk.num_active_slots)) def get_prepared_interval(self, sk: SecretKey) -> range: - """ - Returns the slot range currently prepared (covered by loaded bottom trees). - - With top-bottom tree traversal, a secret key maintains a sliding window of - two consecutive bottom trees. This method returns the range of slots that - can be signed with the currently loaded trees, without needing to compute - additional bottom trees. + """Return the prepared interval as a Python range. - The prepared interval is: - `[left_bottom_tree_index * sqrt(LIFETIME), (left_bottom_tree_index + 2) * sqrt(LIFETIME))` - - Args: - sk: The secret key to query. - - Returns: - A Python range object representing the prepared slot range. - - Raises: - ValueError: If the secret key is missing top-bottom tree structures. + The prepared interval is the slot window covered by the two resident bottom trees. + A signer can sign any slot in this range without paying the cost of + rebuilding a bottom tree from the PRF. """ - leaves_per_bottom_tree = 1 << (self.config.LOG_LIFETIME // 2) + leaves_per_bottom_tree = self.config.LEAVES_PER_BOTTOM_TREE start = int(sk.left_bottom_tree_index) * leaves_per_bottom_tree return range(start, start + 2 * leaves_per_bottom_tree) def advance_preparation(self, sk: SecretKey) -> SecretKey: - """ - Advances the prepared interval by computing the next bottom tree. - - This method implements the "sliding window" strategy for top-bottom tree - traversal. It: - 1. Computes a new bottom tree for the next interval - 2. Shifts the current right tree to become the new left tree - 3. The newly computed tree becomes the new right tree - 4. Increments `left_bottom_tree_index` + """Slide the prepared window one bottom tree forward. - After this operation, the prepared interval moves forward by `sqrt(LIFETIME)` slots. + Phase 1: bail out when the next window would exceed the activation interval. + Phase 2: regenerate the new right bottom tree from the PRF key. + Phase 3: shift the previous right tree to the left slot. - **When to call**: Call this method after signing with a slot that is in the - right half of the prepared interval, to ensure the next slot range is ready. + Returning the same key when no advancement is possible keeps callers simple. Args: - sk: The secret key to advance. + sk: Secret key whose prepared window should advance. Returns: - A new SecretKey with the advanced preparation window. - - Raises: - ValueError: If advancing would exceed the activation interval. + A secret key with the window shifted by one bottom tree. """ - leaves_per_bottom_tree = 1 << (self.config.LOG_LIFETIME // 2) left_index = int(sk.left_bottom_tree_index) - # Check if advancing would exceed the activation interval - next_prepared_end_slot = (left_index + 3) * leaves_per_bottom_tree + # Phase 1: no advancement once the activation interval is fully consumed. + next_prepared_end_slot = (left_index + 3) * self.config.LEAVES_PER_BOTTOM_TREE activation_end = int(sk.activation_slot) + int(sk.num_active_slots) if next_prepared_end_slot > activation_end: - # Nothing to do - we're already at the end of the activation interval return sk - # Compute the next bottom tree (the one after the current right tree) + # Phase 2: rebuild the next bottom tree from the master PRF key. new_right_bottom_tree = HashSubTree.from_prf_key( - prf=self.prf, - hasher=self.hasher, - rand=self.rand, + poseidon=self.poseidon, config=self.config, prf_key=sk.prf_key, bottom_tree_index=Uint64(left_index + 2), parameter=sk.parameter, ) - # Return a new SecretKey with the advanced window + # Phase 3: rotate the right tree into the left slot, advance the index. return sk.model_copy( update={ "left_bottom_tree": sk.right_bottom_tree, @@ -526,28 +391,11 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey: ) -PROD_SIGNATURE_SCHEME = GeneralizedXmssScheme( - config=PROD_CONFIG, - prf=PROD_PRF, - hasher=PROD_TWEAK_HASHER, - encoder=PROD_TARGET_SUM_ENCODER, - rand=PROD_RAND, -) -"""An instance configured for production-level parameters.""" - -TEST_SIGNATURE_SCHEME = GeneralizedXmssScheme( - config=TEST_CONFIG, - prf=TEST_PRF, - hasher=TEST_TWEAK_HASHER, - encoder=TEST_TARGET_SUM_ENCODER, - rand=TEST_RAND, -) -"""A lightweight instance for test environments.""" - -_LEAN_ENV_TO_SCHEME = { - "test": TEST_SIGNATURE_SCHEME, - "prod": PROD_SIGNATURE_SCHEME, -} - -TARGET_SIGNATURE_SCHEME = _LEAN_ENV_TO_SCHEME[LEAN_ENV] -"""The active XMSS signature scheme based on LEAN_ENV environment variable.""" +PROD_SIGNATURE_SCHEME = GeneralizedXmssScheme(config=PROD_CONFIG, poseidon=PROD_POSEIDON) +"""Signature scheme instance with production parameters.""" + +TEST_SIGNATURE_SCHEME = GeneralizedXmssScheme(config=TEST_CONFIG, poseidon=TEST_POSEIDON) +"""Signature scheme instance with test parameters.""" + +TARGET_SIGNATURE_SCHEME = TEST_SIGNATURE_SCHEME if LEAN_ENV == "test" else PROD_SIGNATURE_SCHEME +"""Active scheme selected at import time from the LEAN_ENV environment variable.""" diff --git a/src/lean_spec/subspecs/xmss/merkle.py b/src/lean_spec/subspecs/xmss/merkle.py new file mode 100644 index 000000000..c581eb86b --- /dev/null +++ b/src/lean_spec/subspecs/xmss/merkle.py @@ -0,0 +1,612 @@ +r""" +Sparse Merkle subtrees for the top-bottom traversal of an XMSS key. + +# Overview + +The long-lived public key of an XMSS signature is the root of one Merkle tree. +That tree commits to one one-time public key per slot of the key's lifetime. +A signature opens the leaf for its slot with a path of sibling hashes up to the root. + +A full lifetime tree has one leaf per slot. +For a lifetime of 2^32 slots, holding every node in memory is infeasible. + +# The top-bottom split + +The tree is cut into one top tree sitting above many bottom trees. +Each bottom tree covers a contiguous run of leaves, the square root of the lifetime of them. +There are that many bottom trees, and their roots are exactly the leaves of the top tree. + +The signer keeps the whole top tree resident, plus the bottom trees around the active slot. +As the active slot advances, stale bottom trees are dropped and fresh ones regenerated on demand. +Resident memory stays near the square root of the lifetime instead of the full lifetime. + +# Sparse layers + +Only a window of leaves is resident, so a stored layer holds a contiguous slice, not a full level. +Each layer records the absolute index of its first node, keeping positions in full-tree coordinates. +""" + +from itertools import batched +from typing import Self + +from lean_spec.types import Uint64 +from lean_spec.types.collections import SSZList +from lean_spec.types.container import Container + +from .constants import TARGET_CONFIG, XmssConfig +from .field import random_domain +from .poseidon import PoseidonXmss +from .prf import PRFKey +from .types import ( + HashDigestList, + HashDigestVector, + HashTreeOpening, + Parameter, + TreeTweak, +) + + +class HashTreeLayer(Container): + """ + A single horizontal slice of a sparse Merkle subtree. + + The tree is sparse, so a layer stores only the nodes computed for the active leaf range. + """ + + start_index: Uint64 + """Absolute index of the first stored node within its level.""" + + nodes: HashDigestList + """Stored hash digests for this layer, ordered left to right.""" + + @classmethod + def padded( + cls, + config: XmssConfig, + nodes: list[HashDigestVector], + start_index: Uint64, + ) -> Self: + """ + Build a layer whose nodes can all be paired at the next level up. + + # Why pad + + The level above pairs nodes two at a time, then hashes each pair. + - A run starting on an odd index lacks a left neighbor for its first node. + - A run ending on an even index lacks a right neighbor for its last node. + + Padding either gap with a fresh random digest lets every node pair. + + # Invariant + + The result starts on an even index and ends on an odd index. + A single-node layer is the sole exception, since it is a subtree root. + + # Layout + + indices 5, 6, 7 -> [pad] 5 6 7 starts 4, ends 7 + indices 4, 5, 6 -> 4 5 6 [pad] starts 4, ends 7 + + Args: + config: Active XMSS configuration. + nodes: Active nodes at this layer, in ascending index order. + start_index: Absolute index of the first active node. + + Returns: + A layer satisfying the alignment invariant above. + """ + nodes_with_padding: list[HashDigestVector] = [] + end_index = start_index + Uint64(len(nodes)) - Uint64(1) + + # Prepend one random sibling when the layer begins on an odd index. + if start_index % Uint64(2) == Uint64(1): + nodes_with_padding.append(random_domain(config)) + + # The padded layer always starts on the even index at or before start_index. + actual_start_index = start_index - (start_index % Uint64(2)) + + nodes_with_padding.extend(nodes) + + # Append one random sibling when the layer ends on an even index. + if end_index % Uint64(2) == Uint64(0): + nodes_with_padding.append(random_domain(config)) + + return cls( + start_index=actual_start_index, + nodes=HashDigestList(data=nodes_with_padding), + ) + + +class HashTreeLayers(SSZList[HashTreeLayer]): + """ + The layers of a subtree, ordered from the lowest layer up to the root. + + A bottom tree and a top tree each cover half the depth. + The cap admits the full lifetime tree. + """ + + LIMIT = TARGET_CONFIG.LOG_LIFETIME + 1 + """Layers run from level zero, the leaves, up to the lifetime depth, the root, inclusive.""" + + +class HashSubTree(Container): + """ + A contiguous slice of an XMSS lifetime tree, stored layer by layer. + + # Overview + + A subtree holds every node from its lowest layer up to a single root node. + Two shapes exist, told apart by where the lowest layer sits. + + - A bottom tree starts at layer zero and covers a window of leaves. + - A top tree starts at the split layer and covers the bottom-tree roots. + + # Invariant + + Every stored layer starts on an even index and ends on an odd index. + The exception is the top layer, which holds the single subtree root. + """ + + depth: Uint64 + """Depth of the full lifetime tree this subtree belongs to. + + A subtree starting at layer k stores depth - k layers. + """ + + lowest_layer: Uint64 + """Lowest layer included in this subtree: + - Zero for bottom trees, + - LOG_LIFETIME/2 for top trees. + """ + + layers: HashTreeLayers + """Layers stored from lowest_layer up to the subtree root. + + The last entry holds a single node, the subtree root. + """ + + @classmethod + def new( + cls, + poseidon: PoseidonXmss, + config: XmssConfig, + lowest_layer: Uint64, + depth: Uint64, + start_index: Uint64, + parameter: Parameter, + lowest_layer_nodes: list[HashDigestVector], + highest_layer: Uint64 | None = None, + ) -> Self: + """ + Build a subtree from its lowest layer up to a bounding layer. + + # Overview + + Each layer is hashed pairwise into the layer above, climbing one level per step. + Padding keeps every intermediate layer aligned so sibling pairs form cleanly. + Building stops at the bounding layer, which defaults to the global root. + + Args: + poseidon: Cached Poseidon1 engine. + config: Active XMSS configuration. + lowest_layer: Starting layer for this subtree. + depth: Total depth of the full lifetime tree. + start_index: Absolute index of the first input node. + parameter: Public parameter for the hash function. + lowest_layer_nodes: Active nodes at the lowest layer. + highest_layer: Layer to stop building at, defaulting to the full depth. + + Returns: + A subtree holding every layer from the lowest layer up to the bounding layer. + + Raises: + ValueError: When the input nodes do not fit the level they start in. + """ + # Build to the global root unless a lower bounding layer is requested. + highest_layer = depth if highest_layer is None else highest_layer + + # The input nodes must fit in the layer they belong to. + max_positions = 1 << int(depth - lowest_layer) + if int(start_index) + len(lowest_layer_nodes) > max_positions: + raise ValueError( + f"Overflow at layer {lowest_layer}: " + f"start={start_index}, count={len(lowest_layer_nodes)}, max={max_positions}" + ) + + # Phase 1: pad the input layer. + layers: list[HashTreeLayer] = [] + current = HashTreeLayer.padded(config, lowest_layer_nodes, start_index) + layers.append(current) + + # Phases 2 + 3: hash sibling pairs, pad, repeat. + for level in range(lowest_layer, highest_layer): + parent_start = current.start_index // Uint64(2) + parents = [ + poseidon.tweak_hash( + config, + parameter, + TreeTweak(level=level + 1, index=parent_start + Uint64(i)), + [left, right], + ) + for i, (left, right) in enumerate(batched(current.nodes, 2)) + ] + current = HashTreeLayer.padded(config, parents, parent_start) + layers.append(current) + + return cls( + depth=depth, + lowest_layer=lowest_layer, + layers=HashTreeLayers(data=layers), + ) + + @classmethod + def new_top_tree( + cls, + poseidon: PoseidonXmss, + config: XmssConfig, + depth: int, + start_bottom_tree_index: Uint64, + parameter: Parameter, + bottom_tree_roots: list[HashDigestVector], + ) -> Self: + """Build the top tree from bottom-tree roots up to the global root. + + The top tree starts at layer depth/2 and treats bottom-tree roots as its leaves. + + Args: + poseidon: Cached Poseidon1 engine. + config: Active XMSS configuration. + depth: Total depth of the full lifetime tree. + start_bottom_tree_index: Index of the first bottom tree in the range. + parameter: Public parameter for the hash function. + bottom_tree_roots: Roots of all bottom trees in the range, in order. + + Returns: + A top tree whose root is the global Merkle root. + + Raises: + ValueError: When depth is odd. + """ + # Top-bottom split requires an even depth. + if depth % 2 != 0: + raise ValueError(f"Depth must be even for top-bottom split, got {depth}.") + + return cls.new( + poseidon=poseidon, + config=config, + lowest_layer=Uint64(depth // 2), + depth=Uint64(depth), + start_index=start_bottom_tree_index, + parameter=parameter, + lowest_layer_nodes=bottom_tree_roots, + ) + + @classmethod + def new_bottom_tree( + cls, + poseidon: PoseidonXmss, + config: XmssConfig, + depth: int, + bottom_tree_index: Uint64, + parameter: Parameter, + leaves: list[HashDigestVector], + ) -> Self: + """ + Build one bottom tree from its leaf hashes up to its standalone root. + + # Overview + + A bottom tree spans the lower half of the lifetime tree for one window of slots. + Its root is later placed as a single leaf of the top tree. + + Args: + poseidon: Cached Poseidon1 engine. + config: Active XMSS configuration. + depth: Total depth of the full lifetime tree. + bottom_tree_index: Index of this bottom tree. + parameter: Public parameter for the hash function. + leaves: Pre-hashed one-time public keys for this bottom tree's slots. + + Returns: + A subtree spanning the lower half of the tree, ending in the bottom-tree root. + + Raises: + ValueError: When the depth is odd, or the leaf count is not the square + root of the lifetime. + """ + if depth % 2 != 0: + raise ValueError(f"Depth must be even for top-bottom split, got {depth}.") + + # Each bottom tree spans exactly sqrt(LIFETIME) leaves. + leaves_per_tree = 1 << (depth // 2) + if len(leaves) != leaves_per_tree: + raise ValueError( + f"Expected {leaves_per_tree} leaves for depth={depth}, got {len(leaves)}." + ) + + # Phase 1: build only layers 0 through depth/2, the bottom tree's own height. + subtree = cls.new( + poseidon=poseidon, + config=config, + lowest_layer=Uint64(0), + depth=Uint64(depth), + start_index=bottom_tree_index * Uint64(leaves_per_tree), + parameter=parameter, + lowest_layer_nodes=leaves, + highest_layer=Uint64(depth // 2), + ) + + # Phase 2: the top built layer is padded to a sibling pair. + # The real root is the node at this tree's index, not always position zero. + # An odd index leaves a random pad at position zero, so select by absolute index. + top = subtree.layers[-1] + root_idx = int(bottom_tree_index - top.start_index) + root_layer = HashTreeLayer( + start_index=bottom_tree_index, + nodes=HashDigestList(data=[top.nodes[root_idx]]), + ) + return cls( + depth=Uint64(depth), + lowest_layer=Uint64(0), + layers=HashTreeLayers(data=list(subtree.layers[:-1]) + [root_layer]), + ) + + @classmethod + def from_prf_key( + cls, + poseidon: PoseidonXmss, + config: XmssConfig, + prf_key: PRFKey, + bottom_tree_index: Uint64, + parameter: Parameter, + ) -> Self: + """ + Regenerate one bottom tree on demand from the master secret seed. + + # Overview + + The secret key is not stored slot by slot. + One short master seed deterministically expands into every chain start. + This lets the signer keep only a sliding window resident and rebuild the rest on demand. + + # What a leaf is + + Each slot owns one one-time signature made of many independent hash chains. + Walking a chain from its secret start to its far end yields one public chain end. + Hashing all of a slot's chain ends together produces the leaf committed at that slot. + + Args: + poseidon: Cached Poseidon1 engine. + config: Active XMSS configuration. + prf_key: Master secret seed. + bottom_tree_index: Index of the bottom tree to regenerate. + parameter: Public parameter for the hash function. + + Returns: + The requested bottom tree. + """ + # Each bottom tree covers sqrt(LIFETIME) consecutive epochs. + leaves_per_bottom_tree = config.LEAVES_PER_BOTTOM_TREE + start_epoch = bottom_tree_index * Uint64(leaves_per_bottom_tree) + end_epoch = start_epoch + Uint64(leaves_per_bottom_tree) + + leaf_hashes: list[HashDigestVector] = [] + for epoch in range(start_epoch, end_epoch): + # Derive each chain start from the seed, then walk it to its public end. + # The far end is the chain start hashed forward the full length minus one. + chain_ends: list[HashDigestVector] = [] + for chain_index in range(config.DIMENSION): + start_digest = prf_key.derive_chain_start( + config, Uint64(epoch), Uint64(chain_index) + ) + end_digest = poseidon.hash_chain( + config=config, + parameter=parameter, + epoch=Uint64(epoch), + chain_index=chain_index, + start_step=0, + num_steps=config.BASE - 1, + start_digest=start_digest, + ) + chain_ends.append(end_digest) + + # The leaf for this slot is the hash of all its chain ends together. + leaf_tweak = TreeTweak(level=0, index=Uint64(epoch)) + leaf_hash = poseidon.tweak_hash(config, parameter, leaf_tweak, chain_ends) + leaf_hashes.append(leaf_hash) + + return cls.new_bottom_tree( + poseidon=poseidon, + config=config, + depth=config.LOG_LIFETIME, + bottom_tree_index=bottom_tree_index, + parameter=parameter, + leaves=leaf_hashes, + ) + + def root(self) -> HashDigestVector: + """Return the single node in the highest stored layer. + + Raises: + ValueError: When the subtree is empty or the highest layer has no nodes. + """ + if not self.layers: + raise ValueError("Empty subtree has no root.") + if not self.layers[-1].nodes: + raise ValueError("Top layer is empty.") + return self.layers[-1].nodes[0] + + def path(self, position: Uint64) -> HashTreeOpening: + """ + Collect the sibling hashes that connect one leaf to the subtree root. + + # Overview + + At each level the node has exactly one sibling, the node sharing its parent. + That sibling sits at the current position with its lowest bit flipped. + Recording one sibling per level, then halving the position to climb, yields the full path. + The root has no sibling, so the walk stops one level below it. + + # Layout + + climbing from leaf position 5 in a three-level subtree: + + position 5 -> sibling 4 (flip low bit), then halve to 2 + position 2 -> sibling 3 (flip low bit), then halve to 1 + position 1 -> root, no sibling, stop + + Args: + position: Absolute index of the leaf in full-tree coordinates. + + Returns: + An opening of sibling hashes ordered from the leaf upward. + + Raises: + ValueError: When the subtree is empty or the position is out of bounds. + """ + if not self.layers: + raise ValueError("Empty subtree.") + + first = self.layers[0] + if not (first.start_index <= position < first.start_index + Uint64(len(first.nodes))): + raise ValueError(f"Position {position} out of bounds.") + + siblings: list[HashDigestVector] = [] + pos = int(position) + + # Stop one short of the root layer. + # The root has no sibling. + for layer in self.layers[:-1]: + # The sibling sits at the position with the last bit flipped, then we + # rebase by the layer's start_index because the layer is sparse. + sibling_idx = (pos ^ 1) - int(layer.start_index) + if not (0 <= sibling_idx < len(layer.nodes)): + raise ValueError(f"Sibling index {sibling_idx} out of bounds.") + siblings.append(layer.nodes[sibling_idx]) + pos //= 2 + + return HashTreeOpening(siblings=HashDigestList(data=siblings)) + + +def combined_path( + top_tree: HashSubTree, + bottom_tree: HashSubTree, + position: Uint64, +) -> HashTreeOpening: + """ + Stitch a bottom-tree opening and a top-tree opening into one full path. + + # Overview + + A signature authenticates its leaf all the way up to the global root. + No single resident subtree spans that whole distance, so two openings are joined. + + # Proof flow + + bottom opening : proves the leaf sits under its bottom-tree root. + top opening : proves that bottom-tree root sits under the global root. + + Args: + top_tree: The top tree containing the global root. + bottom_tree: The bottom tree containing the leaf. + position: Absolute index of the leaf. + + Returns: + One opening that authenticates the leaf against the global root. + + Raises: + ValueError: When the tree depths disagree, the depth is odd, or the position + does not belong to the supplied bottom tree. + """ + if top_tree.depth != bottom_tree.depth: + raise ValueError(f"Depth mismatch: top={top_tree.depth}, bottom={bottom_tree.depth}.") + + depth = int(top_tree.depth) + if depth % 2 != 0: + raise ValueError(f"Depth must be even, got {depth}.") + + # The position must belong to the supplied bottom tree, not a sibling one. + leaves_per_tree = Uint64(1 << (depth // 2)) + expected_start = (position // leaves_per_tree) * leaves_per_tree + if bottom_tree.layers[0].start_index != expected_start: + raise ValueError( + f"Wrong bottom tree: position {position} needs start {expected_start}, " + f"got {bottom_tree.layers[0].start_index}." + ) + + # The opening climbs from leaf to root, so bottom siblings come before top siblings. + bottom_path = bottom_tree.path(position) + top_path = top_tree.path(position // leaves_per_tree) + combined = tuple(bottom_path.siblings.data) + tuple(top_path.siblings.data) + + return HashTreeOpening(siblings=HashDigestList(data=combined)) + + +def verify_path( + poseidon: PoseidonXmss, + config: XmssConfig, + parameter: Parameter, + root: HashDigestVector, + position: Uint64, + leaf_parts: list[HashDigestVector], + opening: HashTreeOpening, +) -> bool: + """ + Recompute a root from a leaf and its opening, then compare against a trusted root. + + # Overview + + Verification mirrors construction in reverse. + The leaf is hashed, then folded with each sibling while climbing one level per step. + The walk succeeds when the recomputed root equals the trusted root. + + # Why return false instead of raising + + The opening arrives inside an untrusted signature. + A malformed opening must be a quiet verification failure, never a crash. + So out-of-range input returns false rather than raising. + + Args: + poseidon: Cached Poseidon1 engine. + config: Active XMSS configuration. + parameter: Public parameter for the hash function. + root: Trusted root taken from the public key. + position: Absolute index of the leaf being verified. + leaf_parts: Digests that constitute the original leaf. + opening: Sibling path from leaf to root. + + Returns: + True when the path reconstructs the trusted root, false otherwise. + """ + # Guard against malformed openings. + # The opening list caps at 32 entries. + # A depth greater than 32 would overflow the position bound check below. + depth = len(opening.siblings) + if depth > 32: + return False + if int(position) >= (1 << depth): + return False + + # Phase 1: hash the leaf parts to derive the starting node. + current = poseidon.tweak_hash( + config, + parameter, + TreeTweak(level=0, index=Uint64(position)), + leaf_parts, + ) + pos = int(position) + + # Phase 2: hash with each sibling, climbing one layer per iteration. + for level, sibling in enumerate(opening.siblings): + # The current node sits on the left when its position is even. + left, right = (current, sibling) if pos % 2 == 0 else (sibling, current) + pos //= 2 + current = poseidon.tweak_hash( + config, + parameter, + TreeTweak(level=level + 1, index=Uint64(pos)), + [left, right], + ) + + # Phase 3: compare against the trusted root. + return current == root diff --git a/src/lean_spec/subspecs/xmss/message_hash.py b/src/lean_spec/subspecs/xmss/message_hash.py deleted file mode 100644 index f1d253750..000000000 --- a/src/lean_spec/subspecs/xmss/message_hash.py +++ /dev/null @@ -1,161 +0,0 @@ -""" -Defines the message hashing for the signature scheme using aborting hypercube encoding. - -### The Challenge: Efficiently Encoding a Message as a Codeword - -The "Target Sum" signature scheme requires the signer to find a `codeword` whose -digits sum to a specific value. This requires hashing a message and mapping the -output to a vertex in a high-dimensional hypercube. - -### The Solution: Aborting Hypercube Encoding - -This module implements a circuit-friendly encoding based on rejection sampling of -individual field elements, eliminating all big-integer arithmetic. - -For KoalaBear (`P = 2^31 - 2^24 + 1`), `P - 1 = Q * BASE^Z`, so each field element -can be decomposed into `Z` base-`BASE` digits after dividing by `Q`. The only reject -case is `A_i == P - 1` (probability ~4.7e-10 per FE — essentially never aborts). - -This is backed by the "Aborting Random Oracles" paper which proves -indifferentiability from a theta-aborting random oracle when modeling Poseidon as a -standard random oracle. - -The encoding proceeds in two stages: - -1. **Input Preparation**: All inputs are encoded into field elements. -2. **Poseidon Hashing + Aborting Decode**: Poseidon1 produces `ceil(DIMENSION/Z)` - field elements, each decoded into `Z` base-`BASE` digits via rejection sampling. -""" - -from __future__ import annotations - -from lean_spec.subspecs.xmss.poseidon import ( - PROD_POSEIDON, - TEST_POSEIDON, - PoseidonXmss, -) -from lean_spec.types import Bytes32, StrictBaseModel, Uint64 - -from ..koalabear import Fp -from .constants import ( - PROD_CONFIG, - TEST_CONFIG, - TWEAK_PREFIX_MESSAGE, - XmssConfig, -) -from .types import Parameter, Randomness -from .utils import int_to_base_p - - -class MessageHasher(StrictBaseModel): - """An instance of the message hasher using aborting hypercube encoding.""" - - config: XmssConfig - """Configuration parameters for the hasher.""" - - poseidon: PoseidonXmss - """Poseidon hash engine.""" - - def encode_message(self, message: Bytes32) -> list[Fp]: - """ - Encodes a 32-byte message into a list of field elements. - - The message bytes are interpreted as a single little-endian integer, - which is then decomposed into its base-`P` representation, where `P` - is the field prime. This provides a canonical mapping from bytes to - the algebraic structure required by Poseidon1. - """ - # Interpret the 32 little-endian bytes as a single large integer. - acc = int.from_bytes(message, "little") - - # Decompose the integer into a list of field elements (base-P). - return int_to_base_p(acc, self.config.MSG_LEN_FE) - - def encode_epoch(self, epoch: Uint64) -> list[Fp]: - """ - Encodes the epoch and a domain separator prefix into field elements. - - This function packs the epoch and the message hash prefix into a single - integer, then decomposes it. This ensures the epoch is included in the - hash input in a structured, domain-separated way. - """ - # Combine the epoch and the message hash prefix into a single integer. - acc = (int(epoch) << 8) | TWEAK_PREFIX_MESSAGE - - # Decompose the integer into its base-P representation. - return int_to_base_p(acc, self.config.TWEAK_LEN_FE) - - def _aborting_decode(self, field_elements: list[Fp]) -> list[int] | None: - """ - Decodes Poseidon output field elements into base-`BASE` digits via rejection sampling. - - For each field element `A_i`: - - 1. If `A_i >= Q * BASE^Z` (i.e. `A_i == P - 1`), abort and return `None`. - 2. Compute `d_i = A_i // Q`, an integer in `[0, BASE^Z - 1]`. - 3. Decompose `d_i` into `Z` base-`BASE` digits, least significant first. - - Collect all digits and return the first `DIMENSION` of them. - """ - config = self.config - threshold = config.Q * config.BASE**config.Z - - digits: list[int] = [] - for fe in field_elements: - a = int(fe) - - # Rejection: the only failing case is A_i == P - 1. - if a >= threshold: - return None - - # Integer quotient removes the Q-residue, leaving a uniform value in [0, BASE^Z - 1]. - d = a // config.Q - - # Decompose d into Z base-BASE digits, least significant first. - for _ in range(config.Z): - d, digit = divmod(d, config.BASE) - digits.append(digit) - - # Take exactly DIMENSION digits. - return digits[: config.DIMENSION] - - def apply( - self, - parameter: Parameter, - epoch: Uint64, - rho: Randomness, - message: Bytes32, - ) -> list[int] | None: - """ - Applies message hashing followed by aborting hypercube decode. - - Hashes the inputs with Poseidon1 to produce `MH_HASH_LEN_FE` field elements, - then decodes them into a candidate codeword via rejection sampling. - - Args: - parameter: The public parameter `P`. - epoch: The current epoch. - rho: A random value `rho` to ensure a unique hash output. - message: The 32-byte message to be hashed. - - Returns: - A candidate codeword (list of `DIMENSION` digits in `[0, BASE-1]`), - or `None` if the aborting decode rejects. - """ - # Encode the message and epoch as field elements. - message_fe = self.encode_message(message) - epoch_fe = self.encode_epoch(epoch) - - # Call Poseidon1 once to produce the required number of output field elements. - base_input = message_fe + list(parameter.data) + epoch_fe + list(rho.data) - poseidon_output = self.poseidon.compress(base_input, 24, self.config.MH_HASH_LEN_FE) - - # Decode the field elements into base-BASE digits via rejection sampling. - return self._aborting_decode(poseidon_output) - - -PROD_MESSAGE_HASHER = MessageHasher(config=PROD_CONFIG, poseidon=PROD_POSEIDON) -"""An instance configured for production-level parameters.""" - -TEST_MESSAGE_HASHER = MessageHasher(config=TEST_CONFIG, poseidon=TEST_POSEIDON) -"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/poseidon.py b/src/lean_spec/subspecs/xmss/poseidon.py index 37a6ff7b1..b14962937 100644 --- a/src/lean_spec/subspecs/xmss/poseidon.py +++ b/src/lean_spec/subspecs/xmss/poseidon.py @@ -1,30 +1,10 @@ -""" -Defines the Poseidon1 hash functions for the Generalized XMSS scheme. +"""Poseidon1 hash engine in compression and sponge modes for the Generalized XMSS scheme.""" -### The Cryptographic Engine: Why Poseidon1? - -This module provides the low-level cryptographic engine for all internal hashing -operations. It is built on **Poseidon1** hash function. - -The choice of Poseidon1 is deliberate and critical for the scheme's ultimate goal. -Unlike traditional hashes like SHA-3, Poseidon1 is an **arithmetization-friendly** -(or **SNARK-friendly**) hash function. Its algebraic structure is simple, making it -exponentially faster to prove and verify inside a zero-knowledge proof system, -which is essential for aggregating many signatures into a single, compact proof. - -This file provides wrappers for the two primary ways Poseidon1 is used: - -1. **Compression Mode**: A fast, fixed-input-size mode for hashing small, - predictable data structures like a single hash digest or a pair of them. -2. **Sponge Mode**: A flexible, variable-input-size mode for hashing large - amounts of data, like the many digests that form a Merkle tree leaf. -""" - -from __future__ import annotations +from itertools import batched from pydantic import PrivateAttr -from lean_spec.types import StrictBaseModel +from lean_spec.types import StrictBaseModel, Uint64 from ..koalabear import Fp from ..poseidon1.permutation import ( @@ -33,107 +13,93 @@ Poseidon1, Poseidon1Params, ) -from .utils import int_to_base_p +from .constants import TWEAK_PREFIX_CHAIN, TWEAK_PREFIX_TREE, XmssConfig +from .field import int_to_base_p +from .types import ChainTweak, HashDigestVector, Parameter, TreeTweak class PoseidonXmss(StrictBaseModel): - """An instance of the Poseidon1 hash engine for the XMSS scheme.""" + """Poseidon1 hash engine wrapper used inside the XMSS scheme.""" params16: Poseidon1Params - """Poseidon1 parameters for 16-width permutation.""" + """Permutation parameters for the width-16 state.""" params24: Poseidon1Params - """Poseidon1 parameters for 24-width permutation.""" + """Permutation parameters for the width-24 state.""" - _engine16: Poseidon1 | None = PrivateAttr(default=None) - _engine24: Poseidon1 | None = PrivateAttr(default=None) + _engines: dict[int, Poseidon1] = PrivateAttr(default_factory=dict) def _get_engine(self, width: int) -> Poseidon1: - """Return a cached Poseidon1 engine for the given width.""" - if width == 16: - if self._engine16 is None: - self._engine16 = Poseidon1(self.params16) - return self._engine16 - if self._engine24 is None: - self._engine24 = Poseidon1(self.params24) - return self._engine24 + """Return a cached Poseidon1 engine for the given width. - def compress(self, input_vec: list[Fp], width: int, output_len: int) -> list[Fp]: + Raises: + ValueError: When the width is neither 16 nor 24. """ - Implements the Poseidon1 hash in **compression mode**. - - This mode is used for hashing fixed-size inputs and is the most efficient - way to use Poseidon1. It is used for traversing hash chains and building - the internal nodes of the Merkle tree. + if width not in self._engines: + match width: + case 16: + params = self.params16 + case 24: + params = self.params24 + case _: + raise ValueError(f"Width must be 16 or 24, got {width}") + self._engines[width] = Poseidon1(params) + return self._engines[width] - ### Compression Algorithm + def compress(self, input_vec: list[Fp], width: int, output_len: int) -> list[Fp]: + """Poseidon1 in compression mode. - The function computes: `Truncate(Permute(padded_input) + padded_input)`. - 1. **Padding**: The `input_vec` is padded with zeros to match the full state `width`. - 2. **Permutation**: The core cryptographic permutation is applied to the padded state. - 3. **Feed-Forward**: The original padded input is added element-wise to the - permuted state. This is a key feature of the Poseidon1 design that - provides security against certain attacks. - 4. **Truncation**: The result is truncated to the desired `output_len`. + Computes Truncate(Permute(padded_input) + padded_input). + The padded input is the original vector zero-extended to the state width. + The feed-forward addition is part of the Poseidon1 design and is required + for security. + Used for hash chains and Merkle interior nodes. Args: - input_vec: The list of field elements to be hashed. - width: The state width of the Poseidon1 permutation (16 or 24). - output_len: The number of field elements in the output digest. + input_vec: Field elements to hash. + width: Permutation state width, either 16 or 24. + output_len: Number of output field elements to return. Returns: - A hash digest of `output_len` field elements. + Truncated digest of output_len field elements. """ - # Check that the input vector is long enough to produce the output. + # The output cannot be longer than the input vector after padding. if len(input_vec) < output_len: raise ValueError("Input vector is too short for requested output length.") - # Select the correct permutation parameters based on the state width. - if width not in (16, 24): - raise ValueError(f"Width must be 16 or 24, got {width}") + # Select the cached engine matching the requested permutation width. engine = self._get_engine(width) - # Create a padded input by extending with zeros to match the state width. + # Zero-pad to the state width before applying the permutation. padded_input = list(input_vec) + [Fp(value=0)] * (width - len(input_vec)) - # Apply the Poseidon1 permutation. + # Permute, then add the original padded input element-wise. permuted_state = engine.permute(padded_input) - - # Apply the feed-forward step, adding the input back element-wise. final_state = [p + i for p, i in zip(permuted_state, padded_input, strict=True)] - # Truncate the state to the desired output length and return. return final_state[:output_len] def safe_domain_separator(self, lengths: list[int], capacity_len: int) -> list[Fp]: - """ - Computes a unique domain separator for the sponge construction (SAFE API). + """Build a capacity initialization vector for the sponge construction. - A sponge's security relies on its initial state being unique for each distinct - hashing task. This function creates a unique "configuration" or - "initialization vector" (`capacity_value`) by hashing the high-level - parameters of the sponge's usage (e.g., the dimensions of the data - being hashed). This prevents multi-user or cross-context attacks. + Hashes the packed length parameters into a fixed-size capacity value. + This prevents collisions between sponges that absorb data of different shapes. Args: - lengths: A list of integer parameters that define the hash context. - capacity_len: The desired length of the output capacity value. + lengths: Integer parameters that define the hash context. + capacity_len: Number of field elements in the returned capacity value. Returns: - A list of `capacity_len` field elements for initializing the sponge. + A capacity vector of length capacity_len. """ - # Pack all the length parameters into a single, large, unambiguous integer. + # Pack all lengths into a single unambiguous integer using 32-bit slots. acc = 0 for length in lengths: acc = (acc << 32) | length - # Decompose this integer into a fixed-size list of field elements. - # - # This list serves as the input to a one-off compression hash. - # NOTE: we always use this mode with a 24 width. + # Compress the decomposed vector through the width-24 engine. + # Width 24 is the only mode used for sponge domain separation. input_vec = int_to_base_p(acc, 24) - - # Compress the decomposed vector to produce the capacity value. return self.compress(input_vec, 24, capacity_len) def sponge( @@ -143,80 +109,157 @@ def sponge( output_len: int, width: int, ) -> list[Fp]: - """ - Implements the Poseidon1 hash using the **sponge construction**. - - This mode is used for hashing large or variable-length inputs. In this scheme, - it is specifically used to hash the Merkle tree leaves, which consist of many - concatenated hash digests. - - ### Sponge Algorithm - - 1. **Initialization**: The internal state is divided into a `rate` (for data) - and a `capacity` (for security). The `capacity` part is initialized - with the domain-separating `capacity_value`. - - 2. **Absorbing**: The input data is processed in `rate`-sized chunks. In each - step, a chunk is added to the `rate` part of the state, and then the - entire state is scrambled by the `permute` function. + """Poseidon1 in sponge mode. - 3. **Squeezing**: Once all input is absorbed, the `rate` part of the state is - extracted as output. If more output is needed, the state is permuted again, - and more is extracted, repeating until `output_len` elements are generated. + Phase 1: load capacity, zero-extend input to a multiple of the rate. + Phase 2: absorb each rate-sized chunk by replacement, then permute. + Phase 3: squeeze the rate slots until output_len elements are produced. Args: - input_vec: The input data of arbitrary length. - capacity_value: The domain-separating value from `safe_domain_separator`. - output_len: The number of field elements in the final output digest. - width: The width of the Poseidon1 permutation. + input_vec: Variable-length input. + capacity_value: Domain-separating capacity initialization. + output_len: Desired output length in field elements. + width: Permutation state width. Returns: - A hash digest of `output_len` field elements. + A digest of output_len field elements. """ - # Ensure that the capacity value is not too long. + # The capacity must leave at least one rate slot for absorbing input. if len(capacity_value) >= width: raise ValueError("Capacity length must be smaller than the state width.") - # Determine the permutation parameters and the size of the rate. - if width not in (16, 24): - raise ValueError(f"Width must be 16 or 24, got {width}") engine = self._get_engine(width) rate = width - len(capacity_value) - # Pad the input vector with zeros to be an exact multiple of the rate size. + # Zero-pad to a multiple of the rate so absorption iterates exact chunks. num_extra = (rate - (len(input_vec) % rate)) % rate padded_input = input_vec + [Fp(value=0)] * num_extra - # Initialize the state: - # - capacity part (domain separator) at the beginning, - # - rate part (zero) follows. + # Layout: capacity slots first, then rate slots. cap_len = len(capacity_value) state = [Fp(value=0)] * width state[:cap_len] = capacity_value - # Absorb the input in rate-sized chunks via replacement. - for i in range(0, len(padded_input), rate): - chunk = padded_input[i : i + rate] - # Replace the rate part of the state with the chunk. - for j in range(rate): - state[cap_len + j] = chunk[j] - # Apply the cryptographic permutation to mix the state. + # Phase 2: absorb each chunk by overwriting the rate slots. + for chunk in batched(padded_input, rate): + for j, value in enumerate(chunk): + state[cap_len + j] = value state = engine.permute(state) - # Squeeze the output until enough elements have been generated. + # Phase 3: squeeze rate slots, permuting until enough output is available. output: list[Fp] = [] while len(output) < output_len: - # Extract the rate part of the state (after capacity) as output. output.extend(state[cap_len : cap_len + rate]) - # Permute the state. state = engine.permute(state) - # Truncate to the final output length and return. return output[:output_len] + def tweak_hash( + self, + config: XmssConfig, + parameter: Parameter, + tweak: TreeTweak | ChainTweak, + message_parts: list[HashDigestVector], + ) -> HashDigestVector: + """Apply the tweakable hash to one or more digests. + + Mode selection: + + - One digest input uses width-16 compression for hash chains. + - Two digest inputs use width-24 compression for Merkle interior nodes. + - More inputs use sponge mode for Merkle leaves. + + Args: + config: Active XMSS configuration. + parameter: Public parameter that personalizes the hash. + tweak: Position tweak for domain separation. + message_parts: Digests to hash together. + + Returns: + A digest of HASH_LEN_FE field elements. + """ + # Pack the tweak fields into one integer, then split it into base-P field elements. + # + # The low byte is a per-shape prefix. + # It stops a tree tweak and a chain tweak from packing to the same value. + # That keeps Merkle hashing domain-separated from chain hashing. + # + # Every other field sits in its own bit range above the prefix. + match tweak: + case TreeTweak(level=level, index=index): + acc = (level << 40) | (int(index) << 8) | TWEAK_PREFIX_TREE + case ChainTweak(epoch=epoch, chain_index=chain_index, step=step): + acc = (int(epoch) << 24) | (chain_index << 16) | (step << 8) | TWEAK_PREFIX_CHAIN + encoded_tweak = int_to_base_p(acc, config.TWEAK_LEN_FE) + + if len(message_parts) == 1: + # Hash chain step: width-16 compression of (digest || parameter || tweak). + input_vec = message_parts[0].elements + parameter.elements + encoded_tweak + result = self.compress(input_vec, 16, config.HASH_LEN_FE) + + elif len(message_parts) == 2: + # Merkle node: width-24 compression of (parameter || tweak || left || right). + input_vec = ( + parameter.elements + + encoded_tweak + + message_parts[0].elements + + message_parts[1].elements + ) + result = self.compress(input_vec, 24, config.HASH_LEN_FE) + + else: + # Merkle leaf: sponge mode over many concatenated digests. + flattened_message = [elem for part in message_parts for elem in part.elements] + input_vec = parameter.elements + encoded_tweak + flattened_message + + # The domain separator binds the sponge to this hashing task shape. + lengths = [ + config.PARAMETER_LEN, + config.TWEAK_LEN_FE, + config.DIMENSION, + config.HASH_LEN_FE, + ] + capacity_value = self.safe_domain_separator(lengths, config.CAPACITY) + result = self.sponge(input_vec, capacity_value, config.HASH_LEN_FE, 24) + + return HashDigestVector(data=result) + + def hash_chain( + self, + config: XmssConfig, + parameter: Parameter, + epoch: Uint64, + chain_index: int, + start_step: int, + num_steps: int, + start_digest: HashDigestVector, + ) -> HashDigestVector: + """Iterate the tweakable hash along a Winternitz chain. + + Each iteration uses a distinct chain tweak so every step is domain-separated. + + Args: + config: Active XMSS configuration. + parameter: Public parameter that personalizes the hash. + epoch: Slot identifier for the one-time signature. + chain_index: Index of the chain within the one-time signature. + start_step: Step number of the input digest. + num_steps: Number of additional hash applications. + start_digest: Digest at start_step. + + Returns: + Digest at start_step + num_steps. + """ + current_digest = start_digest + for i in range(num_steps): + # Steps are 1-indexed: step 1 is the first hash after the chain start. + tweak = ChainTweak(epoch=epoch, chain_index=chain_index, step=start_step + i + 1) + current_digest = self.tweak_hash(config, parameter, tweak, [current_digest]) + return current_digest + PROD_POSEIDON = PoseidonXmss(params16=PARAMS_16, params24=PARAMS_24) -"""An instance configured for production-level parameters.""" +"""Poseidon1 engine with production parameters.""" TEST_POSEIDON = PROD_POSEIDON -"""Test and production use the same Poseidon1 parameters; only XmssConfig differs.""" +"""Test environment reuses the production Poseidon1 parameters.""" diff --git a/src/lean_spec/subspecs/xmss/prf.py b/src/lean_spec/subspecs/xmss/prf.py index b90bed660..426b2a40a 100644 --- a/src/lean_spec/subspecs/xmss/prf.py +++ b/src/lean_spec/subspecs/xmss/prf.py @@ -1,209 +1,148 @@ -""" -Defines the pseudorandom function (PRF) used in the signature scheme. - -PRF based on the SHAKE128 extendable-output function (XOF). - -The PRF is used to derive the secret starting points of the hash chains -for each epoch from a single master secret key. -""" - -from __future__ import annotations +"""SHAKE128-based pseudorandom function for deterministic key derivation.""" import hashlib import os -from typing import Final +from itertools import batched +from typing import Final, Self -from lean_spec.subspecs.koalabear import Fp -from lean_spec.types import Bytes32, StrictBaseModel, Uint64 +from lean_spec.types import Bytes16, Bytes32, Uint64 +from lean_spec.types.byte_arrays import BaseBytes -from .constants import ( - PRF_KEY_LENGTH, - PROD_CONFIG, - TEST_CONFIG, - XmssConfig, -) -from .types import HashDigestVector, PRFKey, Randomness +from ..koalabear import Fp +from .constants import PRF_KEY_LENGTH, XmssConfig +from .types import HashDigestVector, Randomness -PRF_DOMAIN_SEP: Final[bytes] = b"\xae\xae\x22\xff\x00\x01\xfa\xff\x21\xaf\x12\x00\x01\x11\xff\x00" +PRF_DOMAIN_SEP: Final = Bytes16(b"\xae\xae\x22\xff\x00\x01\xfa\xff\x21\xaf\x12\x00\x01\x11\xff\x00") """ -A 16-byte domain separator to ensure PRF outputs are unique to this context. +Fixed domain separator prefixed to every PRF call. -This prevents any potential conflicts if the same underlying hash function -(SHAKE128) were used for other purposes in the system. +Prevents cross-context collisions if SHAKE128 is reused elsewhere in the system. """ PRF_DOMAIN_SEP_DOMAIN_ELEMENT: Final[bytes] = b"\x00" -""" -A 1-byte domain separator for deriving domain elements (used in `apply`). - -This distinguishes the PRF calls for generating hash chain starting points -from the PRF calls for generating randomness during signing. -""" +"""Subdomain tag for hash-chain start derivation.""" PRF_DOMAIN_SEP_RANDOMNESS: Final[bytes] = b"\x01" -""" -A 1-byte domain separator for deriving randomness (used in `get_randomness`). - -This distinguishes the PRF calls for generating signing randomness from the -PRF calls for generating domain elements, preventing any potential collisions -between the two use cases. -""" +"""Subdomain tag for signing-randomness derivation.""" PRF_BYTES_PER_FE: Final[int] = 16 """ -The number of bytes of SHAKE128 output used to generate one field element. +SHAKE128 bytes consumed per output field element. -We use 16 bytes (128 bits) of pseudorandom output, which is then reduced -modulo the 31-bit field prime `P`. This provides a significant statistical -safety margin to ensure the resulting field element is close to uniformly -random. +128 bits reduced modulo a 31-bit prime gives a statistical margin against bias. """ -def _bytes_to_field_elements(data: bytes, count: int) -> list[Fp]: +class PRFKey(BaseBytes): """ - Convert PRF output bytes into a list of field elements. - - Each field element is derived from `PRF_BYTES_PER_FE` bytes, - interpreted as a big-endian integer and reduced modulo the field prime. - - The extra bits provide statistical uniformity. + The PRF master secret key. - Args: - data: Raw bytes from SHAKE128 output. Must be exactly `count * PRF_BYTES_PER_FE` bytes. - count: Number of field elements to extract. + High-entropy byte string acting as the single root secret. - Returns: - List of `count` field elements. + Every one-time signing key is deterministically derived from this seed. """ - return [ - Fp(value=int.from_bytes(data[i : i + PRF_BYTES_PER_FE], "big")) - for i in range(0, count * PRF_BYTES_PER_FE, PRF_BYTES_PER_FE) - ] + LENGTH = PRF_KEY_LENGTH -class Prf(StrictBaseModel): - """An instance of the SHAKE128-based PRF for a given config.""" + @classmethod + def generate(cls) -> Self: + """Draw a fresh master key from the operating system entropy pool.""" + return cls(os.urandom(PRF_KEY_LENGTH)) - config: XmssConfig - """Configuration parameters for the PRF.""" - - def key_gen(self) -> PRFKey: + def derive_chain_start( + self, config: XmssConfig, epoch: Uint64, chain_index: Uint64 + ) -> HashDigestVector: """ - Generates a cryptographically secure random key for the PRF. + Derive the secret start of one Winternitz hash chain. - This function sources randomness from the operating system's - entropy pool. + # Overview - Returns: - A new, randomly generated PRF key of `PRF_KEY_LENGTH` bytes. - """ - return PRFKey(os.urandom(PRF_KEY_LENGTH)) - - def apply(self, key: PRFKey, epoch: Uint64, chain_index: Uint64) -> HashDigestVector: - """ - Applies the PRF to derive the secret starting value for a single hash chain. - - ### PRF Construction - - The function constructs a unique input for the underlying SHAKE128 function - by concatenating several components: - `SHAKE128(DOMAIN_SEP || 0x00 || key || epoch || chain_index)` - - The 0x00 byte distinguishes this use case (deriving domain elements) from - randomness generation (which uses 0x01). The arbitrary-length output of - SHAKE128 is then processed to produce a list of field elements, which - serves as the secret starting digest for one chain. + Each slot signs with many independent hash chains. + A chain begins at a secret value. + Its public counterpart is that value hashed all the way up to the chain top. + Recreating the secret start from the seed means it never has to be stored. Args: - key: The secret master PRF key. - epoch: The epoch number, identifying the one-time signature instance. - chain_index: The index of the hash chain within that epoch's OTS. + config: Active XMSS configuration. + epoch: Slot identifier for this one-time signature instance. + chain_index: Position of the chain within the one-time signature. Returns: - A hash digest representing the secret start of a single hash chain. + The secret digest at the bottom of the chain. """ - # Retrieve the scheme's configuration parameters. - config = self.config - - # Construct the unique input for the PRF by concatenating its components: + # Layout: + # + # domain_sep || 0x00 || key || epoch (4 bytes) || chain_index (8 bytes) # - # - Domain Separation: Uniquely tag the PRF for this specific use case. - # - Domain Element Tag: 0x00 byte to distinguish from randomness generation. - # - Key Input: The master secret key. - # - Epoch: A 4-byte integer ensuring every epoch derives a different set of secrets. - # - Chain Index: An 8-byte integer ensuring each parallel hash chain gets a unique secret. + # The 0x00 byte separates chain-start derivation from randomness derivation. input_data = ( PRF_DOMAIN_SEP + PRF_DOMAIN_SEP_DOMAIN_ELEMENT - + key + + self + epoch.to_bytes(4, "big") + chain_index.to_bytes(8, "big") ) - # Determine the total number of bytes to extract from the SHAKE output. - # - # We need enough bytes to produce `HASH_LEN_FE` field elements. + # Pull enough SHAKE128 bytes to fill one digest of field elements. num_bytes_to_read = PRF_BYTES_PER_FE * config.HASH_LEN_FE prf_output_bytes = hashlib.shake_128(input_data).digest(num_bytes_to_read) + return HashDigestVector( + data=[ + Fp(value=int.from_bytes(bytes(chunk), "big")) + for chunk in batched(prf_output_bytes, PRF_BYTES_PER_FE) + ] + ) - # Convert the raw byte output into a list of field elements. - return HashDigestVector(data=_bytes_to_field_elements(prf_output_bytes, config.HASH_LEN_FE)) - - def get_randomness( - self, key: PRFKey, epoch: Uint64, message: Bytes32, counter: Uint64 + def derive_randomness( + self, + config: XmssConfig, + epoch: Uint64, + message: Bytes32, + counter: Uint64, ) -> Randomness: """ - Derives pseudorandom field elements for use in deterministic signing. + Derive deterministic randomness for one signing attempt. - This method is used to generate deterministic randomness for the Information - Encoding step during signing. By deriving randomness from the PRF key, epoch, - message, and attempt counter, we ensure that signing is deterministic: calling - `sign` twice with the same (sk, epoch, message) triple produces the same signature. + # Overview - This provides additional hardening against implementation errors where sign might - be called multiple times with the same epoch. However, calling sign with the same - epoch but *different* messages still compromises security. + Signing maps the message onto a codeword whose digits must add up to a fixed target. + - Each attempt folds in fresh randomness. + - It retries with a higher counter until the digit sum lands on the target. - ### Construction + Deriving that randomness from the seed makes the search reproducible. - Similar to `apply`, but includes the message and a counter in the input: - `SHAKE128(DOMAIN_SEP || 0x01 || key || epoch || message || counter)` + # Reproducibility - The 0x01 byte distinguishes this use case (generating randomness) from - domain element derivation (which uses 0x00). + A synchronized scheme signs each slot at most once. + Signing one slot twice is treated as misbehavior. + A deterministic attempt order means one slot and message always yield the same signature. Args: - key: The secret master PRF key. - epoch: The epoch number for this signature. - message: The message being signed (MESSAGE_LENGTH bytes). - counter: The attempt number (used when retrying encoding). + config: Active XMSS configuration. + epoch: Slot identifier for this signature. + message: Full message being signed. + counter: Attempt number, incremented when a previous attempt aborted. Returns: - Randomness for encoding (i.e., `rho`). + Randomness used to encode the message into a valid codeword. """ - config = self.config - - # Construct input: DOMAIN_SEP || 0x01 || key || epoch || message || counter + # Layout: + # + # domain_sep || 0x01 || key || epoch || message || counter input_data = ( PRF_DOMAIN_SEP + PRF_DOMAIN_SEP_RANDOMNESS - + key + + self + epoch.to_bytes(4, "big") + message + counter.to_bytes(8, "big") ) - # Extract enough bytes for RAND_LEN_FE field elements num_bytes_to_read = PRF_BYTES_PER_FE * config.RAND_LEN_FE prf_output_bytes = hashlib.shake_128(input_data).digest(num_bytes_to_read) - - # Convert to field elements and wrap in Randomness - return Randomness(data=_bytes_to_field_elements(prf_output_bytes, config.RAND_LEN_FE)) - - -PROD_PRF = Prf(config=PROD_CONFIG) -"""An instance configured for production-level parameters.""" - -TEST_PRF = Prf(config=TEST_CONFIG) -"""A lightweight instance for test environments.""" + return Randomness( + data=[ + Fp(value=int.from_bytes(bytes(chunk), "big")) + for chunk in batched(prf_output_bytes, PRF_BYTES_PER_FE) + ] + ) diff --git a/src/lean_spec/subspecs/xmss/rand.py b/src/lean_spec/subspecs/xmss/rand.py deleted file mode 100644 index fa84e80ce..000000000 --- a/src/lean_spec/subspecs/xmss/rand.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Random data generator for the XMSS signature scheme.""" - -import secrets - -from lean_spec.types import StrictBaseModel - -from ..koalabear import Fp, P -from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig -from .types import HashDigestVector, Parameter - - -class Rand(StrictBaseModel): - """An instance of the random data generator for a given config.""" - - config: XmssConfig - """Configuration parameters for the random generator.""" - - def field_elements(self, length: int) -> list[Fp]: - """Generates a random list of field elements.""" - # For each element, generate a secure random integer in the range [0, P-1]. - return [Fp(value=secrets.randbelow(P)) for _ in range(length)] - - def parameter(self) -> Parameter: - """Generates a random public parameter.""" - return Parameter(data=self.field_elements(self.config.PARAMETER_LEN)) - - def domain(self) -> HashDigestVector: - """Generates a random hash digest.""" - return HashDigestVector(data=self.field_elements(self.config.HASH_LEN_FE)) - - -PROD_RAND = Rand(config=PROD_CONFIG) -"""An instance configured for production-level parameters.""" - -TEST_RAND = Rand(config=TEST_CONFIG) -"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/subtree.py b/src/lean_spec/subspecs/xmss/subtree.py deleted file mode 100644 index baf66bf04..000000000 --- a/src/lean_spec/subspecs/xmss/subtree.py +++ /dev/null @@ -1,615 +0,0 @@ -""" -Subtree construction and manipulation for top-bottom Merkle tree traversal. - -This module contains the `HashSubTree` type and its associated construction methods, -implementing the memory-efficient top-bottom tree traversal approach. -""" - -from typing import Self - -from lean_spec.types import Uint64 -from lean_spec.types.container import Container - -from .constants import XmssConfig -from .prf import Prf -from .rand import Rand -from .tweak_hash import TreeTweak, TweakHasher -from .types import ( - HashDigestList, - HashDigestVector, - HashTreeLayer, - HashTreeLayers, - HashTreeOpening, - Parameter, - PRFKey, -) -from .utils import get_padded_layer - - -class HashSubTree(Container): - """ - Represents a subtree of a sparse Merkle tree. - - This is the building block for the top-bottom tree traversal approach, - which splits a large Merkle tree into: - - **One top tree**: Contains the root and the top `LOG_LIFETIME/2` layers - - **Multiple bottom trees**: Each contains `sqrt(LIFETIME)` leaves - - A subtree can represent either a complete tree (from layer 0) or a partial tree - starting from a higher layer (like a top tree starting from layer `LOG_LIFETIME/2`). - - The layers are stored from `lowest_layer` up to the root, with padding applied - to ensure even alignment for efficient parent computation. - - Memory Efficiency - ----------------- - For a key with lifetime 2^32: - - Traditional approach: O(2^32) = requires hundreds of GiB - - Top-bottom approach: O(sqrt(2^32)) = O(2^16) ≈ much less memory - - The secret key maintains: - - The full top tree (sparse, only active roots) - - Two consecutive bottom trees (sliding window) - - SSZ Container with fields: - - depth: uint64 - - lowest_layer: uint64 - - layers: List[HashTreeLayer, LAYERS_LIMIT] - - Serialization is handled automatically by SSZ. - """ - - depth: Uint64 - """ - The total depth of the full tree (e.g., 32 for a 2^32 leaf space). - - This represents the depth of the complete Merkle tree, not just this subtree. - A subtree starting from layer `k` will have `depth - k` layers stored. - """ - - lowest_layer: Uint64 - """ - The lowest layer included in this subtree. - - - For bottom trees: `lowest_layer = 0` (includes leaves) - - For top trees: `lowest_layer = LOG_LIFETIME/2` (starts from middle) - - Example: For LOG_LIFETIME=32, top tree has lowest_layer=16, containing - layers 16 through 32 (the root). - """ - - layers: HashTreeLayers - """ - The layers of this subtree, from `lowest_layer` to the root. - - SSZ notation: `List[HashTreeLayer, LAYERS_LIMIT]` - - - `layers[0]` corresponds to layer `lowest_layer` in the full tree - - `layers[-1]` corresponds to the highest layer in this subtree - - For bottom trees: the last layer contains a single root - - For top trees: the last layer contains the global root - - Each layer maintains the padding invariant: start index is even, - end index is odd (except for single-node layers). - """ - - @classmethod - def new( - cls, - hasher: TweakHasher, - rand: Rand, - lowest_layer: Uint64, - depth: Uint64, - start_index: Uint64, - parameter: Parameter, - lowest_layer_nodes: list[HashDigestVector], - ) -> Self: - """ - Builds a new sparse Merkle subtree starting from a specified layer. - - This is the general constructor for subtrees and is used internally by - `new_top_tree()` and `new_bottom_tree()`. A subtree can start from any - layer, not just layer 0 (leaves). - - ### Construction Algorithm - - 1. **Initialization**: Start with the provided nodes at `lowest_layer`, - apply padding to ensure even alignment. - - 2. **Bottom-Up Iteration**: Build the tree layer by layer, from the - lowest layer towards the root. - - 3. **Parent Generation**: At each level, group nodes into pairs - (left, right) and hash them to create parent nodes. - - 4. **Padding**: Apply padding to each new layer to maintain the - even-alignment invariant. - - 5. **Termination**: Continue until reaching the root or the desired - highest layer. - - Args: - hasher: The tweakable hash instance for computing parent nodes. - rand: Random generator for padding values. - lowest_layer: The starting layer for this subtree (0 for full trees). - depth: The total depth of the full tree (e.g., 32 for 2^32 leaves). - start_index: The absolute index at `lowest_layer` where this subtree begins. - parameter: The public parameter `P` for the hash function. - lowest_layer_nodes: The hash nodes at `lowest_layer` to build from. - - Returns: - A `HashSubTree` containing all computed layers from `lowest_layer` to root. - """ - # Validate: nodes must fit in available positions at this layer. - max_positions = 1 << int(depth - lowest_layer) - if int(start_index) + len(lowest_layer_nodes) > max_positions: - raise ValueError( - f"Overflow at layer {lowest_layer}: " - f"start={start_index}, count={len(lowest_layer_nodes)}, max={max_positions}" - ) - - # Initialize with padded input layer. - layers: list[HashTreeLayer] = [] - current = get_padded_layer(rand, lowest_layer_nodes, start_index) - layers.append(current) - - # Build upward: hash pairs of children to create parents. - for level in range(lowest_layer, depth): - parent_start = current.start_index // Uint64(2) - - # Hash each pair of siblings into their parent using zip for cleaner indexing. - node_pairs = zip(current.nodes[::2], current.nodes[1::2], strict=True) - parents = [ - hasher.apply( - parameter, - TreeTweak(level=level + 1, index=parent_start + Uint64(i)), - [left, right], - ) - for i, (left, right) in enumerate(node_pairs) - ] - - # Pad and store the new layer. - current = get_padded_layer(rand, parents, parent_start) - layers.append(current) - - return cls( - depth=depth, - lowest_layer=lowest_layer, - layers=HashTreeLayers(data=layers), - ) - - @classmethod - def new_top_tree( - cls, - hasher: TweakHasher, - rand: Rand, - depth: int, - start_bottom_tree_index: Uint64, - parameter: Parameter, - bottom_tree_roots: list[HashDigestVector], - ) -> Self: - """ - Constructs a top tree from the roots of bottom trees. - - For top-bottom tree traversal, the full Merkle tree is split into: - - **Top tree**: Contains root and top `LOG_LIFETIME/2` layers - - **Bottom trees**: Each contains `sqrt(LIFETIME)` leaves - - The top tree's lowest layer contains the roots of all bottom trees, - and it is built upward from there to the global root. - - ### Algorithm - - 1. **Determine lowest layer**: For a tree of depth `d`, the top tree - starts at layer `d/2` (the middle of the tree). - - 2. **Build upward**: Use `new()` to build from the bottom tree - roots up to the global root. - - Args: - hasher: The tweakable hash instance for computing parent nodes. - rand: Random generator for padding values. - depth: The total depth of the full tree (must be even for top-bottom split). - start_bottom_tree_index: The index of the first bottom tree in the range. - parameter: The public parameter `P` for the hash function. - bottom_tree_roots: The list of roots from all bottom trees in order. - - Returns: - A `HashSubTree` representing the top tree with `lowest_layer = depth/2`. - - Raises: - ValueError: If depth is odd (top-bottom split requires even depth). - """ - if depth % 2 != 0: - raise ValueError(f"Depth must be even for top-bottom split, got {depth}.") - - # Build from middle layer using bottom tree roots as leaves. - return cls.new( - hasher=hasher, - rand=rand, - lowest_layer=Uint64(depth // 2), - depth=Uint64(depth), - start_index=start_bottom_tree_index, - parameter=parameter, - lowest_layer_nodes=bottom_tree_roots, - ) - - @classmethod - def new_bottom_tree( - cls, - hasher: TweakHasher, - rand: Rand, - depth: int, - bottom_tree_index: Uint64, - parameter: Parameter, - leaves: list[HashDigestVector], - ) -> Self: - """ - Constructs a single bottom tree from leaf hashes. - - A bottom tree covers `sqrt(LIFETIME)` consecutive epochs. For a tree with - `LOG_LIFETIME = 32`, each bottom tree covers 2^16 = 65536 epochs. - - Bottom trees are numbered 0, 1, 2, ... where tree `i` covers epochs - `[i * sqrt(LIFETIME), (i+1) * sqrt(LIFETIME))`. - - ### Algorithm - - 1. **Build full tree**: First, build a complete subtree from layer 0 - using the provided leaves. - - 2. **Truncate incompatible top layers**: The full tree computation adds - padding nodes in upper layers that would be incompatible with other - bottom trees. We remove these layers. - - 3. **Replace with standalone root**: Extract the root at layer `depth/2` - and make it the highest layer of this bottom tree. - - Args: - hasher: The tweakable hash instance for computing parent nodes. - rand: Random generator for padding values. - depth: The total depth of the full tree (must be even). - bottom_tree_index: The index of this bottom tree (0, 1, 2, ...). - parameter: The public parameter `P` for the hash function. - leaves: The pre-hashed leaf nodes (one-time public keys). - - Returns: - A `HashSubTree` with layers 0 through `depth/2`, where the highest - layer contains only the bottom tree's root. - - Raises: - ValueError: If depth is odd or leaves count doesn't match `sqrt(LIFETIME)`. - """ - if depth % 2 != 0: - raise ValueError(f"Depth must be even for top-bottom split, got {depth}.") - - # Each bottom tree has exactly sqrt(LIFETIME) leaves. - leaves_per_tree = 1 << (depth // 2) - if len(leaves) != leaves_per_tree: - raise ValueError( - f"Expected {leaves_per_tree} leaves for depth={depth}, got {len(leaves)}." - ) - - # Build full tree from leaves. - full_tree = cls.new( - hasher=hasher, - rand=rand, - lowest_layer=Uint64(0), - depth=Uint64(depth), - start_index=bottom_tree_index * Uint64(leaves_per_tree), - parameter=parameter, - lowest_layer_nodes=leaves, - ) - - # Extract root from middle layer. - middle = full_tree.layers[depth // 2] - root_idx = int(bottom_tree_index - middle.start_index) - root_layer = HashTreeLayer( - start_index=bottom_tree_index, - nodes=HashDigestList(data=[middle.nodes[root_idx]]), - ) - - # Keep bottom half + single root node. - truncated = list(full_tree.layers[: depth // 2]) - return cls( - depth=Uint64(depth), - lowest_layer=Uint64(0), - layers=HashTreeLayers(data=truncated + [root_layer]), - ) - - @classmethod - def from_prf_key( - cls, - prf: Prf, - hasher: TweakHasher, - rand: Rand, - config: XmssConfig, - prf_key: PRFKey, - bottom_tree_index: Uint64, - parameter: Parameter, - ) -> Self: - """ - Generates a single bottom tree on-demand from the PRF key. - - This is a key component of the top-bottom tree approach: instead of storing all - one-time secret keys, we regenerate them on-demand using the PRF. This enables - O(sqrt(LIFETIME)) memory usage. - - ### Algorithm - - 1. **Determine epoch range**: Bottom tree `i` covers epochs - `[i * sqrt(LIFETIME), (i+1) * sqrt(LIFETIME))` - - 2. **Generate leaves**: For each epoch in parallel: - - For each chain (0 to DIMENSION-1): - - Derive secret start: `PRF(prf_key, epoch, chain_index)` - - Compute public end: hash chain for `BASE - 1` steps - - Hash all chain ends to get the leaf - - 3. **Build bottom tree**: Construct the bottom tree from the leaves - - Args: - prf: The PRF instance for key derivation. - hasher: The tweakable hash instance. - rand: Random generator for padding values. - config: The XMSS configuration. - prf_key: The master PRF secret key. - bottom_tree_index: The index of the bottom tree to generate (0, 1, 2, ...). - parameter: The public parameter `P` for the hash function. - - Returns: - A `HashSubTree` representing the requested bottom tree. - """ - # Calculate the number of leaves per bottom tree: sqrt(LIFETIME). - leaves_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) - - # Determine the epoch range for this bottom tree. - start_epoch = bottom_tree_index * Uint64(leaves_per_bottom_tree) - end_epoch = start_epoch + Uint64(leaves_per_bottom_tree) - - # Generate leaf hashes for all epochs in this bottom tree. - leaf_hashes: list[HashDigestVector] = [] - - for epoch in range(start_epoch, end_epoch): - # For each epoch, compute the one-time public key (chain endpoints). - chain_ends: list[HashDigestVector] = [] - - for chain_index in range(config.DIMENSION): - # Derive the secret start of the chain from the PRF key. - start_digest = prf.apply(prf_key, Uint64(epoch), Uint64(chain_index)) - - # Compute the public end by hashing BASE - 1 times. - end_digest = hasher.hash_chain( - parameter=parameter, - epoch=Uint64(epoch), - chain_index=chain_index, - start_step=0, - num_steps=config.BASE - 1, - start_digest=start_digest, - ) - chain_ends.append(end_digest) - - # Hash the chain ends to get the leaf for this epoch. - leaf_tweak = TreeTweak(level=0, index=Uint64(epoch)) - leaf_hash = hasher.apply(parameter, leaf_tweak, chain_ends) - leaf_hashes.append(leaf_hash) - - # Build the bottom tree from the leaf hashes. - return cls.new_bottom_tree( - hasher=hasher, - rand=rand, - depth=config.LOG_LIFETIME, - bottom_tree_index=bottom_tree_index, - parameter=parameter, - leaves=leaf_hashes, - ) - - def root(self) -> HashDigestVector: - """ - Extracts the root digest from this subtree. - - For top-bottom tree traversal, a subtree's root is the single node - in its highest layer. - - Returns: - The root hash digest of the subtree. - - Raises: - ValueError: If the subtree has no layers or the highest layer is empty. - """ - if not self.layers: - raise ValueError("Empty subtree has no root.") - if not self.layers[-1].nodes: - raise ValueError("Top layer is empty.") - return self.layers[-1].nodes[0] - - def path(self, position: Uint64) -> HashTreeOpening: - """ - Computes the authentication path for a leaf within this subtree. - - This is similar to full tree path computation but works with subtrees that may - not start from layer 0. The path is computed from the specified position up to - (but not including) the subtree's root. - - For a subtree covering layers L through H (where H is the highest/root layer), - this generates H - L siblings: one for each layer from L to H-1. - - Args: - position: The absolute index of the leaf in the full tree coordinate system. - - Returns: - A `HashTreeOpening` containing the sibling hashes for the path. - - Raises: - ValueError: If the subtree is empty or the position is out of bounds. - """ - if not self.layers: - raise ValueError("Empty subtree.") - - # Check bounds. - first = self.layers[0] - if not (first.start_index <= position < first.start_index + Uint64(len(first.nodes))): - raise ValueError(f"Position {position} out of bounds.") - - # Collect sibling at each layer (except root). - siblings: list[HashDigestVector] = [] - pos = position - - # Iterate over all layers except the last (root). - for layer in self.layers[:-1]: - # Sibling index: flip last bit of position, adjust for layer offset. - sibling_idx = int((pos ^ Uint64(1)) - layer.start_index) - if not (0 <= sibling_idx < len(layer.nodes)): - raise ValueError(f"Sibling index {sibling_idx} out of bounds.") - - siblings.append(layer.nodes[sibling_idx]) - pos = pos // Uint64(2) # Move to parent position. - - return HashTreeOpening(siblings=HashDigestList(data=siblings)) - - -def combined_path( - top_tree: HashSubTree, - bottom_tree: HashSubTree, - position: Uint64, -) -> HashTreeOpening: - """ - Generates a combined authentication path spanning top and bottom trees. - - For top-bottom tree traversal, a signature's authentication path must prove - that a leaf is part of the global Merkle root. This requires two proofs: - - 1. **Bottom tree path**: Proves the leaf is part of its bottom tree's root - 2. **Top tree path**: Proves the bottom tree's root is part of the global root - - This function combines both paths into a single `HashTreeOpening` that can - be used for verification. - - ### Algorithm - - 1. **Determine which bottom tree**: Calculate which bottom tree contains - the specified position. - - 2. **Get bottom tree path**: Extract the authentication path from the leaf - up to the bottom tree's root (depth/2 siblings). - - 3. **Get top tree path**: Extract the authentication path from the bottom - tree's root up to the global root (depth/2 siblings). - - 4. **Concatenate**: Combine both paths into a single path with `depth` siblings. - - Args: - top_tree: The top tree containing the global root. - bottom_tree: The bottom tree containing the specified position. - position: The absolute epoch/leaf index to generate a path for. - - Returns: - A `HashTreeOpening` with `depth` siblings that authenticates the leaf - against the global root. - - Raises: - ValueError: If trees have mismatched depths, odd depth, or position is - out of bounds for the bottom tree. - """ - # Validate matching depths. - if top_tree.depth != bottom_tree.depth: - raise ValueError(f"Depth mismatch: top={top_tree.depth}, bottom={bottom_tree.depth}.") - - depth = int(top_tree.depth) - if depth % 2 != 0: - raise ValueError(f"Depth must be even, got {depth}.") - - # Validate bottom tree matches position. - leaves_per_tree = Uint64(1 << (depth // 2)) - expected_start = (position // leaves_per_tree) * leaves_per_tree - if bottom_tree.layers[0].start_index != expected_start: - raise ValueError( - f"Wrong bottom tree: position {position} needs start {expected_start}, " - f"got {bottom_tree.layers[0].start_index}." - ) - - # Concatenate: bottom path + top path. - bottom_path = bottom_tree.path(position) - top_path = top_tree.path(position // leaves_per_tree) - combined = tuple(bottom_path.siblings.data) + tuple(top_path.siblings.data) - - return HashTreeOpening(siblings=HashDigestList(data=combined)) - - -def verify_path( - hasher: TweakHasher, - parameter: Parameter, - root: HashDigestVector, - position: Uint64, - leaf_parts: list[HashDigestVector], - opening: HashTreeOpening, -) -> bool: - """ - Verifies a Merkle authentication path against a known, trusted root. - - This function is the final check in signature verification. It proves that the - one-time public key used for the signature (represented by `leaf_parts`) is a - legitimate member of the set committed to by the Merkle `root`. - - ### Verification Algorithm - - 1. **Leaf Computation**: The process begins at the bottom. The verifier first - hashes the `leaf_parts` to compute the actual leaf digest. This becomes the - starting `current_node` for the climb up the tree. - - 2. **Bottom-Up Reconstruction**: The verifier iterates through the `opening.siblings` - path. At each `level`, it takes the `current_node` and the `sibling_node` - from the path. - - 3. **Parent Calculation**: It determines if the `current_node` is a left or - right child based on its `position`. The two nodes are placed in the - correct `(left, right)` order and hashed (with the correct `TreeTweak`) - to compute the parent. This parent becomes the `current_node` for the - next level. - - 4. **Final Comparison**: After all siblings are used, the final `current_node` - is the candidate root. The path is valid if and only if it matches the trusted `root`. - - Args: - hasher: The tweakable hash instance for computing parent nodes. - parameter: The public parameter `P` for the hash function. - root: The known, trusted Merkle root from the public key. - position: The absolute index of the leaf being verified. - leaf_parts: The list of digests that constitute the original leaf. - opening: The `HashTreeOpening` object containing the sibling path. - - Returns: - `True` if the path is valid and reconstructs the root, `False` otherwise. - Returns `False` for invalid inputs (depth > 32 or position out of bounds). - """ - # Validate depth and position bounds. - # - # These checks guard against malformed attacker-controlled input. - # Return False instead of raising to avoid panic on invalid signatures. - depth = len(opening.siblings) - if depth > 32: - return False - if int(position) >= (1 << depth): - return False - - # Start: hash leaf parts to get leaf node. - current = hasher.apply( - parameter, - TreeTweak(level=0, index=Uint64(position)), - leaf_parts, - ) - pos = int(position) - - # Walk up: hash current with each sibling. - for level, sibling in enumerate(opening.siblings): - # Left child has even position, right child has odd. - left, right = (current, sibling) if pos % 2 == 0 else (sibling, current) - pos //= 2 # Parent position. - current = hasher.apply( - parameter, - TreeTweak(level=level + 1, index=Uint64(pos)), - [left, right], - ) - - # Valid if we reconstructed the expected root. - return current == root diff --git a/src/lean_spec/subspecs/xmss/target_sum.py b/src/lean_spec/subspecs/xmss/target_sum.py deleted file mode 100644 index 985b29c38..000000000 --- a/src/lean_spec/subspecs/xmss/target_sum.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -Implements the Top Level Target Sum Winternitz incomparable encoding scheme. - -This module provides the logic for converting a message hash into a valid -codeword for the one-time signature part of the scheme. It acts as a filter on -top of the message hash output. -""" - -from __future__ import annotations - -from lean_spec.types import Bytes32, StrictBaseModel, Uint64 - -from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig -from .message_hash import ( - PROD_MESSAGE_HASHER, - TEST_MESSAGE_HASHER, - MessageHasher, -) -from .types import Parameter, Randomness - - -class TargetSumEncoder(StrictBaseModel): - """ - An instance of the Target Sum encoder for a given configuration. - - This class encapsulates the logic for validating a message hash against the - scheme's target sum constraint. - """ - - config: XmssConfig - """Configuration parameters for the encoder.""" - - message_hasher: MessageHasher - """Message hasher for encoding.""" - - def encode( - self, parameter: Parameter, message: Bytes32, rho: Randomness, epoch: Uint64 - ) -> list[int] | None: - """ - Encodes a message into a codeword if it meets the target sum criteria. - - ### Encoding Algorithm - - 1. **Hashing to a Vertex**: The function first hashes the inputs (`message`, - `rho`, etc.) to produce a candidate codeword. This can be viewed as - mapping the inputs to a vertex in a high-dimensional hypercube, where - the vertex's coordinates are the digits of the codeword. - - 2. **Target Sum Validation**: It then checks if the sum of the candidate's digits - matches the scheme's predefined `TARGET_SUM`. This is equivalent to - verifying that the vertex lies on the correct hypercube layer. This - constraint is critical for the scheme's security and ensures a - predictable number of hash operations during signature verification. - - Args: - parameter: The public parameter `P`, used for domain separation. - message: The message to encode. - rho: The randomness used for this specific encoding attempt. - epoch: The current epoch, used as part of the hash input. - - Returns: - - The codeword (a list of integers) if the sum matches the target. - - Otherwise, it returns `None` to signal that this attempt failed and - a new `rho` must be tried. - """ - # Hash the inputs to map them to a potential codeword (a vertex in the hypercube). - codeword_candidate = self.message_hasher.apply(parameter, epoch, rho, message) - - # The aborting decode may reject if a field element equals P - 1. - if codeword_candidate is None: - return None - - # A codeword is valid only if it lies on the predefined hypercube layer. - # - # This is verified by checking if the sum of its coordinates equals TARGET_SUM. - if sum(codeword_candidate) == self.config.TARGET_SUM: - # If the sum is correct, this is a valid codeword for the one-time signature. - return codeword_candidate - - # If the sum does not match, this `rho` is invalid for this message. - # The caller will need to try again with new randomness. - return None - - -PROD_TARGET_SUM_ENCODER = TargetSumEncoder(config=PROD_CONFIG, message_hasher=PROD_MESSAGE_HASHER) -"""An instance configured for production-level parameters.""" - -TEST_TARGET_SUM_ENCODER = TargetSumEncoder(config=TEST_CONFIG, message_hasher=TEST_MESSAGE_HASHER) -"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/tweak_hash.py b/src/lean_spec/subspecs/xmss/tweak_hash.py deleted file mode 100644 index d83c2ea70..000000000 --- a/src/lean_spec/subspecs/xmss/tweak_hash.py +++ /dev/null @@ -1,253 +0,0 @@ -""" -Defines the Tweakable Hash function using Poseidon1. - -### The Problem: Hash Function Overload - -In a complex cryptographic scheme like XMSS, a single hash function (like Poseidon1) -is used for many different purposes: -1. Hashing iteratively to form **hash chains**. -2. Hashing pairs of nodes to build the **Merkle tree**. -3. Hashing the one-time public key to form a **Merkle leaf**. - -If we simply called `hash(data)` for all these cases, we could run into a critical -security issue: a "collision" between different contexts. For example, the output of a -hash in a chain might accidentally be identical to the hash of two nodes in the tree. -This could allow an attacker to create forgeries. - -### The Solution: Tweakable Hashing - -A **tweakable hash function** solves this by treating each hash computation as having a -unique "address" or "tweak". - -The function's signature becomes `hash(tweak, data)`. By ensuring every tweak is -unique across the entire scheme, we guarantee that every hash computation is -**domain-separated**, eliminating the risk of cross-context collisions. -""" - -from typing import NamedTuple - -from lean_spec.types import StrictBaseModel, Uint64 - -from ..koalabear import Fp -from .constants import ( - PROD_CONFIG, - TEST_CONFIG, - TWEAK_PREFIX_CHAIN, - TWEAK_PREFIX_TREE, - XmssConfig, -) -from .poseidon import ( - PROD_POSEIDON, - TEST_POSEIDON, - PoseidonXmss, -) -from .types import HashDigestVector, Parameter -from .utils import int_to_base_p - - -class TreeTweak(NamedTuple): - """ - A tweak used for hashing nodes within the Merkle tree. - - This structure ensures that every hash computed during the construction of the - Merkle tree has a unique context. - """ - - level: int - """The level (height) in the Merkle tree, where 0 is the leaf level.""" - - index: Uint64 - """The node's index (from the left) within that level.""" - - -class ChainTweak(NamedTuple): - """ - A tweak used for hashing elements within a WOTS+ hash chain. - - This structure ensures every iterative hash within every one-time signature - chain is distinct across all epochs. - """ - - epoch: Uint64 - """The signature epoch.""" - - chain_index: int - """The index of the hash chain (from 0 to DIMENSION-1).""" - - step: int - """The step number within the chain (from 1 to BASE-1).""" - - -class TweakHasher(StrictBaseModel): - """An instance of the Tweakable Hasher for a given config.""" - - config: XmssConfig - """Configuration parameters for the hasher.""" - - poseidon: PoseidonXmss - """Poseidon permutation instance for hashing.""" - - def _encode_tweak(self, tweak: TreeTweak | ChainTweak, length: int) -> list[Fp]: - """ - Encodes a structured tweak object into a list of field elements. - - It converts a high-level tweak context (like "Merkle tree, level 5, index 3") - into a low-level format that can be consumed by the Poseidon1 hash function. - - ### Encoding Algorithm - - 1. **Packing**: The integer components of the tweak are packed into a - single, large integer using bit-shifting. A unique prefix - (`TWEAK_PREFIX_TREE` or `TWEAK_PREFIX_CHAIN`) is included to - guarantee that `TreeTweak` and `ChainTweak` can never produce - the same integer. This process is injective (one-to-one). - - 2. **Decomposition**: The resulting large integer is then decomposed into a - list of base-`P` digits, where `P` is the field prime. This produces the - final list of field elements. - - Args: - tweak: The `TreeTweak` or `ChainTweak` object to encode. - length: The desired number of field elements in the output list. - - Returns: - A list of `length` field elements representing the encoded tweak. - """ - # Pack the tweak's integer fields into a single large integer. - # - # A hardcoded prefix is included for domain separation between tweak types. - match tweak: - case TreeTweak(level=level, index=index): - # Packing scheme: (level << 40) | (index << 8) | PREFIX - acc = (level << 40) | (int(index) << 8) | TWEAK_PREFIX_TREE - case ChainTweak(epoch=epoch, chain_index=chain_index, step=step): - # Packing scheme: (epoch << 24) | (chain_index << 16) | (step << 8) | PREFIX - acc = (int(epoch) << 24) | (chain_index << 16) | (step << 8) | TWEAK_PREFIX_CHAIN - - # Decompose the packed integer `acc` into a list of base-P field elements. - return int_to_base_p(acc, length) - - def apply( - self, - parameter: Parameter, - tweak: TreeTweak | ChainTweak, - message_parts: list[HashDigestVector], - ) -> HashDigestVector: - """ - Applies the tweakable Poseidon1 hash function to a message. - - This is the main entry point for all internal hashing operations. It prepares - the inputs and routes them to the appropriate Poseidon1 function based on - the input size, ensuring optimal performance and security. - - ### Hashing Algorithm - - 1. **Input Assembly**: The final hash input is formed by concatenating: - `[parameter || encoded_tweak || flattened_message_parts]` - - 2. **Mode Selection**: - - For small inputs (1 or 2 `HashDigest` parts), it uses the highly - efficient **compression mode** of Poseidon1. - - For large inputs (many `HashDigest` parts, like a Merkle leaf), - it uses the more flexible **sponge mode**. - - Args: - parameter: The public parameter `P` for this key pair. - tweak: A `TreeTweak` or `ChainTweak` for domain separation. - message_parts: A list of one or more hash digests to be hashed together. - - Returns: - A new hash digest of `HASH_LEN_FE` field elements. - """ - # Get the config for this scheme. - config = self.config - - # Encode the high-level tweak structure into a list of field elements. - encoded_tweak = self._encode_tweak(tweak, config.TWEAK_LEN_FE) - - # Route to the correct Poseidon1 mode based on the input size. - if len(message_parts) == 1: - # Case 1: Hashing a single digest (used in hash chains). - # - # We use the efficient width-16 compression mode. - input_vec = message_parts[0].elements + parameter.elements + encoded_tweak - result = self.poseidon.compress(input_vec, 16, config.HASH_LEN_FE) - - elif len(message_parts) == 2: - # Case 2: Hashing two digests (used for Merkle tree nodes). - # - # We use the slightly larger width-24 compression mode. - input_vec = ( - parameter.elements - + encoded_tweak - + message_parts[0].elements - + message_parts[1].elements - ) - result = self.poseidon.compress(input_vec, 24, config.HASH_LEN_FE) - - else: - # Case 3: Hashing many digests (used for the Merkle tree leaf). - # - # We use the robust sponge mode. - # First, flatten the list of message parts into a single vector. - flattened_message = [elem for part in message_parts for elem in part.elements] - input_vec = parameter.elements + encoded_tweak + flattened_message - - # Create a domain separator for the sponge mode based on the input dimensions. - # - # This ensures the sponge is uniquely configured for this specific hashing task. - lengths = [ - config.PARAMETER_LEN, - config.TWEAK_LEN_FE, - config.DIMENSION, - config.HASH_LEN_FE, - ] - capacity_value = self.poseidon.safe_domain_separator(lengths, config.CAPACITY) - - result = self.poseidon.sponge(input_vec, capacity_value, config.HASH_LEN_FE, 24) - - return HashDigestVector(data=result) - - def hash_chain( - self, - parameter: Parameter, - epoch: Uint64, - chain_index: int, - start_step: int, - num_steps: int, - start_digest: HashDigestVector, - ) -> HashDigestVector: - """ - Performs repeated hashing to traverse a WOTS+ hash chain. - - This function iteratively calls the main `apply` method, creating a new, - unique `ChainTweak` for each step to ensure every hash in the sequence - is domain-separated. - - Args: - parameter: The public parameter `P`. - epoch: The signature epoch, part of the tweak. - chain_index: The index of the hash chain, part of the tweak. - start_step: The starting step number in the chain. - num_steps: The number of hashing steps to perform. - start_digest: The digest to begin hashing from. - - Returns: - The final hash digest after `num_steps` applications. - """ - current_digest = start_digest - for i in range(num_steps): - # Create a unique tweak for the current position in the chain. - # - # The `step` is `start_step + i + 1` because steps are 1-indexed. - tweak = ChainTweak(epoch=epoch, chain_index=chain_index, step=start_step + i + 1) - # Apply the hash function to get the next digest in the chain. - current_digest = self.apply(parameter, tweak, [current_digest]) - return current_digest - - -PROD_TWEAK_HASHER = TweakHasher(config=PROD_CONFIG, poseidon=PROD_POSEIDON) -"""An instance configured for production-level parameters.""" - -TEST_TWEAK_HASHER = TweakHasher(config=TEST_CONFIG, poseidon=TEST_POSEIDON) -"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/types.py b/src/lean_spec/subspecs/xmss/types.py index 41fe7fc00..b0c4567fd 100644 --- a/src/lean_spec/subspecs/xmss/types.py +++ b/src/lean_spec/subspecs/xmss/types.py @@ -1,83 +1,78 @@ """Base types for the XMSS signature scheme.""" -from typing import Final +from typing import Final, NamedTuple from lean_spec.subspecs.koalabear import Fp from ...types import Uint64 -from ...types.byte_arrays import BaseBytes from ...types.collections import SSZList, SSZVector from ...types.container import Container -from .constants import PRF_KEY_LENGTH, TARGET_CONFIG +from .constants import TARGET_CONFIG -class PRFKey(BaseBytes): - """ - The PRF master secret key. +class TreeTweak(NamedTuple): + """Tweak that domain-separates Merkle node hashes by their position.""" + + level: int + """Height in the Merkle tree. - This is a high-entropy byte string that acts as the single root secret from - which all one-time signing keys are deterministically derived. + Layer 0 is the leaf level. """ - LENGTH = PRF_KEY_LENGTH + index: Uint64 + """Node index within its level, counted from the left.""" -HASH_DIGEST_LENGTH: Final = TARGET_CONFIG.HASH_LEN_FE -""" -The fixed length of a hash digest in field elements. +class ChainTweak(NamedTuple): + """Tweak that domain-separates Winternitz chain hashes by their position.""" + + epoch: Uint64 + """Slot identifier for the one-time signature.""" + + chain_index: int + """Index of the chain within the one-time signature.""" + + step: int + """ + - Step number along the chain. + - Steps are 1-indexed. + - Step zero is the chain start. + """ -Derived from `TARGET_CONFIG.HASH_LEN_FE`. This corresponds to the output length -of the Poseidon1 hash function used in the XMSS scheme. -""" -# Calculate the maximum number of nodes in a sparse Merkle tree layer: -# - A bottom tree has at most 2^(LOG_LIFETIME/2) leaves -# - With padding, we may add up to 2 additional nodes -# - To be generous and future-proof, we use 2^(LOG_LIFETIME/2 + 1) -NODE_LIST_LIMIT: Final = 1 << (TARGET_CONFIG.LOG_LIFETIME // 2 + 1) +NODE_LIST_LIMIT: Final = 2 * TARGET_CONFIG.LEAVES_PER_BOTTOM_TREE """ -The maximum number of nodes that can be stored in a sparse Merkle tree layer. +Maximum number of nodes a sparse Merkle tree layer can hold. -Calculated as `2^(LOG_LIFETIME/2 + 1)` from TARGET_CONFIG to accommodate: -- Bottom trees with up to `2^(LOG_LIFETIME/2)` nodes -- Padding overhead (up to 2 additional nodes) -- Future-proofing with 2x margin +- The widest layer is a bottom tree's leaf row, the square root of the lifetime in leaves. +- Padding adds at most one sibling at each end. +- Twice the leaf count is a generous cap that absorbs the padding with room to spare. """ class HashDigestVector(SSZVector[Fp]): """ - A single hash digest represented as a fixed-size vector of field elements. + A single hash digest as a fixed-size vector of field elements. - This is the SSZ-compliant representation of a Poseidon1 hash output. - In SSZ notation: `Vector[Fp, HASH_DIGEST_LENGTH]` - - The fixed size enables efficient serialization when used in collections, - as SSZ can pack these back-to-back without per-element offsets. + The fixed size lets SSZ pack these back-to-back without per-element offsets. """ - LENGTH = HASH_DIGEST_LENGTH + LENGTH = TARGET_CONFIG.HASH_LEN_FE + """One Poseidon1 digest, measured in field elements.""" class HashDigestList(SSZList[HashDigestVector]): - """ - Variable-length list of hash digests. - - In SSZ notation: `List[Vector[Fp, HASH_DIGEST_LENGTH], NODE_LIST_LIMIT]` - - This type is used to represent collections of hash digests in the XMSS scheme. - """ + """Variable-length list of hash digests.""" LIMIT = NODE_LIST_LIMIT class Parameter(SSZVector[Fp]): - """ - The public parameter P. + """The public parameter P. - This is a unique, randomly generated value associated with a single key pair. It - is mixed into every hash computation to "personalize" the hash function, preventing - certain cross-key attacks. It is public knowledge. + - Unique, randomly generated value associated with a single key pair. + - Mixed into every hash to personalize the function and block cross-key attacks. + - Public knowledge. """ LENGTH = TARGET_CONFIG.PARAMETER_LEN @@ -85,13 +80,11 @@ class Parameter(SSZVector[Fp]): class Randomness(SSZVector[Fp]): """ - The randomness `rho` (ρ) used during signing. - - This value provides a variable input to the message hash, allowing the signer to - repeatedly try hashing until a valid "codeword" is found. It must be included in - the final signature for the verifier to reproduce the same hash. + Fresh randomness mixed into the message hash during signing. - SSZ notation: `Vector[Fp, RAND_LEN_FE]` + - Signing rehashes the message with new randomness on each attempt. + - Retries continue until the resulting codeword hits the target sum. + - The chosen randomness travels in the signature so the verifier recomputes the same codeword. """ LENGTH = TARGET_CONFIG.RAND_LEN_FE @@ -99,55 +92,12 @@ class Randomness(SSZVector[Fp]): class HashTreeOpening(Container): """ - A Merkle authentication path. + A Merkle authentication path proving one leaf sits under the root. - This object contains the minimal proof required to connect a specific leaf - to the Merkle root. It consists of the list of all sibling nodes along the - path from the leaf to the top of the tree. - - SSZ Container with fields: - - siblings: List[Vector[Fp, HASH_DIGEST_LENGTH], NODE_LIST_LIMIT] + - The path lists the sibling hashes met while climbing from the leaf up to the root. + - A verifier rehashes the leaf upward with these siblings. + - The reconstructed root must equal the trusted root. """ siblings: HashDigestList - """SSZ-compliant list of sibling hashes, from bottom to top.""" - - -class HashTreeLayer(Container): - """ - Represents a single horizontal "slice" of the sparse Merkle tree. - - Because the tree is sparse, we only store the nodes that are actually computed - for the active range of leaves, not the entire conceptual layer. - """ - - start_index: Uint64 - """The starting index of the first node in this layer.""" - nodes: HashDigestList - """SSZ-compliant list of hash digests stored for this layer.""" - - -LAYERS_LIMIT: Final = TARGET_CONFIG.LOG_LIFETIME + 1 -""" -The maximum number of layers in a subtree. - -This is `LOG_LIFETIME + 1` to accommodate all layers from 0 (leaves) to LOG_LIFETIME (root), -inclusive. For example, with LOG_LIFETIME=32, this allows up to 33 layers. -""" - - -class HashTreeLayers(SSZList[HashTreeLayer]): - """ - Variable-length list of Merkle tree layers. - - In SSZ notation: `List[HashTreeLayer, LAYERS_LIMIT]` - - This type represents the layers of a subtree, from the lowest layer up to the root. - - The number of layers varies based on the subtree structure: - - Bottom trees: `LOG_LIFETIME/2` layers - - Top trees: `LOG_LIFETIME/2` layers - - Maximum: `LOG_LIFETIME + 1` layers - """ - - LIMIT = LAYERS_LIMIT + """Sibling hashes, ordered from the leaf upward to the root.""" diff --git a/src/lean_spec/subspecs/xmss/utils.py b/src/lean_spec/subspecs/xmss/utils.py deleted file mode 100644 index 623c3b765..000000000 --- a/src/lean_spec/subspecs/xmss/utils.py +++ /dev/null @@ -1,155 +0,0 @@ -"""Utility functions for the XMSS signature scheme.""" - -from __future__ import annotations - -from ...types.uint import Uint64 -from ..koalabear import Fp, P -from .rand import Rand -from .types import HashDigestList, HashDigestVector, HashTreeLayer - - -def get_padded_layer( - rand: Rand, nodes: list[HashDigestVector], start_index: Uint64 -) -> HashTreeLayer: - """ - Pads a layer of nodes with random hashes to simplify tree construction. - - This helper enforces a crucial invariant: every active layer must start at an - even index and end at an odd index. This guarantees that every node within - the layer can be neatly paired with a sibling (a left child with a right - child), which dramatically simplifies the parent generation logic by - removing the need to handle edge cases. - - Args: - rand: Random generator for padding values. - nodes: The list of active nodes for the current layer. - start_index: The starting index of the first node in `nodes`. - - Returns: - A new `HashTreeLayer` with the necessary padding applied. - """ - nodes_with_padding: list[HashDigestVector] = [] - end_index = start_index + Uint64(len(nodes)) - Uint64(1) - - # Prepend random padding if the layer starts at an odd index. - if start_index % Uint64(2) == Uint64(1): - nodes_with_padding.append(rand.domain()) - - # The actual start index of the padded layer is always the even - # number at or immediately before the original start_index. - actual_start_index = start_index - (start_index % Uint64(2)) - - # Add the actual node content. - nodes_with_padding.extend(nodes) - - # Append random padding if the layer ends at an even index. - if end_index % Uint64(2) == Uint64(0): - nodes_with_padding.append(rand.domain()) - - return HashTreeLayer( - start_index=actual_start_index, nodes=HashDigestList(data=nodes_with_padding) - ) - - -def int_to_base_p(value: int, num_limbs: int) -> list[Fp]: - """ - Decomposes a large integer into a list of base-P field elements. - - This function performs a standard base conversion, where each "digit" - is an element in the prime field F_p. - - Args: - value: The integer to decompose. - num_limbs: The desired number of output field elements (limbs). - - Returns: - A list of `num_limbs` field elements representing the integer. - """ - limbs: list[Fp] = [] - acc = value - for _ in range(num_limbs): - limbs.append(Fp(value=acc)) - acc //= P - return limbs - - -def expand_activation_time( - log_lifetime: int, desired_activation_slot: int, desired_num_active_slots: int -) -> tuple[int, int]: - """ - Expands and aligns the activation time to top-bottom tree boundaries. - - For efficient top-bottom tree traversal, activation intervals must be aligned to - `sqrt(LIFETIME)` boundaries. This function takes the user's desired activation - interval and expands it to meet the following requirements: - - 1. **Start alignment**: Start slot is rounded down to a multiple of `sqrt(LIFETIME)` - 2. **End alignment**: End slot is rounded up to a multiple of `sqrt(LIFETIME)` - 3. **Minimum duration**: At least `2 * sqrt(LIFETIME)` slots (two bottom trees) - 4. **Lifetime bounds**: Clamped to `[0, LIFETIME)` - - ### Algorithm - - Let `C = 2^(LOG_LIFETIME/2) = sqrt(LIFETIME)` - - 1. Align start downward: `start = desired_start & c_mask` where `c_mask = ~(C - 1)` - 2. Round end upward: `end = (desired_end + C - 1) & c_mask` - 3. Enforce minimum: `if end - start < 2*C: end = start + 2*C` - 4. Clamp to bounds: Adjust if end exceeds `C^2 = LIFETIME` - - ### Example - - For `LOG_LIFETIME = 32` (LIFETIME = 2^32, C = 2^16 = 65536): - - Request: slots [10000, 80000) → 70000 slots - - Aligned: slots [0, 131072) → 131072 slots = 2 bottom trees - - Args: - log_lifetime: The logarithm (base 2) of the total lifetime. - desired_activation_slot: The user's requested first slot. - desired_num_active_slots: The user's requested number of slots. - - Returns: - A tuple `(start_bottom_tree_index, end_bottom_tree_index)` where: - - `start_bottom_tree_index`: Index of the first bottom tree (0, 1, 2, ...) - - `end_bottom_tree_index`: Index past the last bottom tree (exclusive) - - Actual slots: `[start_index * C, end_index * C)` - """ - # Calculate sqrt(LIFETIME) and the alignment mask. - c = 1 << (log_lifetime // 2) # C = 2^(LOG_LIFETIME/2) - c_mask = ~(c - 1) # Mask for rounding to multiples of C - - # Calculate the desired end slot. - desired_end_slot = desired_activation_slot + desired_num_active_slots - - # Step 1: Align start downward to a multiple of C. - start = desired_activation_slot & c_mask - - # Step 2: Round end upward to a multiple of C. - end = (desired_end_slot + c - 1) & c_mask - - # Step 3: Enforce minimum duration of 2*C. - if end - start < 2 * c: - end = start + 2 * c - - # Step 4: Clamp to lifetime bounds [0, C^2). - lifetime = c * c # LIFETIME = C^2 = 2^LOG_LIFETIME - if end > lifetime: - # If the expanded interval exceeds the lifetime, try to fit it at the end. - duration = end - start - - if duration > lifetime: - # The expanded interval is larger than the entire lifetime. - # Use the entire lifetime. - start = 0 - end = lifetime - else: - # Shift the interval to end at the lifetime boundary. - end = lifetime - start = (lifetime - duration) & c_mask # Keep alignment - - # Convert to bottom tree indices. - # Bottom tree i covers slots [i*C, (i+1)*C). - start_bottom_tree_index = start // c - end_bottom_tree_index = end // c - - return (start_bottom_tree_index, end_bottom_tree_index) diff --git a/tests/consensus/lstar/ssz/test_xmss_containers.py b/tests/consensus/lstar/ssz/test_xmss_containers.py index 4b10ded57..46e2b4f00 100644 --- a/tests/consensus/lstar/ssz/test_xmss_containers.py +++ b/tests/consensus/lstar/ssz/test_xmss_containers.py @@ -10,11 +10,11 @@ TypeOneMultiSignature, TypeTwoMultiSignature, ) +from lean_spec.subspecs.xmss.constants import TARGET_CONFIG +from lean_spec.subspecs.xmss.merkle import HashTreeLayer from lean_spec.subspecs.xmss.types import ( - HASH_DIGEST_LENGTH, HashDigestList, HashDigestVector, - HashTreeLayer, HashTreeOpening, Parameter, ) @@ -36,7 +36,7 @@ def _zero_hash_digest_vector() -> HashDigestVector: """Build a hash digest vector with all field elements set to zero.""" - return HashDigestVector(data=[Fp(0) for _ in range(HASH_DIGEST_LENGTH)]) + return HashDigestVector(data=[Fp(0) for _ in range(TARGET_CONFIG.HASH_LEN_FE)]) def _zero_parameter() -> Parameter: @@ -120,7 +120,7 @@ def test_public_key_typical(ssz: SSZTestFiller) -> None: ssz( type_name="PublicKey", value=PublicKey( - root=HashDigestVector(data=[Fp(i + 1) for i in range(HASH_DIGEST_LENGTH)]), + root=HashDigestVector(data=[Fp(i + 1) for i in range(TARGET_CONFIG.HASH_LEN_FE)]), parameter=Parameter(data=[Fp(100 + i) for i in range(Parameter.LENGTH)]), ), ) @@ -144,7 +144,9 @@ def test_hash_tree_opening_typical(ssz: SSZTestFiller) -> None: value=HashTreeOpening( siblings=HashDigestList( data=[ - HashDigestVector(data=[Fp(i + j * 10) for i in range(HASH_DIGEST_LENGTH)]) + HashDigestVector( + data=[Fp(i + j * 10) for i in range(TARGET_CONFIG.HASH_LEN_FE)] + ) for j in range(3) ] ) @@ -174,7 +176,7 @@ def test_hash_tree_layer_typical(ssz: SSZTestFiller) -> None: start_index=Uint64(42), nodes=HashDigestList( data=[ - HashDigestVector(data=[Fp(i + j * 7) for i in range(HASH_DIGEST_LENGTH)]) + HashDigestVector(data=[Fp(i + j * 7) for i in range(TARGET_CONFIG.HASH_LEN_FE)]) for j in range(2) ] ), diff --git a/tests/lean_spec/forks/lstar/forkchoice/test_attestation_target.py b/tests/lean_spec/forks/lstar/forkchoice/test_attestation_target.py index abbc30064..9fdc82dff 100644 --- a/tests/lean_spec/forks/lstar/forkchoice/test_attestation_target.py +++ b/tests/lean_spec/forks/lstar/forkchoice/test_attestation_target.py @@ -23,7 +23,6 @@ Checkpoint, Slot, ValidatorIndex, - ValidatorIndices, ) from tests.lean_spec.helpers import make_store @@ -581,8 +580,7 @@ def test_attestation_target_after_on_block( proposer_pubkey = key_manager.get_public_keys(proposer_1)[1] proposer_type_1 = TypeOneMultiSignature.aggregate( children=[], - raw_xmss=[(proposer_pubkey, proposer_signature)], - xmss_participants=ValidatorIndices(data=[proposer_1]).to_aggregation_bits(), + raw_xmss=[(proposer_1, proposer_pubkey, proposer_signature)], message=block_root, slot=slot_1, ) diff --git a/tests/lean_spec/forks/lstar/forkchoice/test_store_attestations.py b/tests/lean_spec/forks/lstar/forkchoice/test_store_attestations.py index 5810ac00b..4b6f82998 100644 --- a/tests/lean_spec/forks/lstar/forkchoice/test_store_attestations.py +++ b/tests/lean_spec/forks/lstar/forkchoice/test_store_attestations.py @@ -22,7 +22,6 @@ Checkpoint, Slot, ValidatorIndex, - ValidatorIndices, ) from tests.lean_spec.helpers import ( TEST_VALIDATOR_ID, @@ -293,16 +292,15 @@ def test_valid_proof_stored_correctly( data_root = hash_tree_root(attestation_data) # Create valid aggregated proof - xmss_participants = ValidatorIndices(data=participants).to_aggregation_bits() - raw_xmss = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in participants], - [key_manager.sign_attestation_data(vid, attestation_data) for vid in participants], - strict=True, + raw_xmss = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, attestation_data), ) - ) + for vid in participants + ] proof = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=[], raw_xmss=raw_xmss, message=data_root, @@ -340,16 +338,15 @@ def test_attestation_data_used_as_key( data_root = hash_tree_root(attestation_data) - xmss_participants = ValidatorIndices(data=participants).to_aggregation_bits() - raw_xmss = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in participants], - [key_manager.sign_attestation_data(vid, attestation_data) for vid in participants], - strict=True, + raw_xmss = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, attestation_data), ) - ) + for vid in participants + ] proof = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=[], raw_xmss=raw_xmss, message=data_root, @@ -380,16 +377,15 @@ def test_invalid_proof_rejected(self, key_manager: XmssKeyManager, spec: LstarSp data_root = hash_tree_root(attestation_data) - xmss_participants = ValidatorIndices(data=signers).to_aggregation_bits() - raw_xmss = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in signers], - [key_manager.sign_attestation_data(vid, attestation_data) for vid in signers], - strict=True, + raw_xmss = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, attestation_data), ) - ) + for vid in signers + ] proof = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=[], raw_xmss=raw_xmss, message=data_root, @@ -426,19 +422,15 @@ def test_multiple_proofs_accumulate(self, key_manager: XmssKeyManager, spec: Lst # First proof: validators 1 and 2 participants_1 = [ValidatorIndex(1), ValidatorIndex(2)] - xmss_1 = ValidatorIndices(data=participants_1).to_aggregation_bits() - raw_xmss_1 = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in participants_1], - [ - key_manager.sign_attestation_data(vid, attestation_data) - for vid in participants_1 - ], - strict=True, + raw_xmss_1 = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, attestation_data), ) - ) + for vid in participants_1 + ] proof_1 = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_1, children=[], raw_xmss=raw_xmss_1, message=data_root, @@ -447,19 +439,15 @@ def test_multiple_proofs_accumulate(self, key_manager: XmssKeyManager, spec: Lst # Second proof: validators 1 and 3 (validator 1 overlaps) participants_2 = [ValidatorIndex(1), ValidatorIndex(3)] - xmss_2 = ValidatorIndices(data=participants_2).to_aggregation_bits() - raw_xmss_2 = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in participants_2], - [ - key_manager.sign_attestation_data(vid, attestation_data) - for vid in participants_2 - ], - strict=True, + raw_xmss_2 = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, attestation_data), ) - ) + for vid in participants_2 + ] proof_2 = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_2, children=[], raw_xmss=raw_xmss_2, message=data_root, diff --git a/tests/lean_spec/helpers/builders.py b/tests/lean_spec/helpers/builders.py index cbb33e1f6..766dbef49 100644 --- a/tests/lean_spec/helpers/builders.py +++ b/tests/lean_spec/helpers/builders.py @@ -410,16 +410,15 @@ def make_aggregated_proof( ) -> TypeOneMultiSignature: """Create a valid Type-1 aggregated proof for the given participants.""" data_root = hash_tree_root(attestation_data) - xmss_participants = ValidatorIndices(data=participants).to_aggregation_bits() - raw_xmss = list( - zip( - [key_manager.get_public_keys(vid)[0] for vid in participants], - [key_manager.sign_attestation_data(vid, attestation_data) for vid in participants], - strict=True, + raw_xmss = [ + ( + vid, + key_manager.get_public_keys(vid)[0], + key_manager.sign_attestation_data(vid, attestation_data), ) - ) + for vid in participants + ] return TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=[], raw_xmss=raw_xmss, message=data_root, @@ -480,8 +479,7 @@ def make_signed_block_from_store( proposer_signature = key_manager.sign_block_root(proposer_index, slot, block_root) proposer_type_1 = TypeOneMultiSignature.aggregate( children=[], - raw_xmss=[(proposer_pubkey, proposer_signature)], - xmss_participants=ValidatorIndices(data=[proposer_index]).to_aggregation_bits(), + raw_xmss=[(proposer_index, proposer_pubkey, proposer_signature)], message=block_root, slot=slot, ) diff --git a/tests/lean_spec/subspecs/validator/test_service.py b/tests/lean_spec/subspecs/validator/test_service.py index f1bed1e77..d6609d0fa 100644 --- a/tests/lean_spec/subspecs/validator/test_service.py +++ b/tests/lean_spec/subspecs/validator/test_service.py @@ -26,7 +26,7 @@ from lean_spec.subspecs.validator.registry import ValidatorEntry from lean_spec.subspecs.xmss import TARGET_SIGNATURE_SCHEME from lean_spec.subspecs.xmss.aggregation import TypeOneMultiSignature, TypeTwoMultiSignature -from lean_spec.types import Bytes32, Slot, Uint64, ValidatorIndex, ValidatorIndices +from lean_spec.types import Bytes32, Slot, Uint64, ValidatorIndex from tests.lean_spec.helpers import ( TEST_VALIDATOR_ID, MockNetworkRequester, @@ -1077,11 +1077,9 @@ async def test_block_includes_pending_attestations( signatures.append(sig) public_keys.append(key_manager[vid].attestation_keypair.public_key) - xmss_participants = ValidatorIndices(data=participants).to_aggregation_bits() proof = TypeOneMultiSignature.aggregate( children=[], - raw_xmss=list(zip(public_keys, signatures, strict=True)), - xmss_participants=xmss_participants, + raw_xmss=list(zip(participants, public_keys, signatures, strict=True)), message=data_root, slot=attestation_data.slot, ) diff --git a/tests/lean_spec/subspecs/xmss/__init__.py b/tests/lean_spec/subspecs/xmss/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/lean_spec/subspecs/xmss/test_aggregation.py b/tests/lean_spec/subspecs/xmss/test_aggregation.py index 4a8bd8213..b23c2712e 100644 --- a/tests/lean_spec/subspecs/xmss/test_aggregation.py +++ b/tests/lean_spec/subspecs/xmss/test_aggregation.py @@ -16,7 +16,6 @@ Checkpoint, Slot, ValidatorIndex, - ValidatorIndices, ) from tests.lean_spec.helpers import make_attestation_data_simple, make_bytes32 @@ -31,16 +30,15 @@ def _sign_and_aggregate( att_data = make_attestation_data_simple(slot, make_bytes32(head), make_bytes32(target), source) data_root = hash_tree_root(att_data) - xmss_participants = ValidatorIndices(data=validator_ids).to_aggregation_bits() - raw_xmss = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in validator_ids], - [key_manager.sign_attestation_data(vid, att_data) for vid in validator_ids], - strict=True, + raw_xmss = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, att_data), ) - ) + for vid in validator_ids + ] return TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=[], raw_xmss=raw_xmss, message=data_root, @@ -48,35 +46,22 @@ def _sign_and_aggregate( ) -def test_aggregate_rejects_empty_inputs() -> None: - """Aggregation with no signatures and no children raises an error.""" - with pytest.raises(AggregationError, match="At least one raw signature or child proof"): - TypeOneMultiSignature.aggregate( - xmss_participants=None, - children=[], - raw_xmss=[], - message=make_bytes32(0), - slot=Slot(0), - ) - - def test_aggregate_multiple_signatures(key_manager: XmssKeyManager) -> None: """Multiple validators' signatures can be aggregated into a single Type-1 proof.""" source = Checkpoint(root=make_bytes32(10), slot=Slot(0)) att_data = make_attestation_data_simple(Slot(2), make_bytes32(11), make_bytes32(12), source) vids = [ValidatorIndex(i) for i in range(4)] - xmss_participants = ValidatorIndices(data=vids).to_aggregation_bits() - raw_xmss = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in vids], - [key_manager.sign_attestation_data(vid, att_data) for vid in vids], - strict=True, + raw_xmss = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, att_data), ) - ) + for vid in vids + ] proof = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=[], raw_xmss=raw_xmss, message=hash_tree_root(att_data), @@ -102,17 +87,16 @@ def test_aggregate_children_with_raw_signatures(key_manager: XmssKeyManager) -> # Additional raw signatures: validators 2, 3 extra_vids = [ValidatorIndex(2), ValidatorIndex(3)] - xmss_participants = ValidatorIndices(data=extra_vids).to_aggregation_bits() - raw_xmss = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in extra_vids], - [key_manager.sign_attestation_data(vid, att_data) for vid in extra_vids], - strict=True, + raw_xmss = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, att_data), ) - ) + for vid in extra_vids + ] parent = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=[ ( child, @@ -149,7 +133,6 @@ def test_aggregate_three_children(key_manager: XmssKeyManager) -> None: child_c_pks = [key_manager[ValidatorIndex(2)].attestation_keypair.public_key] parent = TypeOneMultiSignature.aggregate( - xmss_participants=None, children=[(child_a, child_a_pks), (child_b, child_b_pks), (child_c, child_c_pks)], raw_xmss=[], message=hash_tree_root(att_data), @@ -185,14 +168,12 @@ def test_aggregate_children_of_children(key_manager: XmssKeyManager) -> None: # Level 1: two intermediate proofs. mid_ab = TypeOneMultiSignature.aggregate( - xmss_participants=None, children=[(leaf_a, leaf_a_pks), (leaf_b, leaf_b_pks)], raw_xmss=[], message=msg, slot=att_data.slot, ) mid_cd = TypeOneMultiSignature.aggregate( - xmss_participants=None, children=[(leaf_c, leaf_c_pks), (leaf_d, leaf_d_pks)], raw_xmss=[], message=msg, @@ -201,7 +182,6 @@ def test_aggregate_children_of_children(key_manager: XmssKeyManager) -> None: # Level 2: final root proof. root = TypeOneMultiSignature.aggregate( - xmss_participants=None, children=[(mid_ab, leaf_a_pks + leaf_b_pks), (mid_cd, leaf_c_pks + leaf_d_pks)], raw_xmss=[], message=msg, @@ -236,17 +216,16 @@ def test_aggregate_mixed_children_and_raw_multiple(key_manager: XmssKeyManager) # Additional raw signatures from validators 2 and 3. extra_vids = [ValidatorIndex(2), ValidatorIndex(3)] - xmss_participants = ValidatorIndices(data=extra_vids).to_aggregation_bits() - raw_xmss = list( - zip( - [key_manager[vid].attestation_keypair.public_key for vid in extra_vids], - [key_manager.sign_attestation_data(vid, att_data) for vid in extra_vids], - strict=True, + raw_xmss = [ + ( + vid, + key_manager[vid].attestation_keypair.public_key, + key_manager.sign_attestation_data(vid, att_data), ) - ) + for vid in extra_vids + ] proof = TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, children=[(child_a, child_a_pks), (child_b, child_b_pks)], raw_xmss=raw_xmss, message=msg, @@ -263,6 +242,57 @@ def test_aggregate_mixed_children_and_raw_multiple(key_manager: XmssKeyManager) ) +def test_type_one_verify_rejects_pubkey_count_mismatch(key_manager: XmssKeyManager) -> None: + """Type-1 verification refuses a pubkey set that does not match the bitfield.""" + source = Checkpoint(root=make_bytes32(160), slot=Slot(0)) + att_args = (Slot(2), 161, 162, source) + vids = [ValidatorIndex(0), ValidatorIndex(1)] + + proof = _sign_and_aggregate(key_manager, vids, att_args) + # The bitfield names two validators but only one key is supplied. + only_one = [key_manager[ValidatorIndex(0)].attestation_keypair.public_key] + + with pytest.raises( + AggregationError, match="Type-1 verify expected 2 pubkeys for participants, got 1" + ): + proof.verify(public_keys=only_one, message=make_bytes32(161), slot=att_args[0]) + + +def test_type_two_split_by_msg_rejected_under_test_prover(key_manager: XmssKeyManager) -> None: + """Splitting a merged proof aborts under the reduced test-config prover. + + The split branch is functional only under the production prover. + The test-config build aborts it with an in-circuit assertion. + Exercising it here drives the serialization and error-translation path. + """ + source = Checkpoint(root=make_bytes32(600), slot=Slot(0)) + att_args_a = (Slot(11), 601, 602, source) + att_args_b = (Slot(11), 603, 604, source) + att_data_a = make_attestation_data_simple( + att_args_a[0], make_bytes32(att_args_a[1]), make_bytes32(att_args_a[2]), att_args_a[3] + ) + + vids_a = [ValidatorIndex(0), ValidatorIndex(1)] + vids_b = [ValidatorIndex(2), ValidatorIndex(3)] + part_a = _sign_and_aggregate(key_manager, vids_a, att_args_a) + part_b = _sign_and_aggregate(key_manager, vids_b, att_args_b) + + pubkeys_a = [key_manager[vid].attestation_keypair.public_key for vid in vids_a] + pubkeys_b = [key_manager[vid].attestation_keypair.public_key for vid in vids_b] + + merged = TypeTwoMultiSignature.aggregate( + parts=[part_a, part_b], + public_keys_per_part=[pubkeys_a, pubkeys_b], + ) + + with pytest.raises(AggregationError, match="Type-2 split failed"): + merged.split_by_msg( + message=hash_tree_root(att_data_a), + public_keys_per_message=[pubkeys_a, pubkeys_b], + participants=part_a.participants, + ) + + def test_aggregate_wrong_message_fails_verification(key_manager: XmssKeyManager) -> None: """Verification fails when the caller passes a message that does not match the proof.""" source = Checkpoint(root=make_bytes32(120), slot=Slot(0)) @@ -335,7 +365,6 @@ def test_aggregate_child_signed_different_message_fails(key_manager: XmssKeyMana # The binding rejects mismatching messages during recursive aggregation. with pytest.raises(AggregationError): TypeOneMultiSignature.aggregate( - xmss_participants=None, children=[(child_a, child_a_pks), (child_b, child_b_pks)], raw_xmss=[], message=hash_tree_root(att_data_b), @@ -343,60 +372,6 @@ def test_aggregate_child_signed_different_message_fails(key_manager: XmssKeyMana ) -def test_aggregate_rejects_single_child_without_raw(key_manager: XmssKeyManager) -> None: - """A single child without raw signatures is rejected (need at least two children).""" - placeholder = ByteList512KiB(data=b"\x00") - stub_child = TypeOneMultiSignature( - participants=ValidatorIndices(data=[ValidatorIndex(0)]).to_aggregation_bits(), - proof=placeholder, - ) - - with pytest.raises(AggregationError, match="At least two child proofs"): - TypeOneMultiSignature.aggregate( - xmss_participants=None, - children=[ - ( - stub_child, - [ - key_manager[ValidatorIndex(i)].attestation_keypair.public_key - for i in range(1) - ], - ) - ], - raw_xmss=[], - message=make_bytes32(0), - slot=Slot(0), - ) - - -def test_aggregate_rejects_mismatched_participant_count( - key_manager: XmssKeyManager, -) -> None: - """Participant bitfield count must match raw signature count.""" - source = Checkpoint(root=make_bytes32(60), slot=Slot(0)) - att_data = make_attestation_data_simple(Slot(7), make_bytes32(61), make_bytes32(62), source) - - # Claim 2 participants but only provide 1 signature. - xmss_participants = ValidatorIndices( - data=[ValidatorIndex(0), ValidatorIndex(1)] - ).to_aggregation_bits() - raw_xmss = [ - ( - key_manager[ValidatorIndex(0)].attestation_keypair.public_key, - key_manager.sign_attestation_data(ValidatorIndex(0), att_data), - ) - ] - - with pytest.raises(AggregationError, match="does not match"): - TypeOneMultiSignature.aggregate( - xmss_participants=xmss_participants, - children=[], - raw_xmss=raw_xmss, - message=hash_tree_root(att_data), - slot=att_data.slot, - ) - - def test_type_two_aggregate_rejects_empty_parts() -> None: """Type-2 aggregation requires at least one Type-1 input.""" with pytest.raises(AggregationError, match="at least one Type-1 input"): @@ -425,6 +400,24 @@ def test_type_two_aggregate_rejects_mismatched_pubkey_layout( ) +def test_type_two_aggregate_propagates_prover_error(key_manager: XmssKeyManager) -> None: + """A corrupted component proof makes the merge prover reject the inputs.""" + source = Checkpoint(root=make_bytes32(210), slot=Slot(0)) + att_args = (Slot(8), 211, 212, source) + vids = [ValidatorIndex(0), ValidatorIndex(1)] + + part = _sign_and_aggregate(key_manager, vids, att_args) + pubkeys = [key_manager[vid].attestation_keypair.public_key for vid in vids] + + corrupted_bytes = bytearray(part.proof.data) + corrupted_bytes[10] ^= 0xFF + corrupted_bytes[20] ^= 0xFF + corrupted = part.model_copy(update={"proof": ByteList512KiB(data=bytes(corrupted_bytes))}) + + with pytest.raises(AggregationError, match="merge_many_type_1 failed"): + TypeTwoMultiSignature.aggregate(parts=[corrupted], public_keys_per_part=[pubkeys]) + + def test_type_two_verify_round_trip(key_manager: XmssKeyManager) -> None: """A Type-2 merge of two distinct-message Type-1 proofs round-trips through verify.""" source = Checkpoint(root=make_bytes32(300), slot=Slot(0)) diff --git a/tests/lean_spec/subspecs/xmss/test_constants.py b/tests/lean_spec/subspecs/xmss/test_constants.py new file mode 100644 index 000000000..3b504c2d6 --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_constants.py @@ -0,0 +1,285 @@ +""" +Tests for the XMSS cryptographic constants, configuration presets, and security margins. + +The security-margin tests validate that the production parameter choices achieve +adequate classical and quantum security. + +Based on: + +- [DKKW25c] "Hash-Based Multi-Signatures for Post-Quantum Ethereum" + (https://eprint.iacr.org/2025/055.pdf) +- [HKKTW26] "Aborting Random Oracles" + (https://eprint.iacr.org/2026/016) + +The security analysis follows the framework of [DKKW25c] Section 6. +""" + +import math + +import pytest + +from lean_spec.subspecs.koalabear import P_BYTES, P +from lean_spec.subspecs.xmss.constants import ( + PROD_CONFIG, + TARGET_CONFIG, + TEST_CONFIG, + XmssConfig, +) +from lean_spec.types import Uint64 +from lean_spec.types.ssz_base import BYTES_PER_LENGTH_OFFSET + + +def _valid_config_kwargs() -> dict[str, int]: + """Return a copy of the production configuration fields as plain kwargs.""" + return PROD_CONFIG.model_dump() + + +def test_decomposition_validator_rejects_bad_product() -> None: + """A configuration whose product does not equal the prime minus one is rejected.""" + kwargs = _valid_config_kwargs() + kwargs["Q"] = 128 + with pytest.raises(ValueError, match=f"Q \\* BASE\\^Z must equal P-1={P - 1}"): + XmssConfig(**kwargs) + + +def test_decomposition_validator_accepts_valid_product() -> None: + """A configuration whose product equals the prime minus one validates.""" + config = XmssConfig(**_valid_config_kwargs()) + assert config.Q * config.BASE**config.Z == P - 1 + + +def test_target_config_is_test_config_under_test_env() -> None: + """The active configuration under the test environment is the test preset.""" + assert TARGET_CONFIG is TEST_CONFIG + + +def test_lifetime_is_two_to_the_log_lifetime() -> None: + """The lifetime is two raised to the configured base-two logarithm.""" + assert TEST_CONFIG.LIFETIME == Uint64(1 << TEST_CONFIG.LOG_LIFETIME) + + +def test_leaves_per_bottom_tree_is_square_root_of_lifetime() -> None: + """One bottom tree covers the square root of the lifetime in leaves.""" + assert TEST_CONFIG.LEAVES_PER_BOTTOM_TREE == 1 << (TEST_CONFIG.LOG_LIFETIME // 2) + + +@pytest.mark.parametrize( + "dimension, z, expected", + [ + pytest.param(4, 8, 1, id="dimension below one field element rounds up to one"), + pytest.param(46, 8, 6, id="production dimension needs six field elements"), + pytest.param(16, 8, 2, id="two full field elements exactly cover sixteen digits"), + ], +) +def test_mh_hash_len_rounds_up(dimension: int, z: int, expected: int) -> None: + """The aborting-decode output length is the dimension divided by digits, rounded up.""" + kwargs = _valid_config_kwargs() + kwargs["DIMENSION"] = dimension + kwargs["Z"] = z + assert XmssConfig(**kwargs).MH_HASH_LEN_FE == expected + + +def test_signature_len_bytes_matches_layout() -> None: + """The advertised signature length equals the sum of its SSZ-encoded fields.""" + config = TEST_CONFIG + path = config.LOG_LIFETIME * config.HASH_LEN_FE * P_BYTES + rho = config.RAND_LEN_FE * P_BYTES + hashes = config.DIMENSION * config.HASH_LEN_FE * P_BYTES + expected = path + rho + hashes + 3 * BYTES_PER_LENGTH_OFFSET + assert config.SIGNATURE_LEN_BYTES == expected + + +@pytest.mark.parametrize( + "param_name, value", + [ + ("DIMENSION", 46), + ("BASE", 8), + ("Z", 8), + ("Q", 127), + ("TARGET_SUM", 200), + ("LOG_LIFETIME", 32), + ("PARAMETER_LEN", 5), + ("TWEAK_LEN_FE", 2), + ("MSG_LEN_FE", 9), + ("RAND_LEN_FE", 7), + ("HASH_LEN_FE", 8), + ("CAPACITY", 9), + ], +) +def test_prod_config_matches_reference(param_name: str, value: int) -> None: + """Production parameters must match the canonical Rust implementation.""" + assert getattr(PROD_CONFIG, param_name) == value + + +def _calculate_layer_size(w: int, v: int, d: int) -> int: + """Count a hypercube layer's size using inclusion-exclusion. + + Counts integer solutions to x_1 + ... + x_v = k with 0 <= x_i <= w-1, + where k = v*(w-1) - d. + """ + coord_sum = v * (w - 1) - d + return sum( + ((-1) ** s) * math.comb(v, s) * math.comb(coord_sum - s * w + v - 1, v - 1) + for s in range(coord_sum // w + 1) + ) + + +def _compute_security_levels(config: XmssConfig) -> dict[str, float]: + """Compute classical and quantum security levels for a configuration. + + Returns a dict with keys: + + - k_classical: effective classical security in bits + - k_quantum: effective quantum security in bits + - expected_attempts: expected signing attempts per message + - signing_failure_log2: log2 of the probability that all attempts fail + """ + v = config.DIMENSION + w_bits = int(math.log2(config.BASE)) + base = config.BASE + + # Each field element contributes floor(log2(P)) = 31 bits. + fe_bits = 31 + bits_digest = config.HASH_LEN_FE * fe_bits + bits_param = config.PARAMETER_LEN * fe_bits + bits_rand = config.RAND_LEN_FE * fe_bits + + # Raw message hash output is v chunks of w bits each. + bits_msg = v * w_bits + + # Abort correction from [HKKTW26] Corollary 1, Remark 14. + # + # Each field element aborts iff it equals the prime minus one. + # The non-abort probability per element is (P - 1) / P. + # Over ell field elements the total non-abort probability is that ratio to the ell. + wz = base**config.Z + q = config.Q + ell = math.ceil(v / config.Z) + + non_abort_total = ((q * wz) / P) ** ell + abort_correction_bits = -math.log2(non_abort_total) + + bits_msg_eff = bits_msg + abort_correction_bits + + log5 = math.log2(5) + log12 = math.log2(12) + log3 = math.log2(3) + log_lifetime = math.log2(config.LIFETIME) + logv = math.log2(v) + log_max_tries = math.log2(config.MAX_TRIES) + logqs = math.log2(config.LIFETIME) + + # Classical security is the minimum over four attack surfaces. + k_classical = min( + bits_digest - log5 - 2 * w_bits - log_lifetime - logv, + bits_param - log5 - 3, + bits_msg_eff - log5 - 1, + bits_rand - log5 - logqs - log_max_tries - 1, + ) + + # Quantum security is the minimum over four attack surfaces. + k_quantum = min( + bits_digest / 2 - log5 - 2 * w_bits - log_lifetime - logv - log12, + (bits_param - 5) / 2 - log5 - 2, + (bits_msg_eff - 3) / 2 - log5 - 1, + (bits_rand - logqs) / 2 - log5 - log3 - log_max_tries, + ) + + # Expected signing attempts for target-sum encoding. + d = v * (base - 1) - config.TARGET_SUM + layer_size = _calculate_layer_size(base, v, d) + layer_prob = layer_size / base**v + success_prob = non_abort_total * layer_prob + expected_attempts = 1 / success_prob + + signing_failure_log2 = config.MAX_TRIES * math.log2(1 - success_prob) + + return { + "k_classical": k_classical, + "k_quantum": k_quantum, + "expected_attempts": expected_attempts, + "signing_failure_log2": signing_failure_log2, + } + + +def test_prod_classical_security() -> None: + """Production parameters achieve at least 128-bit classical security.""" + levels = _compute_security_levels(PROD_CONFIG) + assert levels["k_classical"] >= 128 + + +def test_prod_quantum_security() -> None: + """Production parameters achieve at least 64-bit quantum security.""" + levels = _compute_security_levels(PROD_CONFIG) + assert levels["k_quantum"] >= 64 + + +def test_prod_expected_signing_attempts_are_bounded() -> None: + """Signing succeeds within a manageable number of attempts on average.""" + levels = _compute_security_levels(PROD_CONFIG) + assert levels["expected_attempts"] < 1000 + + +def test_prod_signing_failure_is_negligible() -> None: + """The probability of exhausting every attempt is below two to the minus 128.""" + levels = _compute_security_levels(PROD_CONFIG) + assert levels["signing_failure_log2"] < -128 + + +def test_prod_abort_probability_is_negligible() -> None: + """The aborting decode rejection probability is below two to the minus 28.""" + config = PROD_CONFIG + ell = math.ceil(config.DIMENSION / config.Z) + non_abort_per_fe = (config.Q * config.BASE**config.Z) / P + total_non_abort = non_abort_per_fe**ell + assert 1 - total_non_abort < 2**-28 + + +def test_prod_base_is_power_of_two() -> None: + """The alphabet size is a power of two so digits map cleanly onto bits.""" + w_bits = int(math.log2(PROD_CONFIG.BASE)) + assert PROD_CONFIG.BASE == 2**w_bits + + +def test_prod_digit_width_divides_twenty_four() -> None: + """The digit width divides twenty-four so rejection sampling works for KoalaBear.""" + w_bits = int(math.log2(PROD_CONFIG.BASE)) + assert 24 % w_bits == 0 + + +def test_prod_z_equals_twenty_four_over_digit_width() -> None: + """The digit count equals twenty-four divided by the digit width for the optimal decode.""" + w_bits = int(math.log2(PROD_CONFIG.BASE)) + assert PROD_CONFIG.Z == 24 // w_bits + + +def test_prod_mh_hash_len_covers_dimension() -> None: + """The aborting-decode output produces at least one digit per hash chain.""" + config = PROD_CONFIG + assert config.MH_HASH_LEN_FE * config.Z >= config.DIMENSION + + +def test_prod_binding_constraint_is_message_hash() -> None: + """The tightest classical bound is the message hash, matching the design intent.""" + config = PROD_CONFIG + v = config.DIMENSION + w_bits = int(math.log2(config.BASE)) + fe_bits = 31 + + bits_digest = config.HASH_LEN_FE * fe_bits + bits_param = config.PARAMETER_LEN * fe_bits + bits_rand = config.RAND_LEN_FE * fe_bits + bits_msg = v * w_bits + + log5 = math.log2(5) + log_lifetime = math.log2(config.LIFETIME) + logv = math.log2(v) + log_max_tries = math.log2(config.MAX_TRIES) + + classical_bounds = [ + bits_digest - log5 - 2 * w_bits - log_lifetime - logv, + bits_param - log5 - 3, + bits_msg - log5 - 1, + bits_rand - log5 - log_lifetime - log_max_tries - 1, + ] + assert classical_bounds.index(min(classical_bounds)) == 2 diff --git a/tests/lean_spec/subspecs/xmss/test_containers.py b/tests/lean_spec/subspecs/xmss/test_containers.py index c4c981dfd..3934b2a1c 100644 --- a/tests/lean_spec/subspecs/xmss/test_containers.py +++ b/tests/lean_spec/subspecs/xmss/test_containers.py @@ -1,4 +1,4 @@ -"""Behaviour tests for ValidatorKeyPair.""" +"""Behaviour tests for the XMSS containers.""" import json @@ -6,8 +6,17 @@ from consensus_testing.keys import XmssKeyManager from pydantic import ValidationError -from lean_spec.subspecs.xmss.containers import KeyPair, ValidatorKeyPair -from lean_spec.types import ValidatorIndex +from lean_spec.subspecs.koalabear.field import P_BYTES +from lean_spec.subspecs.xmss.constants import TEST_CONFIG +from lean_spec.subspecs.xmss.containers import ( + KeyPair, + PublicKey, + SecretKey, + Signature, + ValidatorKeyPair, +) +from lean_spec.subspecs.xmss.interface import TEST_SIGNATURE_SCHEME +from lean_spec.types import Bytes32, Slot, Uint64, ValidatorIndex @pytest.fixture(scope="module") @@ -213,3 +222,101 @@ def test_keypair_frozen(keypair_a: KeyPair) -> None: """KeyPair fields cannot be reassigned (StrictBaseModel is frozen).""" with pytest.raises(ValidationError): keypair_a.public_key = keypair_a.public_key + + +def test_keypair_decodes_public_and_secret_hex(keypair_a: KeyPair) -> None: + """A key pair validates from hex strings for both halves.""" + decoded = KeyPair.model_validate( + { + "public_key": keypair_a.public_key.encode_bytes().hex(), + "secret_key": keypair_a.secret_key.encode_bytes().hex(), + } + ) + assert decoded == keypair_a + + +def test_keypair_rejects_invalid_public_key_hex(keypair_a: KeyPair) -> None: + """A malformed public-key hex string surfaces as a validation error.""" + with pytest.raises(ValidationError, match="invalid public key hex"): + KeyPair.model_validate( + { + "public_key": "deadbeef", + "secret_key": keypair_a.secret_key.encode_bytes().hex(), + } + ) + + +def test_keypair_rejects_invalid_secret_key_hex(keypair_a: KeyPair) -> None: + """A malformed secret-key hex string surfaces as a validation error.""" + with pytest.raises(ValidationError, match="invalid secret key hex"): + KeyPair.model_validate( + { + "public_key": keypair_a.public_key.encode_bytes().hex(), + "secret_key": "deadbeef", + } + ) + + +@pytest.fixture(scope="module") +def signed_key_pair() -> KeyPair: + """A key pair generated directly from the test scheme.""" + return TEST_SIGNATURE_SCHEME.key_gen(Slot(0), Uint64(32)) + + +@pytest.fixture(scope="module") +def sample_signature(signed_key_pair: KeyPair) -> Signature: + """A signature over a fixed message at slot zero.""" + return TEST_SIGNATURE_SCHEME.sign( + signed_key_pair.secret_key, Slot(0), Bytes32(bytes([42] * 32)) + ) + + +def test_public_key_ssz_roundtrip(signed_key_pair: KeyPair) -> None: + """A public key encodes and decodes back to an equal value.""" + public_key = signed_key_pair.public_key + assert PublicKey.decode_bytes(public_key.encode_bytes()) == public_key + + +def test_public_key_encoded_size_matches_layout(signed_key_pair: KeyPair) -> None: + """The encoded public key is the digest plus parameter packed into field bytes.""" + encoded = signed_key_pair.public_key.encode_bytes() + expected = (TEST_CONFIG.HASH_LEN_FE + TEST_CONFIG.PARAMETER_LEN) * P_BYTES + assert len(encoded) == expected + + +def test_secret_key_ssz_roundtrip(signed_key_pair: KeyPair) -> None: + """A secret key encodes and decodes back to an equal value.""" + secret_key = signed_key_pair.secret_key + assert SecretKey.decode_bytes(secret_key.encode_bytes()) == secret_key + + +def test_signature_is_fixed_size() -> None: + """A signature reports as fixed-size on the wire.""" + assert Signature.is_fixed_size() is True + + +def test_signature_byte_length_matches_config() -> None: + """The signature byte length matches the configured fixed length.""" + assert Signature.get_byte_length() == TEST_CONFIG.SIGNATURE_LEN_BYTES + + +def test_signature_ssz_roundtrip(sample_signature: Signature) -> None: + """A signature encodes and decodes back to an equal value.""" + assert Signature.decode_bytes(sample_signature.encode_bytes()) == sample_signature + + +def test_signature_encoded_size_matches_config(sample_signature: Signature) -> None: + """The encoded signature length matches the advertised fixed length.""" + assert len(sample_signature.encode_bytes()) == TEST_CONFIG.SIGNATURE_LEN_BYTES + + +def test_signature_json_is_prefixed_hex(sample_signature: Signature) -> None: + """The JSON form is a hex string prefixed with the byte marker.""" + dumped = json.loads(sample_signature.model_dump_json()) + assert dumped == "0x" + sample_signature.encode_bytes().hex() + + +def test_signature_decodes_from_json(sample_signature: Signature) -> None: + """A signature decodes back from its JSON hex form.""" + encoded = "0x" + sample_signature.encode_bytes().hex() + assert Signature.from_hex(encoded) == sample_signature diff --git a/tests/lean_spec/subspecs/xmss/test_encoding.py b/tests/lean_spec/subspecs/xmss/test_encoding.py new file mode 100644 index 000000000..bd45df48e --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_encoding.py @@ -0,0 +1,147 @@ +"""Tests for the message-to-codeword encoding pipeline.""" + +import pytest + +from lean_spec.subspecs.koalabear import Fp, P +from lean_spec.subspecs.xmss import encoding +from lean_spec.subspecs.xmss.constants import TEST_CONFIG, TWEAK_PREFIX_MESSAGE +from lean_spec.subspecs.xmss.encoding import ( + aborting_decode, + encode_epoch, + encode_message, + message_hash, + target_sum_encode, +) +from lean_spec.subspecs.xmss.field import int_to_base_p, random_field_elements +from lean_spec.subspecs.xmss.poseidon import TEST_POSEIDON +from lean_spec.subspecs.xmss.types import Parameter, Randomness +from lean_spec.types import Bytes32, Uint64 + + +def _parameter() -> Parameter: + """Return a fixed public parameter for encoding tests.""" + return Parameter(data=[Fp(value=1)] * TEST_CONFIG.PARAMETER_LEN) + + +def test_encode_message_zero_is_all_zero_limbs() -> None: + """An all-zero message encodes to all-zero field elements.""" + encoded = encode_message(TEST_CONFIG, Bytes32(b"\x00" * 32)) + assert encoded == [Fp(value=0)] * TEST_CONFIG.MSG_LEN_FE + + +def test_encode_message_reads_little_endian() -> None: + """A maximal message encodes to its little-endian base-P decomposition.""" + message = Bytes32(b"\xff" * 32) + acc = int.from_bytes(message, "little") + assert encode_message(TEST_CONFIG, message) == int_to_base_p(acc, TEST_CONFIG.MSG_LEN_FE) + + +@pytest.mark.parametrize("epoch", [0, 42, 2**32 - 1]) +def test_encode_epoch_matches_prefixed_decomposition(epoch: int) -> None: + """An epoch encodes to its value shifted above the message prefix.""" + acc = (epoch << 8) | TWEAK_PREFIX_MESSAGE + expected = int_to_base_p(acc, TEST_CONFIG.TWEAK_LEN_FE) + assert encode_epoch(TEST_CONFIG, Uint64(epoch)) == expected + + +def test_encode_epoch_is_injective_over_a_range() -> None: + """Distinct epochs in a range encode to distinct field-element tuples.""" + encodings = {tuple(encode_epoch(TEST_CONFIG, Uint64(i))) for i in range(1000)} + assert len(encodings) == 1000 + + +def test_aborting_decode_known_decomposition() -> None: + """A hand-built quotient decodes to its base-BASE digits, truncated to the dimension.""" + config = TEST_CONFIG + d_value = 5 + fe_list = [Fp(value=config.Q * d_value)] * config.MH_HASH_LEN_FE + + expected_per_fe = [] + remaining = d_value + for _ in range(config.Z): + expected_per_fe.append(remaining % config.BASE) + remaining //= config.BASE + expected = (expected_per_fe * config.MH_HASH_LEN_FE)[: config.DIMENSION] + + assert aborting_decode(config, fe_list) == expected + + +def test_aborting_decode_accepts_largest_valid_element() -> None: + """The element just below the abort threshold decodes successfully.""" + config = TEST_CONFIG + result = aborting_decode(config, [Fp(value=P - 2)] * config.MH_HASH_LEN_FE) + assert result is not None + assert len(result) == config.DIMENSION + assert all(0 <= d < config.BASE for d in result) + + +def test_aborting_decode_rejects_threshold_element() -> None: + """The element equal to the prime minus one triggers the abort.""" + assert aborting_decode(TEST_CONFIG, [Fp(value=P - 1)]) is None + + +def test_message_hash_yields_valid_codeword() -> None: + """The message hash decodes to a codeword of dimension digits in range.""" + config = TEST_CONFIG + parameter = Parameter(data=random_field_elements(config.PARAMETER_LEN)) + randomness = Randomness(data=random_field_elements(config.RAND_LEN_FE)) + + result = message_hash( + TEST_POSEIDON, config, parameter, Uint64(313), randomness, Bytes32(b"\xaa" * 32) + ) + + assert result is not None + assert len(result) == config.DIMENSION + assert all(0 <= digit < config.BASE for digit in result) + + +def test_target_sum_encode_accepts_codeword_on_target_layer() -> None: + """Randomness whose codeword sums to the target is accepted.""" + config = TEST_CONFIG + parameter = _parameter() + # Attempt counter three lands the all-zero message on the target-sum layer. + rho = Randomness(data=int_to_base_p(3, config.RAND_LEN_FE)) + + codeword = target_sum_encode( + TEST_POSEIDON, config, parameter, Bytes32(b"\x00" * 32), rho, Uint64(0) + ) + + # The digits sum to the target of six, landing on the accepted layer. + assert codeword == [3, 0, 3, 0] + + +def test_target_sum_encode_rejects_codeword_off_target_layer() -> None: + """Randomness whose codeword misses the target sum is rejected.""" + config = TEST_CONFIG + parameter = _parameter() + # Attempt counter zero produces a codeword whose digits do not sum to the target. + rho = Randomness(data=int_to_base_p(0, config.RAND_LEN_FE)) + + assert ( + target_sum_encode(TEST_POSEIDON, config, parameter, Bytes32(b"\x00" * 32), rho, Uint64(0)) + is None + ) + + +def test_target_sum_encode_propagates_aborting_decode_failure( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """An aborting message hash makes the encode return None before the sum check. + + The aborting decode rejects only the prime-minus-one field element. + That event has probability near one in two billion per element. + It cannot be triggered with real inputs in a test, so the hash result is forced. + """ + monkeypatch.setattr(encoding, "message_hash", lambda *args, **kwargs: None) + + assert ( + target_sum_encode( + TEST_POSEIDON, + TEST_CONFIG, + _parameter(), + Bytes32(b"\x00" * 32), + Randomness(data=int_to_base_p(0, TEST_CONFIG.RAND_LEN_FE)), + Uint64(0), + ) + is None + ) diff --git a/tests/lean_spec/subspecs/xmss/test_field.py b/tests/lean_spec/subspecs/xmss/test_field.py new file mode 100644 index 000000000..58535d1c8 --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_field.py @@ -0,0 +1,88 @@ +"""Tests for the field-element decomposition and secure sampling helpers.""" + +import secrets + +import pytest + +from lean_spec.subspecs.koalabear import Fp, P +from lean_spec.subspecs.xmss.constants import TEST_CONFIG +from lean_spec.subspecs.xmss.field import ( + int_to_base_p, + random_domain, + random_field_elements, + random_parameter, +) +from lean_spec.subspecs.xmss.types import HashDigestVector, Parameter + + +@pytest.mark.parametrize( + "value, num_limbs, expected_values", + [ + pytest.param(0, 4, [0, 0, 0, 0], id="zero spreads to all-zero limbs"), + pytest.param(123, 4, [123, 0, 0, 0], id="small value in the lowest limb"), + pytest.param(P, 4, [0, 1, 0, 0], id="prime carries into the second limb"), + pytest.param(P - 1, 4, [P - 1, 0, 0, 0], id="largest single-limb value"), + pytest.param(3 * (P**2) + 2 * P + 1, 4, [1, 2, 3, 0], id="mixed multi-limb value"), + pytest.param(P**3 - 1, 3, [P - 1, P - 1, P - 1], id="all limbs saturated"), + ], +) +def test_int_to_base_p_known_decomposition( + value: int, num_limbs: int, expected_values: list[int] +) -> None: + """Decomposition matches hand-computed base-P limbs.""" + assert int_to_base_p(value, num_limbs) == [Fp(value=v) for v in expected_values] + + +def test_int_to_base_p_zero_limbs_returns_empty() -> None: + """Requesting zero limbs yields an empty list.""" + assert int_to_base_p(12345, 0) == [] + + +def test_int_to_base_p_truncates_high_limbs() -> None: + """A value wider than the requested limbs drops its high digits.""" + assert int_to_base_p(P**2 + P + 7, 1) == [Fp(value=7)] + + +def test_int_to_base_p_roundtrip_is_reversible() -> None: + """Decomposing then recomposing recovers the original integer.""" + num_limbs = 5 + original_limbs = [secrets.randbelow(P) for _ in range(num_limbs)] + original_value = sum(val * (P**i) for i, val in enumerate(original_limbs)) + + decomposed = [int(fp) for fp in int_to_base_p(original_value, num_limbs)] + + assert decomposed == original_limbs + + +def test_random_field_elements_length() -> None: + """The sampler returns exactly the requested number of elements.""" + assert len(random_field_elements(7)) == 7 + + +def test_random_field_elements_zero_length() -> None: + """A zero-length request yields an empty list.""" + assert random_field_elements(0) == [] + + +def test_random_field_elements_are_in_field_range() -> None: + """Every sampled element lies in the range zero up to the prime.""" + assert all(0 <= int(fe) < P for fe in random_field_elements(50)) + + +def test_random_field_elements_are_not_constant() -> None: + """A large sample is overwhelmingly unlikely to repeat a single value.""" + assert len({int(fe) for fe in random_field_elements(100)}) > 1 + + +def test_random_parameter_has_parameter_length() -> None: + """A sampled parameter has the configured parameter length.""" + parameter = random_parameter(TEST_CONFIG) + assert isinstance(parameter, Parameter) + assert len(parameter.data) == TEST_CONFIG.PARAMETER_LEN + + +def test_random_domain_has_hash_length() -> None: + """A sampled domain vector has the configured digest length.""" + domain = random_domain(TEST_CONFIG) + assert isinstance(domain, HashDigestVector) + assert len(domain.data) == TEST_CONFIG.HASH_LEN_FE diff --git a/tests/lean_spec/subspecs/xmss/test_interface.py b/tests/lean_spec/subspecs/xmss/test_interface.py index 34ddb6ce4..4be16f1cd 100644 --- a/tests/lean_spec/subspecs/xmss/test_interface.py +++ b/tests/lean_spec/subspecs/xmss/test_interface.py @@ -1,12 +1,13 @@ -""" -End-to-end tests for the Generalized XMSS signature scheme. -""" +"""End-to-end tests for the Generalized XMSS signature scheme and its helpers.""" import pytest +from lean_spec.subspecs.xmss import interface +from lean_spec.subspecs.xmss.encoding import target_sum_encode from lean_spec.subspecs.xmss.interface import ( TEST_SIGNATURE_SCHEME, GeneralizedXmssScheme, + _expand_activation_time, ) from lean_spec.types import Bytes32, Slot, Uint64 @@ -37,7 +38,7 @@ def _test_correctness_roundtrip( # Sign the message at the chosen slot. # - # This might take a moment as it may try multiple `rho` values. + # This might take a moment as it may try multiple rho values. signature = scheme.sign(sk, test_slot, message) # Verification of the valid signature must succeed. @@ -56,9 +57,11 @@ def _test_correctness_roundtrip( # In that case, verification will succeed, which is expected behavior for identical codewords. # # We detect this by checking if both messages encode to the same codeword. - original_codeword = scheme.encoder.encode(pk.parameter, message, signature.rho, test_slot) - tampered_codeword = scheme.encoder.encode( - pk.parameter, tampered_message, signature.rho, test_slot + original_codeword = target_sum_encode( + scheme.poseidon, scheme.config, pk.parameter, message, signature.rho, test_slot + ) + tampered_codeword = target_sum_encode( + scheme.poseidon, scheme.config, pk.parameter, tampered_message, signature.rho, test_slot ) if tampered_codeword != original_codeword: @@ -202,6 +205,79 @@ def test_deterministic_signing() -> None: assert sig1.path.siblings == sig2.path.siblings +@pytest.mark.parametrize( + "log_lifetime, desired_activation, desired_num, expected_start, expected_end", + [ + pytest.param(8, 0, 16, 0, 2, id="boundary request widens to minimum two trees"), + pytest.param(8, 10, 5, 0, 2, id="unaligned request rounds onto tree boundaries"), + pytest.param(8, 0, 100, 0, 7, id="larger request spans seven trees"), + pytest.param(4, 0, 300, 0, 4, id="request wider than lifetime covers all of it"), + pytest.param(8, 32, 16, 2, 4, id="middle request stays put"), + pytest.param(8, 240, 30, 14, 16, id="request near the end slides back to the boundary"), + ], +) +def test_expand_activation_time( + log_lifetime: int, + desired_activation: int, + desired_num: int, + expected_start: int, + expected_end: int, +) -> None: + """The requested window snaps onto whole bottom trees, widened and clamped.""" + assert _expand_activation_time(log_lifetime, desired_activation, desired_num) == ( + expected_start, + expected_end, + ) + + +def test_key_gen_rejects_range_exceeding_lifetime() -> None: + """A requested range past the lifetime is refused.""" + with pytest.raises(ValueError, match="Activation range exceeds the key's lifetime."): + TEST_SIGNATURE_SCHEME.key_gen(Slot(200), Uint64(100)) + + +def test_sign_rejects_slot_outside_activation() -> None: + """Signing a slot the key was never activated for is refused.""" + secret_key = TEST_SIGNATURE_SCHEME.key_gen(Slot(0), Uint64(32)).secret_key + with pytest.raises(ValueError, match="Key is not active for the specified slot."): + TEST_SIGNATURE_SCHEME.sign(secret_key, Slot(200), Bytes32(b"\x42" * 32)) + + +def test_sign_raises_when_no_encoding_found(monkeypatch: pytest.MonkeyPatch) -> None: + """An encoding search that never succeeds raises after exhausting the attempts. + + The encoding is forced to always reject so the retry loop runs to its limit. + """ + secret_key = TEST_SIGNATURE_SCHEME.key_gen(Slot(0), Uint64(32)).secret_key + monkeypatch.setattr(interface, "target_sum_encode", lambda *args, **kwargs: None) + tries = TEST_SIGNATURE_SCHEME.config.MAX_TRIES + with pytest.raises( + RuntimeError, match=f"Failed to find a valid message encoding after {tries} tries." + ): + TEST_SIGNATURE_SCHEME.sign(secret_key, Slot(0), Bytes32(b"\x42" * 32)) + + +def test_sign_raises_on_wrong_codeword_dimension(monkeypatch: pytest.MonkeyPatch) -> None: + """An encoding returning the wrong number of digits raises. + + The encoding is forced to return one digit too few for the scheme dimension. + """ + secret_key = TEST_SIGNATURE_SCHEME.key_gen(Slot(0), Uint64(32)).secret_key + short = [0] * (TEST_SIGNATURE_SCHEME.config.DIMENSION - 1) + monkeypatch.setattr(interface, "target_sum_encode", lambda *args, **kwargs: short) + with pytest.raises( + RuntimeError, match="Encoding is broken: returned too many or too few chunks." + ): + TEST_SIGNATURE_SCHEME.sign(secret_key, Slot(0), Bytes32(b"\x42" * 32)) + + +def test_advance_preparation_is_a_noop_at_the_end() -> None: + """A two-tree key cannot advance and returns the same secret key.""" + leaves = TEST_SIGNATURE_SCHEME.config.LEAVES_PER_BOTTOM_TREE + secret_key = TEST_SIGNATURE_SCHEME.key_gen(Slot(0), Uint64(2 * leaves)).secret_key + assert TEST_SIGNATURE_SCHEME.advance_preparation(secret_key) is secret_key + + class TestVerifySecurityBounds: """ Security tests for verify method input validation. diff --git a/tests/lean_spec/subspecs/xmss/test_merkle.py b/tests/lean_spec/subspecs/xmss/test_merkle.py new file mode 100644 index 000000000..5059b9592 --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_merkle.py @@ -0,0 +1,387 @@ +"""Tests for the sparse Merkle subtree implementation.""" + +import pytest + +from lean_spec.subspecs.xmss.constants import PROD_CONFIG, TEST_CONFIG, XmssConfig +from lean_spec.subspecs.xmss.field import random_domain, random_parameter +from lean_spec.subspecs.xmss.merkle import ( + HashSubTree, + HashTreeLayer, + HashTreeLayers, + combined_path, + verify_path, +) +from lean_spec.subspecs.xmss.poseidon import PROD_POSEIDON, TEST_POSEIDON, PoseidonXmss +from lean_spec.subspecs.xmss.prf import PRFKey +from lean_spec.subspecs.xmss.types import ( + HashDigestList, + HashDigestVector, + HashTreeOpening, + Parameter, + TreeTweak, +) +from lean_spec.types import Uint64 +from lean_spec.types.exceptions import SSZValueError + + +def _run_commit_open_verify_roundtrip( + poseidon: PoseidonXmss, + config: XmssConfig, + num_leaves: int, + depth: int, + start_index: int, + leaf_parts_len: int, +) -> None: + """Build a tree, then open and verify every active leaf against its root.""" + parameter = random_parameter(config) + leaves: list[list[HashDigestVector]] = [ + [random_domain(config) for _ in range(leaf_parts_len)] for _ in range(num_leaves) + ] + + leaf_hashes: list[HashDigestVector] = [ + poseidon.tweak_hash( + config, + parameter, + TreeTweak(level=0, index=Uint64(start_index + i)), + leaf_parts, + ) + for i, leaf_parts in enumerate(leaves) + ] + + tree = HashSubTree.new( + poseidon=poseidon, + config=config, + lowest_layer=Uint64(0), + depth=Uint64(depth), + start_index=Uint64(start_index), + parameter=parameter, + lowest_layer_nodes=leaf_hashes, + ) + root = tree.root() + + for i, leaf_parts in enumerate(leaves): + position = Uint64(start_index + i) + opening = tree.path(position) + assert verify_path( + poseidon=poseidon, + config=config, + parameter=parameter, + root=root, + position=position, + leaf_parts=leaf_parts, + opening=opening, + ) + + +@pytest.mark.parametrize( + "num_leaves, depth, start_index, leaf_parts_len", + [ + pytest.param(16, 4, 0, 3, id="Full tree (depth 4)", marks=pytest.mark.slow), + pytest.param(12, 5, 0, 5, id="Half tree, left-aligned (depth 5)", marks=pytest.mark.slow), + pytest.param(16, 5, 16, 2, id="Half tree, right-aligned (depth 5)"), + pytest.param(22, 6, 13, 3, id="Sparse, non-aligned tree (depth 6)", marks=pytest.mark.slow), + pytest.param(2, 2, 2, 6, id="Half tree, right-aligned (small)"), + pytest.param(1, 1, 0, 1, id="Tree with a single leaf at the start"), + pytest.param(1, 1, 1, 1, id="Tree with a single leaf at an odd index"), + pytest.param(16, 5, 7, 2, id="Small sparse tree starting at an odd index"), + ], +) +def test_commit_open_verify_roundtrip( + num_leaves: int, + depth: int, + start_index: int, + leaf_parts_len: int, +) -> None: + """A built tree opens and verifies every leaf for various shapes.""" + assert start_index + num_leaves <= (1 << depth) + _run_commit_open_verify_roundtrip( + PROD_POSEIDON, PROD_CONFIG, num_leaves, depth, start_index, leaf_parts_len + ) + + +def test_new_rejects_nodes_overflowing_their_level() -> None: + """A node run that does not fit its level raises with the level and bounds named.""" + with pytest.raises(ValueError, match=r"Overflow at layer 0: start=3, count=2, max=4"): + HashSubTree.new( + poseidon=TEST_POSEIDON, + config=TEST_CONFIG, + lowest_layer=Uint64(0), + depth=Uint64(2), + start_index=Uint64(3), + parameter=random_parameter(TEST_CONFIG), + lowest_layer_nodes=[random_domain(TEST_CONFIG), random_domain(TEST_CONFIG)], + ) + + +def test_new_top_tree_rejects_odd_depth() -> None: + """The top tree requires an even depth for the top-bottom split.""" + with pytest.raises(ValueError, match=r"Depth must be even for top-bottom split, got 7."): + HashSubTree.new_top_tree( + TEST_POSEIDON, TEST_CONFIG, 7, Uint64(0), random_parameter(TEST_CONFIG), [] + ) + + +def test_new_bottom_tree_rejects_odd_depth() -> None: + """The bottom tree requires an even depth for the top-bottom split.""" + with pytest.raises(ValueError, match=r"Depth must be even for top-bottom split, got 7."): + HashSubTree.new_bottom_tree( + TEST_POSEIDON, TEST_CONFIG, 7, Uint64(0), random_parameter(TEST_CONFIG), [] + ) + + +def test_new_bottom_tree_rejects_wrong_leaf_count() -> None: + """The bottom tree requires exactly the square-root-of-lifetime leaves.""" + with pytest.raises(ValueError, match=r"Expected 16 leaves for depth=8, got 0."): + HashSubTree.new_bottom_tree( + TEST_POSEIDON, TEST_CONFIG, 8, Uint64(0), random_parameter(TEST_CONFIG), [] + ) + + +def test_root_rejects_empty_subtree() -> None: + """A subtree with no layers has no root.""" + subtree = HashSubTree( + depth=Uint64(8), + lowest_layer=Uint64(0), + layers=HashTreeLayers(data=[]), + ) + with pytest.raises(ValueError, match=r"Empty subtree has no root."): + subtree.root() + + +def test_root_rejects_empty_top_layer() -> None: + """A subtree whose top layer holds no nodes has no root.""" + empty_layer = HashTreeLayer(start_index=Uint64(0), nodes=HashDigestList(data=[])) + subtree = HashSubTree( + depth=Uint64(8), + lowest_layer=Uint64(0), + layers=HashTreeLayers(data=[empty_layer]), + ) + with pytest.raises(ValueError, match=r"Top layer is empty."): + subtree.root() + + +def test_path_rejects_empty_subtree() -> None: + """Opening a path on a subtree with no layers raises.""" + subtree = HashSubTree( + depth=Uint64(8), + lowest_layer=Uint64(0), + layers=HashTreeLayers(data=[]), + ) + with pytest.raises(ValueError, match=r"Empty subtree."): + subtree.path(Uint64(0)) + + +def test_path_rejects_position_out_of_bounds() -> None: + """A position outside the lowest layer's stored range raises.""" + layer = HashTreeLayer( + start_index=Uint64(0), + nodes=HashDigestList(data=[random_domain(TEST_CONFIG), random_domain(TEST_CONFIG)]), + ) + subtree = HashSubTree( + depth=Uint64(8), + lowest_layer=Uint64(0), + layers=HashTreeLayers(data=[layer]), + ) + with pytest.raises(ValueError, match=r"Position 5 out of bounds."): + subtree.path(Uint64(5)) + + +def test_path_rejects_sibling_out_of_bounds() -> None: + """A non-root layer too small to hold the needed sibling raises.""" + leaf_layer = HashTreeLayer( + start_index=Uint64(0), + nodes=HashDigestList(data=[random_domain(TEST_CONFIG), random_domain(TEST_CONFIG)]), + ) + # The middle layer lacks the sibling at index one. + middle_layer = HashTreeLayer( + start_index=Uint64(0), nodes=HashDigestList(data=[random_domain(TEST_CONFIG)]) + ) + root_layer = HashTreeLayer( + start_index=Uint64(0), nodes=HashDigestList(data=[random_domain(TEST_CONFIG)]) + ) + subtree = HashSubTree( + depth=Uint64(8), + lowest_layer=Uint64(0), + layers=HashTreeLayers(data=[leaf_layer, middle_layer, root_layer]), + ) + with pytest.raises(ValueError, match=r"Sibling index 1 out of bounds."): + subtree.path(Uint64(0)) + + +@pytest.fixture(scope="module") +def prf_trees() -> tuple[Parameter, HashSubTree, HashSubTree, HashSubTree]: + """Build a top tree over two prf-derived bottom trees for combined-path tests.""" + config = TEST_CONFIG + prf_key = PRFKey.generate() + parameter = random_parameter(config) + bottom_zero = HashSubTree.from_prf_key( + poseidon=TEST_POSEIDON, + config=config, + prf_key=prf_key, + bottom_tree_index=Uint64(0), + parameter=parameter, + ) + bottom_one = HashSubTree.from_prf_key( + poseidon=TEST_POSEIDON, + config=config, + prf_key=prf_key, + bottom_tree_index=Uint64(1), + parameter=parameter, + ) + top = HashSubTree.new_top_tree( + TEST_POSEIDON, + config, + config.LOG_LIFETIME, + Uint64(0), + parameter, + [bottom_zero.root(), bottom_one.root()], + ) + return parameter, top, bottom_zero, bottom_one + + +def test_combined_path_authenticates_leaf_to_global_root( + prf_trees: tuple[Parameter, HashSubTree, HashSubTree, HashSubTree], +) -> None: + """A combined opening spans the full depth from leaf to global root.""" + _, top, bottom_zero, _ = prf_trees + opening = combined_path(top, bottom_zero, Uint64(0)) + assert len(opening.siblings) == TEST_CONFIG.LOG_LIFETIME + + +def test_combined_path_rejects_depth_mismatch( + prf_trees: tuple[Parameter, HashSubTree, HashSubTree, HashSubTree], +) -> None: + """Top and bottom trees of disagreeing depth cannot be stitched.""" + _, top, bottom_zero, _ = prf_trees + mismatched = bottom_zero.model_copy(update={"depth": Uint64(6)}) + with pytest.raises(ValueError, match=r"Depth mismatch: top=8, bottom=6."): + combined_path(top, mismatched, Uint64(0)) + + +def test_combined_path_rejects_odd_depth( + prf_trees: tuple[Parameter, HashSubTree, HashSubTree, HashSubTree], +) -> None: + """Stitching requires an even depth.""" + _, top, bottom_zero, _ = prf_trees + odd_top = top.model_copy(update={"depth": Uint64(7)}) + odd_bottom = bottom_zero.model_copy(update={"depth": Uint64(7)}) + with pytest.raises(ValueError, match=r"Depth must be even, got 7."): + combined_path(odd_top, odd_bottom, Uint64(7)) + + +def test_combined_path_rejects_wrong_bottom_tree( + prf_trees: tuple[Parameter, HashSubTree, HashSubTree, HashSubTree], +) -> None: + """A position belonging to a sibling bottom tree is refused.""" + _, top, bottom_zero, _ = prf_trees + with pytest.raises( + ValueError, + match=r"Wrong bottom tree: position 16 needs start 16, got 0.", + ): + combined_path(top, bottom_zero, Uint64(16)) + + +def test_from_prf_key_builds_a_bottom_tree() -> None: + """A prf-derived bottom tree has the expected depth, layers, and leaf count.""" + config = TEST_CONFIG + bottom_tree = HashSubTree.from_prf_key( + poseidon=TEST_POSEIDON, + config=config, + prf_key=PRFKey.generate(), + bottom_tree_index=Uint64(0), + parameter=random_parameter(config), + ) + assert bottom_tree.depth == Uint64(config.LOG_LIFETIME) + assert bottom_tree.lowest_layer == Uint64(0) + assert len(bottom_tree.layers[-1].nodes) == 1 + assert len(bottom_tree.layers[0].nodes) == config.LEAVES_PER_BOTTOM_TREE + + +def test_from_prf_key_is_deterministic() -> None: + """The same seed and index rebuild the same bottom-tree root.""" + config = TEST_CONFIG + prf_key = PRFKey.generate() + parameter = random_parameter(config) + first = HashSubTree.from_prf_key( + poseidon=TEST_POSEIDON, + config=config, + prf_key=prf_key, + bottom_tree_index=Uint64(0), + parameter=parameter, + ) + second = HashSubTree.from_prf_key( + poseidon=TEST_POSEIDON, + config=config, + prf_key=prf_key, + bottom_tree_index=Uint64(0), + parameter=parameter, + ) + assert first.root() == second.root() + + +def test_from_prf_key_distinct_indices_give_distinct_roots() -> None: + """Different bottom-tree indices produce different roots.""" + config = TEST_CONFIG + prf_key = PRFKey.generate() + parameter = random_parameter(config) + tree_zero = HashSubTree.from_prf_key( + poseidon=TEST_POSEIDON, + config=config, + prf_key=prf_key, + bottom_tree_index=Uint64(0), + parameter=parameter, + ) + tree_one = HashSubTree.from_prf_key( + poseidon=TEST_POSEIDON, + config=config, + prf_key=prf_key, + bottom_tree_index=Uint64(1), + parameter=parameter, + ) + assert tree_zero.root() != tree_one.root() + + +def test_verify_path_rejects_excessive_depth_at_ssz_level() -> None: + """ + The SSZ type system caps an opening at the layer limit. + + The defensive depth guard inside verification cannot be reached through a + well-formed opening, since the digest list rejects more than its limit. + """ + with pytest.raises(SSZValueError): + HashDigestList(data=[random_domain(PROD_CONFIG) for _ in range(33)]) + + +@pytest.mark.parametrize("position", [16, 100]) +def test_verify_path_rejects_position_exceeding_capacity(position: int) -> None: + """A position at or beyond two-to-the-depth fails without raising.""" + siblings = [random_domain(PROD_CONFIG) for _ in range(4)] + opening = HashTreeOpening(siblings=HashDigestList(data=siblings)) + assert ( + verify_path( + poseidon=PROD_POSEIDON, + config=PROD_CONFIG, + parameter=random_parameter(PROD_CONFIG), + root=random_domain(PROD_CONFIG), + position=Uint64(position), + leaf_parts=[random_domain(PROD_CONFIG)], + opening=opening, + ) + is False + ) + + +def test_verify_path_accepts_boundary_position_without_raising() -> None: + """The maximum valid position for the depth does not trip the bounds guard.""" + siblings = [random_domain(PROD_CONFIG) for _ in range(4)] + opening = HashTreeOpening(siblings=HashDigestList(data=siblings)) + result = verify_path( + poseidon=PROD_POSEIDON, + config=PROD_CONFIG, + parameter=random_parameter(PROD_CONFIG), + root=random_domain(PROD_CONFIG), + position=Uint64(15), + leaf_parts=[random_domain(PROD_CONFIG)], + opening=opening, + ) + assert isinstance(result, bool) diff --git a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py deleted file mode 100644 index ae6094bc6..000000000 --- a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Tests for the sparse Merkle tree implementation.""" - -import pytest - -from lean_spec.subspecs.xmss.rand import PROD_RAND, Rand -from lean_spec.subspecs.xmss.subtree import HashSubTree, verify_path -from lean_spec.subspecs.xmss.tweak_hash import ( - PROD_TWEAK_HASHER, - TreeTweak, - TweakHasher, -) -from lean_spec.subspecs.xmss.types import HashDigestList, HashDigestVector, HashTreeOpening -from lean_spec.types import Uint64 -from lean_spec.types.exceptions import SSZValueError - - -def _run_commit_open_verify_roundtrip( - hasher: TweakHasher, - rand: Rand, - num_leaves: int, - depth: int, - start_index: int, - leaf_parts_len: int, -) -> None: - """ - A helper function to perform a full Merkle tree roundtrip test. - - The process is as follows: - 1. Generate random leaf data. - 2. Hash the leaves to create layer 0 of the tree. - 3. Build the full Merkle tree and get its root (commit). - 4. For each leaf, generate an authentication path (open). - 5. Verify that each path is valid for its corresponding leaf and root. - - Args: - hasher: The tweakable hash instance for computing parent nodes. - rand: Random generator for padding values. - num_leaves: The number of active leaves in the tree. - depth: The total depth of the Merkle tree. - start_index: The starting index of the first active leaf. - leaf_parts_len: The number of digests that constitute a single leaf. - """ - # SETUP: Generate a random parameter and the raw leaf data. - parameter = rand.parameter() - leaves: list[list[HashDigestVector]] = [ - [rand.domain() for _ in range(leaf_parts_len)] for _ in range(num_leaves) - ] - - # HASH LEAVES: Compute the layer 0 nodes by hashing the leaf parts. - leaf_hashes: list[HashDigestVector] = [ - hasher.apply( - parameter, - TreeTweak(level=0, index=Uint64(start_index + i)), - leaf_parts, - ) - for i, leaf_parts in enumerate(leaves) - ] - - # COMMIT: Build the Merkle tree from the leaf hashes. - tree = HashSubTree.new( - hasher=hasher, - rand=rand, - lowest_layer=Uint64(0), - depth=Uint64(depth), - start_index=Uint64(start_index), - parameter=parameter, - lowest_layer_nodes=leaf_hashes, - ) - root = tree.root() - - # OPEN & VERIFY: For each leaf, generate and verify its path. - for i, leaf_parts in enumerate(leaves): - position = Uint64(start_index + i) - opening = tree.path(position) - is_valid = verify_path( - hasher=hasher, - parameter=parameter, - root=root, - position=position, - leaf_parts=leaf_parts, - opening=opening, - ) - assert is_valid, f"Verification failed for leaf at position {position}" - - -@pytest.mark.parametrize( - "num_leaves, depth, start_index, leaf_parts_len", - [ - pytest.param(16, 4, 0, 3, id="Full tree (depth 4)", marks=pytest.mark.slow), - pytest.param(12, 5, 0, 5, id="Half tree, left-aligned (depth 5)", marks=pytest.mark.slow), - pytest.param(16, 5, 16, 2, id="Half tree, right-aligned (depth 5)"), - pytest.param(22, 6, 13, 3, id="Sparse, non-aligned tree (depth 6)", marks=pytest.mark.slow), - pytest.param(2, 2, 2, 6, id="Half tree, right-aligned (small)"), - pytest.param(1, 1, 0, 1, id="Tree with a single leaf at the start"), - pytest.param(1, 1, 1, 1, id="Tree with a single leaf at an odd index"), - pytest.param(16, 5, 7, 2, id="Small sparse tree starting at an odd index"), - ], -) -def test_commit_open_verify_roundtrip( - num_leaves: int, - depth: int, - start_index: int, - leaf_parts_len: int, -) -> None: - """Tests the Merkle tree logic for various configurations.""" - # Ensure the test case parameters are valid for the specified tree depth. - assert start_index + num_leaves <= (1 << depth) - - _run_commit_open_verify_roundtrip( - PROD_TWEAK_HASHER, PROD_RAND, num_leaves, depth, start_index, leaf_parts_len - ) - - -class TestVerifyPathSecurityBounds: - """ - Security tests for verify_path input validation. - - Verification functions must return False (not raise) on attacker-controlled invalid input. - This prevents denial-of-service via malformed signatures. - """ - - def test_ssz_validation_rejects_excessive_depth(self) -> None: - """ - SSZ type system rejects openings with depth > 32. - - HashDigestList has a LIMIT of 32, so the type system prevents - creating malformed openings at the SSZ level. The check in - verify_path is defense-in-depth for deserialized data. - """ - rand = PROD_RAND - - # Attempting to create a list with 33 siblings raises at the type level. - excessive_siblings = [rand.domain() for _ in range(33)] - with pytest.raises(SSZValueError): - HashDigestList(data=excessive_siblings) - - def test_rejects_position_exceeding_tree_capacity(self) -> None: - """verify_path returns False when position >= 2^depth.""" - rand = PROD_RAND - hasher = PROD_TWEAK_HASHER - parameter = rand.parameter() - - root = rand.domain() - leaf_parts = [rand.domain()] - - # Create an opening with depth=4 (supports positions 0-15). - siblings = [rand.domain() for _ in range(4)] - opening = HashTreeOpening(siblings=HashDigestList(data=siblings)) - - # Position 16 is out of bounds for depth 4 (capacity = 2^4 = 16). - result = verify_path( - hasher=hasher, - parameter=parameter, - root=root, - position=Uint64(16), - leaf_parts=leaf_parts, - opening=opening, - ) - assert result is False - - # Position 100 is also out of bounds. - result = verify_path( - hasher=hasher, - parameter=parameter, - root=root, - position=Uint64(100), - leaf_parts=leaf_parts, - opening=opening, - ) - assert result is False - - def test_valid_position_at_boundary(self) -> None: - """verify_path accepts position at maximum valid value (2^depth - 1).""" - rand = PROD_RAND - hasher = PROD_TWEAK_HASHER - parameter = rand.parameter() - - root = rand.domain() - leaf_parts = [rand.domain()] - - # Create an opening with depth=4. - siblings = [rand.domain() for _ in range(4)] - opening = HashTreeOpening(siblings=HashDigestList(data=siblings)) - - # Position 15 is the maximum valid position for depth 4. - # This should not return False due to bounds check (may still fail root check). - result = verify_path( - hasher=hasher, - parameter=parameter, - root=root, - position=Uint64(15), - leaf_parts=leaf_parts, - opening=opening, - ) - # Result may be False due to wrong root, but importantly it didn't raise. - assert isinstance(result, bool) diff --git a/tests/lean_spec/subspecs/xmss/test_message_hash.py b/tests/lean_spec/subspecs/xmss/test_message_hash.py deleted file mode 100644 index 58861155a..000000000 --- a/tests/lean_spec/subspecs/xmss/test_message_hash.py +++ /dev/null @@ -1,125 +0,0 @@ -""" -Tests for the message hashing and aborting hypercube encoding logic. -""" - -from lean_spec.subspecs.koalabear import Fp, P -from lean_spec.subspecs.xmss.constants import ( - TEST_CONFIG, - TWEAK_PREFIX_MESSAGE, -) -from lean_spec.subspecs.xmss.message_hash import ( - TEST_MESSAGE_HASHER, -) -from lean_spec.subspecs.xmss.rand import TEST_RAND -from lean_spec.subspecs.xmss.types import Randomness -from lean_spec.subspecs.xmss.utils import int_to_base_p -from lean_spec.types import Bytes32, Uint64 - - -def test_encode_message() -> None: - """Tests `encode_message` with various message patterns.""" - config = TEST_CONFIG - hasher = TEST_MESSAGE_HASHER - - # All-zero message - msg_zeros = Bytes32(b"\x00" * 32) - encoded_zeros = hasher.encode_message(msg_zeros) - assert len(encoded_zeros) == config.MSG_LEN_FE - assert all(fe == Fp(value=0) for fe in encoded_zeros) - - # All-max message (0xff) - msg_max = Bytes32(b"\xff" * 32) - acc = int.from_bytes(msg_max, "little") - expected_max = int_to_base_p(acc, config.MSG_LEN_FE) - assert hasher.encode_message(msg_max) == expected_max - - -def test_encode_epoch() -> None: - """ - Tests `encode_epoch` for correctness and injectivity. - """ - hasher = TEST_MESSAGE_HASHER - config = TEST_CONFIG - - # Test specific values from the Rust reference tests. - test_epochs = [0, 42, 2**32 - 1] - for epoch in test_epochs: - acc = (epoch << 8) | TWEAK_PREFIX_MESSAGE - expected = int_to_base_p(acc, config.TWEAK_LEN_FE) - assert hasher.encode_epoch(Uint64(epoch)) == expected - - # Test for injectivity. It is highly unlikely for a collision to occur - # with a few random samples if the encoding is injective. - num_trials = 1000 - seen_encodings: set[tuple[Fp, ...]] = set() - for i in range(num_trials): - encoding = tuple(hasher.encode_epoch(Uint64(i))) - assert encoding not in seen_encodings - seen_encodings.add(encoding) - - -def test_aborting_decode_known_decomposition() -> None: - """Verifies aborting decode with a hand-computed example.""" - hasher = TEST_MESSAGE_HASHER - config = TEST_CONFIG - - # Pick an arbitrary quotient multiplier to build a valid field element. - d_value = 5 - fe_list = [Fp(value=config.Q * d_value)] * hasher.config.MH_HASH_LEN_FE - result = hasher._aborting_decode(fe_list) - assert result is not None - assert len(result) == config.DIMENSION - - # Each FE decomposes d_value into Z base-BASE digits (LSB first), - # then the first DIMENSION digits are taken across all FEs. - digits_per_fe = [] - remaining = d_value - for _ in range(config.Z): - digits_per_fe.append(remaining % config.BASE) - remaining //= config.BASE - all_digits = (digits_per_fe * hasher.config.MH_HASH_LEN_FE)[: config.DIMENSION] - assert result == all_digits - - -def test_aborting_decode_boundary() -> None: - """Tests that FE = P-2 succeeds and FE = P-1 aborts.""" - hasher = TEST_MESSAGE_HASHER - config = TEST_CONFIG - - # P - 2 is the largest valid value (just below Q * BASE^Z = P - 1). - fe_valid = [Fp(value=P - 2)] * hasher.config.MH_HASH_LEN_FE - result = hasher._aborting_decode(fe_valid) - assert result is not None - assert len(result) == config.DIMENSION - assert all(0 <= d < config.BASE for d in result) - - # P - 1 triggers the abort (A_i >= Q * BASE^Z). - fe_abort = [Fp(value=P - 1)] - result = hasher._aborting_decode(fe_abort) - assert result is None - - -def test_apply_output_is_valid_codeword() -> None: - """ - Tests that the output of `apply` is `None` or a valid codeword with - DIMENSION digits each in `[0, BASE-1]`. - """ - config = TEST_CONFIG - hasher = TEST_MESSAGE_HASHER - rand = TEST_RAND - - # Setup with random inputs. - parameter = rand.parameter() - epoch = Uint64(313) - randomness = Randomness(data=rand.field_elements(config.RAND_LEN_FE)) - message = Bytes32(b"\xaa" * 32) - - # Call the message hash function. - result = hasher.apply(parameter, epoch, randomness, message) - - # The aborting decode may return None, but in practice it almost never does. - assert result is not None - - # Verify the properties of the output codeword. - assert len(result) == config.DIMENSION - assert all(0 <= coord < config.BASE for coord in result) diff --git a/tests/lean_spec/subspecs/xmss/test_poseidon.py b/tests/lean_spec/subspecs/xmss/test_poseidon.py new file mode 100644 index 000000000..7dd605c97 --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_poseidon.py @@ -0,0 +1,184 @@ +"""Tests for the Poseidon1 hash engine wrapper used by the XMSS scheme.""" + +import pytest + +from lean_spec.subspecs.koalabear import Fp +from lean_spec.subspecs.xmss.constants import TEST_CONFIG +from lean_spec.subspecs.xmss.field import random_domain +from lean_spec.subspecs.xmss.poseidon import TEST_POSEIDON +from lean_spec.subspecs.xmss.types import ( + ChainTweak, + HashDigestVector, + Parameter, + TreeTweak, +) +from lean_spec.types import Uint64 + + +def _parameter() -> Parameter: + """Return a fixed public parameter for hashing tests.""" + return Parameter(data=[Fp(value=1)] * TEST_CONFIG.PARAMETER_LEN) + + +@pytest.mark.parametrize("width", [16, 24]) +def test_get_engine_caches_supported_widths(width: int) -> None: + """A supported width yields the same engine instance on repeated calls.""" + assert TEST_POSEIDON._get_engine(width) is TEST_POSEIDON._get_engine(width) + + +@pytest.mark.parametrize("width", [0, 8, 15, 17, 32]) +def test_get_engine_rejects_unsupported_width(width: int) -> None: + """An unsupported width raises with the offending value named.""" + with pytest.raises(ValueError, match=f"Width must be 16 or 24, got {width}"): + TEST_POSEIDON._get_engine(width) + + +@pytest.mark.parametrize("width", [16, 24]) +def test_compress_returns_requested_output_length(width: int) -> None: + """Compression returns exactly the requested number of output elements.""" + result = TEST_POSEIDON.compress([Fp(value=i) for i in range(8)], width, 8) + assert len(result) == 8 + + +def test_compress_truncates_to_short_output() -> None: + """A short output length yields a truncated digest.""" + result = TEST_POSEIDON.compress([Fp(value=i) for i in range(8)], 16, 1) + assert len(result) == 1 + + +def test_compress_is_deterministic() -> None: + """The same input compresses to the same output.""" + a = TEST_POSEIDON.compress([Fp(value=i) for i in range(8)], 16, 8) + b = TEST_POSEIDON.compress([Fp(value=i) for i in range(8)], 16, 8) + assert a == b + + +def test_compress_rejects_output_longer_than_input() -> None: + """Requesting more output than the raw input length raises.""" + with pytest.raises(ValueError, match="Input vector is too short for requested output length."): + TEST_POSEIDON.compress([Fp(value=1), Fp(value=2)], 16, 8) + + +def test_safe_domain_separator_returns_capacity_length() -> None: + """The domain separator returns a vector of the requested capacity length.""" + assert len(TEST_POSEIDON.safe_domain_separator([5, 2, 4, 8], 9)) == 9 + + +def test_safe_domain_separator_distinguishes_shapes() -> None: + """Different length parameters produce different capacity values.""" + assert TEST_POSEIDON.safe_domain_separator([1, 2], 9) != TEST_POSEIDON.safe_domain_separator( + [2, 1], 9 + ) + + +def test_sponge_returns_requested_output_length() -> None: + """The sponge returns exactly the requested number of output elements.""" + capacity = TEST_POSEIDON.safe_domain_separator([1, 2, 3, 4], 9) + assert len(TEST_POSEIDON.sponge([Fp(value=1)] * 5, capacity, 8, 24)) == 8 + + +def test_sponge_squeezes_more_than_one_rate_block() -> None: + """Requesting more output than one rate block permutes again to squeeze enough.""" + capacity = TEST_POSEIDON.safe_domain_separator([1, 2, 3, 4], 9) + rate = 24 - len(capacity) + assert len(TEST_POSEIDON.sponge([Fp(value=1)] * 5, capacity, rate + 1, 24)) == rate + 1 + + +def test_sponge_rejects_capacity_not_smaller_than_width() -> None: + """A capacity that fills the whole state leaves no rate slot and raises.""" + with pytest.raises(ValueError, match="Capacity length must be smaller than the state width."): + TEST_POSEIDON.sponge([Fp(value=1)], [Fp(value=0)] * 16, 1, 16) + + +def test_tweak_hash_chain_uses_width_sixteen_compression() -> None: + """A single digest input hashes through width-sixteen compression.""" + result = TEST_POSEIDON.tweak_hash( + TEST_CONFIG, + _parameter(), + ChainTweak(epoch=Uint64(0), chain_index=1, step=1), + [random_domain(TEST_CONFIG)], + ) + assert isinstance(result, HashDigestVector) + assert len(result.data) == TEST_CONFIG.HASH_LEN_FE + + +def test_tweak_hash_node_uses_width_twenty_four_compression() -> None: + """Two digest inputs hash through width-twenty-four compression.""" + result = TEST_POSEIDON.tweak_hash( + TEST_CONFIG, + _parameter(), + TreeTweak(level=1, index=Uint64(0)), + [random_domain(TEST_CONFIG), random_domain(TEST_CONFIG)], + ) + assert len(result.data) == TEST_CONFIG.HASH_LEN_FE + + +def test_tweak_hash_leaf_uses_sponge_mode() -> None: + """More than two digest inputs hash through sponge mode.""" + parts = [random_domain(TEST_CONFIG) for _ in range(TEST_CONFIG.DIMENSION)] + result = TEST_POSEIDON.tweak_hash( + TEST_CONFIG, _parameter(), TreeTweak(level=0, index=Uint64(0)), parts + ) + assert len(result.data) == TEST_CONFIG.HASH_LEN_FE + + +def test_tweak_hash_chain_and_tree_tweaks_are_domain_separated() -> None: + """A chain tweak and a tree tweak over one digest produce different hashes.""" + digest = random_domain(TEST_CONFIG) + chain = TEST_POSEIDON.tweak_hash( + TEST_CONFIG, _parameter(), ChainTweak(epoch=Uint64(0), chain_index=0, step=1), [digest] + ) + # A single-part tree tweak still routes through width-sixteen compression. + tree = TEST_POSEIDON.tweak_hash( + TEST_CONFIG, _parameter(), TreeTweak(level=0, index=Uint64(0)), [digest] + ) + assert chain != tree + + +def test_hash_chain_zero_steps_returns_start_digest() -> None: + """Walking zero steps returns the starting digest unchanged.""" + start = random_domain(TEST_CONFIG) + result = TEST_POSEIDON.hash_chain( + config=TEST_CONFIG, + parameter=_parameter(), + epoch=Uint64(0), + chain_index=0, + start_step=0, + num_steps=0, + start_digest=start, + ) + assert result == start + + +def test_hash_chain_is_composable() -> None: + """Walking two steps equals walking one step then one more.""" + parameter = _parameter() + start = random_domain(TEST_CONFIG) + two = TEST_POSEIDON.hash_chain( + config=TEST_CONFIG, + parameter=parameter, + epoch=Uint64(0), + chain_index=0, + start_step=0, + num_steps=2, + start_digest=start, + ) + one = TEST_POSEIDON.hash_chain( + config=TEST_CONFIG, + parameter=parameter, + epoch=Uint64(0), + chain_index=0, + start_step=0, + num_steps=1, + start_digest=start, + ) + one_more = TEST_POSEIDON.hash_chain( + config=TEST_CONFIG, + parameter=parameter, + epoch=Uint64(0), + chain_index=0, + start_step=1, + num_steps=1, + start_digest=one, + ) + assert two == one_more diff --git a/tests/lean_spec/subspecs/xmss/test_prf.py b/tests/lean_spec/subspecs/xmss/test_prf.py index cf4c63e00..70b5b9d17 100644 --- a/tests/lean_spec/subspecs/xmss/test_prf.py +++ b/tests/lean_spec/subspecs/xmss/test_prf.py @@ -4,29 +4,26 @@ PRF_KEY_LENGTH, TEST_CONFIG, ) -from lean_spec.subspecs.xmss.prf import TEST_PRF -from lean_spec.subspecs.xmss.types import PRFKey +from lean_spec.subspecs.xmss.prf import PRFKey from lean_spec.types import Uint64 def test_key_gen_is_random() -> None: """ - Performs a sanity check on `key_gen` to ensure it's not deterministic + Performs a sanity check on key_gen to ensure it's not deterministic or producing trivial outputs. This test mirrors the logic from the reference Rust implementation. """ - prf = TEST_PRF - # Check that the key has the correct length. - key = prf.key_gen() + key = PRFKey.generate() assert len(key) == PRF_KEY_LENGTH # Generate multiple keys and ensure they are not all identical. # # This is a basic check to ensure we are getting fresh randomness. num_trials = 10 - keys = {prf.key_gen() for _ in range(num_trials)} + keys = {PRFKey.generate() for _ in range(num_trials)} assert len(keys) == num_trials # Check that the keys are not filled with a single repeated byte. @@ -35,7 +32,7 @@ def test_key_gen_is_random() -> None: # such a key, so this is a good health check. all_same_count = 0 for _ in range(num_trials): - key = prf.key_gen() + key = PRFKey.generate() # A set will have size 1 if all elements are the same. if len(set(key)) == 1: all_same_count += 1 @@ -44,32 +41,31 @@ def test_key_gen_is_random() -> None: def test_apply_is_sensitive_to_inputs() -> None: """ - Tests that changing any input to `apply` results in a different output. + Tests that changing any input to apply results in a different output. This confirms that all parts of the input (key, epoch, chain_index) are being correctly absorbed by the hash function. """ - prf = TEST_PRF config = TEST_CONFIG # Generate a baseline output with a set of initial inputs. key1 = PRFKey(b"\x11" * PRF_KEY_LENGTH) epoch1 = Uint64(10) chain_index1 = Uint64(20) - baseline_output = prf.apply(key1, epoch1, chain_index1) + baseline_output = key1.derive_chain_start(config, epoch1, chain_index1) assert len(baseline_output) == config.HASH_LEN_FE # Test sensitivity to the key. key2 = PRFKey(b"\x22" * PRF_KEY_LENGTH) - output_key_changed = prf.apply(key2, epoch1, chain_index1) + output_key_changed = key2.derive_chain_start(config, epoch1, chain_index1) assert baseline_output != output_key_changed # Test sensitivity to the epoch. epoch2 = Uint64(11) - output_epoch_changed = prf.apply(key1, epoch2, chain_index1) + output_epoch_changed = key1.derive_chain_start(config, epoch2, chain_index1) assert baseline_output != output_epoch_changed # Test sensitivity to the chain_index. chain_index2 = Uint64(21) - output_index_changed = prf.apply(key1, epoch1, chain_index2) + output_index_changed = key1.derive_chain_start(config, epoch1, chain_index2) assert baseline_output != output_index_changed diff --git a/tests/lean_spec/subspecs/xmss/test_security_levels.py b/tests/lean_spec/subspecs/xmss/test_security_levels.py deleted file mode 100644 index 815d9bed9..000000000 --- a/tests/lean_spec/subspecs/xmss/test_security_levels.py +++ /dev/null @@ -1,307 +0,0 @@ -""" -Validates that XMSS parameter choices achieve adequate classical and quantum security. - -Based on: - -- [DKKW25c] "Hash-Based Multi-Signatures for Post-Quantum Ethereum" - (https://eprint.iacr.org/2025/055.pdf) -- [HKKTW26] "Aborting Random Oracles" - (https://eprint.iacr.org/2026/016) - -The security analysis follows the framework of [DKKW25c] Section 6. Theorem 1 -gives an advantage bound as the sum of five terms. Each term divided by attacker -running time must be at most `2^{-(k + log5)}`, yielding four independent -constraints (Parameter Requirements 2 and 3): - -1. Digest (SM-UD/SM-PRE via Eq 8-9 / Eq 15) -2. Public parameter (SM-TCR via Eq 6-7 / Eq 16) -3. Message hash (SM-rTCR via Eq 10 / Eq 13) -4. Randomness (SM-rTCR via Eq 10 / Eq 14) - -The abort correction from [HKKTW26] Corollary 1 and Remark 14 adjusts the -message hash bound: the aborting decode effectively enlarges the output space -to `|H|/(1 - theta)`, where `theta` is the abort probability. -""" - -import math - -import pytest - -from lean_spec.subspecs.koalabear import P -from lean_spec.subspecs.xmss.constants import PROD_CONFIG, XmssConfig - - -def _calculate_layer_size(w: int, v: int, d: int) -> int: - """ - Calculates a hypercube layer's size using inclusion-exclusion. - - Counts integer solutions to x_1 + ... + x_v = k with 0 <= x_i <= w-1, - where k = v*(w-1) - d. - """ - coord_sum = v * (w - 1) - d - return sum( - ((-1) ** s) * math.comb(v, s) * math.comb(coord_sum - s * w + v - 1, v - 1) - for s in range(coord_sum // w + 1) - ) - - -def _compute_security_levels(config: XmssConfig) -> dict[str, float]: - """ - Computes classical and quantum security levels for an XMSS configuration. - - Returns a dict with keys: - - - `k_classical`: effective classical security (bits) - - `k_quantum`: effective quantum security (bits) - - `expected_attempts`: expected signing attempts per message - - `signing_failure_log2`: log2 of probability that all MAX_TRIES attempts fail - """ - v = config.DIMENSION - w_bits = int(math.log2(config.BASE)) - base = config.BASE - - # Bit sizes of the parameter spaces. - # - # Each KoalaBear field element contributes floor(log2(P)) = 31 bits. - fe_bits = 31 - bits_digest = config.HASH_LEN_FE * fe_bits - bits_param = config.PARAMETER_LEN * fe_bits - bits_rand = config.RAND_LEN_FE * fe_bits - - # Raw message hash output: v chunks of w bits each. - bits_msg = v * w_bits - - # Abort correction from [HKKTW26] Corollary 1, Remark 14. - # - # Each field element aborts iff A_i >= Q * BASE^Z (i.e., A_i == P - 1). - # The non-abort probability per FE is (Q * BASE^Z) / P = (P - 1) / P. - # Over ell = ceil(v / Z) field elements, the total non-abort probability is: - # (1 - theta) = ((P - 1) / P) ^ ell - # - # The aborting rTCR bound ([HKKTW26] Corollary 1) gains a factor (1 - theta), - # which is equivalent to hashing into a space of size |H| / (1 - theta). - # This adds -log2(1 - theta) bits to the effective message hash output. - wz = base**config.Z - q = config.Q - ell = math.ceil(v / config.Z) - - non_abort_total = ((q * wz) / P) ** ell - abort_correction_bits = -math.log2(non_abort_total) - - bits_msg_eff = bits_msg + abort_correction_bits - - # Useful logarithmic constants. - log5 = math.log2(5) - log12 = math.log2(12) - log3 = math.log2(3) - log_lifetime = math.log2(config.LIFETIME) - logv = math.log2(v) - log_max_tries = math.log2(config.MAX_TRIES) - logqs = math.log2(config.LIFETIME) - - # Classical security: minimum over four attack surfaces. - # - # Each bound derives from the requirement that each of the five terms in - # Theorem 1 satisfies Adv_i / T(A) <= 2^{-(k_C + log5)}. - k_classical = min( - # [DKKW25c] Eq (15): SM-UD + SM-PRE on the digest hash Th. - bits_digest - log5 - 2 * w_bits - log_lifetime - logv, - # [DKKW25c] Eq (16): SM-TCR on the public parameter space. - bits_param - log5 - 3, - # [DKKW25c] Eq (13) + [HKKTW26] Corollary 1: SM-rTCR on message hash. - bits_msg_eff - log5 - 1, - # [DKKW25c] Eq (14): SM-rTCR randomness reprogramming. - bits_rand - log5 - logqs - log_max_tries - 1, - ) - - # Quantum security: minimum over four attack surfaces. - # - # Uses quantum ROM bounds from [DKKW25c] Table 1. - k_quantum = min( - # [DKKW25c] Eq (15), quantum: digest hash. - bits_digest / 2 - log5 - 2 * w_bits - log_lifetime - logv - log12, - # [DKKW25c] Eq (16), quantum: public parameter. - (bits_param - 5) / 2 - log5 - 2, - # [DKKW25c] Eq (13) + [HKKTW26] Corollary 1, quantum: message hash. - (bits_msg_eff - 3) / 2 - log5 - 1, - # [DKKW25c] Eq (14), quantum: randomness reprogramming. - (bits_rand - logqs) / 2 - log5 - log3 - log_max_tries, - ) - - # Expected signing attempts for target-sum encoding. - # - # [DKKW25c] Construction 6, Lemma 7: the number of valid codewords is - # |C| = #{x in Z_W^v : sum(x_i) = T}, the layer size at distance - # d = v*(W-1) - T from the sink vertex. The inclusion-exclusion formula - # from _calculate_layer_size gives |C|. - # - # Success probability per attempt = P(no abort) * P(target layer | no abort). - d = v * (base - 1) - config.TARGET_SUM - layer_size = _calculate_layer_size(base, v, d) - layer_prob = layer_size / base**v - success_prob = non_abort_total * layer_prob - expected_attempts = 1 / success_prob - - # [DKKW25c] Lemma 3: correctness error is delta^K where delta = 1 - success_prob. - signing_failure_log2 = config.MAX_TRIES * math.log2(1 - success_prob) - - return { - "k_classical": k_classical, - "k_quantum": k_quantum, - "expected_attempts": expected_attempts, - "signing_failure_log2": signing_failure_log2, - } - - -def test_prod_classical_security() -> None: - """Production parameters must achieve at least 128-bit classical security.""" - levels = _compute_security_levels(PROD_CONFIG) - assert levels["k_classical"] >= 128, ( - f"Classical security {levels['k_classical']:.2f} bits is below 128" - ) - - -def test_prod_quantum_security() -> None: - """Production parameters must achieve at least 64-bit quantum security.""" - levels = _compute_security_levels(PROD_CONFIG) - assert levels["k_quantum"] >= 64, f"Quantum security {levels['k_quantum']:.2f} bits is below 64" - - -def test_prod_signing_efficiency() -> None: - """Signing must succeed within a reasonable number of attempts on average.""" - levels = _compute_security_levels(PROD_CONFIG) - - # Expected attempts should be manageable (< 1000). - assert levels["expected_attempts"] < 1000, ( - f"Expected {levels['expected_attempts']:.2f} signing attempts is too high" - ) - - # The probability of MAX_TRIES consecutive failures must be astronomically small. - # log2(failure_prob) < -128 means failure probability < 2^{-128}. - assert levels["signing_failure_log2"] < -128, ( - f"Signing failure probability 2^{levels['signing_failure_log2']:.2f} is too high" - ) - - -def test_prod_abort_probability_is_negligible() -> None: - """ - The aborting decode rejection probability must be negligible. - - From [HKKTW26] Section 6.1: each FE has abort probability 1/P. - Over `ceil(v/Z)` FEs, the total abort probability is approximately - `ceil(v/Z) / P`. - """ - config = PROD_CONFIG - ell = math.ceil(config.DIMENSION / config.Z) - - # Per-FE non-abort probability: (Q * BASE^Z) / P = (P - 1) / P. - non_abort_per_fe = (config.Q * config.BASE**config.Z) / P - total_non_abort = non_abort_per_fe**ell - - # The abort probability should be less than 2^{-28} (~3.7e-9). - abort_prob = 1 - total_non_abort - assert abort_prob < 2**-28, f"Abort probability {abort_prob:.2e} is not negligible" - - -def test_prod_decomposition_invariant() -> None: - """ - Validates the fundamental relationship Q * BASE^Z == P - 1. - - From [HKKTW26] Section 6.1: for KoalaBear, P - 1 = 2^24 * 127. - With BASE = 2^w, the decomposition requires w | 24 so that - Z = 24 / w digits can be extracted from each field element. - """ - config = PROD_CONFIG - - # Core decomposition invariant (also checked at config construction time). - assert config.Q * config.BASE**config.Z == P - 1 - - # w must divide 24 for the rejection sampling to work with KoalaBear. - # - # P - 1 = 2^24 * 127, and BASE = 2^w, so we need w | 24. - w_bits = int(math.log2(config.BASE)) - assert config.BASE == 2**w_bits, "BASE must be a power of 2" - assert 24 % w_bits == 0, f"w={w_bits} must divide 24" - - # Z must equal 24 / w for the optimal decomposition (alpha = 1). - assert config.Z == 24 // w_bits, f"Z={config.Z} must equal 24/w={24 // w_bits}" - - -def test_prod_mh_hash_len_is_consistent() -> None: - """ - The Poseidon output length must produce enough digits to cover DIMENSION. - - From [HKKTW26] Section 6.1: ell = ceil(v / z) field elements produce - ell * z >= v base-w digits. - """ - config = PROD_CONFIG - assert config.MH_HASH_LEN_FE * config.Z >= config.DIMENSION - - -def test_prod_binding_constraint_is_message_hash() -> None: - """ - Verify the binding (smallest) constraint is the message hash for both - classical and quantum security, matching the design intent from [DKKW25c]. - """ - config = PROD_CONFIG - v = config.DIMENSION - w_bits = int(math.log2(config.BASE)) - fe_bits = 31 - - bits_digest = config.HASH_LEN_FE * fe_bits - bits_param = config.PARAMETER_LEN * fe_bits - bits_rand = config.RAND_LEN_FE * fe_bits - bits_msg = v * w_bits - - log5 = math.log2(5) - log_lifetime = math.log2(config.LIFETIME) - logv = math.log2(v) - log_max_tries = math.log2(config.MAX_TRIES) - - # Classical: the message hash bound v*w - log5 - 1 should be the tightest. - classical_bounds = [ - bits_digest - log5 - 2 * w_bits - log_lifetime - logv, - bits_param - log5 - 3, - bits_msg - log5 - 1, - bits_rand - log5 - log_lifetime - log_max_tries - 1, - ] - assert classical_bounds.index(min(classical_bounds)) == 2, ( - "Classical binding constraint should be message hash (index 2)" - ) - - -@pytest.mark.parametrize( - "param_name, value", - [ - ("DIMENSION", 46), - ("BASE", 8), - ("Z", 8), - ("Q", 127), - ("TARGET_SUM", 200), - ("LOG_LIFETIME", 32), - ("PARAMETER_LEN", 5), - ("TWEAK_LEN_FE", 2), - ("MSG_LEN_FE", 9), - ("RAND_LEN_FE", 7), - ("HASH_LEN_FE", 8), - ("CAPACITY", 9), - ], -) -def test_prod_config_matches_reference(param_name: str, value: int) -> None: - """ - Guards against accidental parameter drift. - - These values must match the canonical Rust implementation (leanSig). - """ - assert getattr(PROD_CONFIG, param_name) == value - - -def test_print_security_summary(capsys: pytest.CaptureFixture[str]) -> None: - """Prints a human-readable summary of the security analysis (informational).""" - levels = _compute_security_levels(PROD_CONFIG) - print("\n--- XMSS Production Security Summary ---") - print(f"Classical security: {levels['k_classical']:.2f} bits") - print(f"Quantum security: {levels['k_quantum']:.2f} bits") - print(f"Expected sign attempts: {levels['expected_attempts']:.2f}") - print(f"Signing failure (log2): {levels['signing_failure_log2']:.2f}") - print("----------------------------------------") diff --git a/tests/lean_spec/subspecs/xmss/test_ssz_serialization.py b/tests/lean_spec/subspecs/xmss/test_ssz_serialization.py deleted file mode 100644 index 02ce9b3e9..000000000 --- a/tests/lean_spec/subspecs/xmss/test_ssz_serialization.py +++ /dev/null @@ -1,147 +0,0 @@ -"""Tests for SSZ serialization of XMSS types.""" - -from lean_spec.subspecs.koalabear.field import P_BYTES -from lean_spec.subspecs.xmss.constants import TEST_CONFIG -from lean_spec.subspecs.xmss.containers import PublicKey, SecretKey, Signature -from lean_spec.subspecs.xmss.interface import TEST_SIGNATURE_SCHEME -from lean_spec.types import Bytes32, Slot, Uint64 - - -def test_public_key_ssz_roundtrip() -> None: - """Test that PublicKey can be SSZ serialized and deserialized.""" - # Generate a key pair - activation_slot = Slot(0) - num_active_slots = Uint64(32) - public_key = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots).public_key - - # Serialize to bytes using SSZ - pk_bytes = public_key.encode_bytes() - - # Deserialize from bytes - recovered_pk = PublicKey.decode_bytes(pk_bytes) - - # Verify the recovered public key matches the original - assert recovered_pk.root == public_key.root - assert recovered_pk.parameter == public_key.parameter - assert recovered_pk == public_key - - -def test_signature_ssz_roundtrip() -> None: - """Test that Signature can be SSZ serialized and deserialized.""" - # Generate a key pair and sign a message - activation_slot = Slot(0) - num_active_slots = Uint64(32) - kp = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots) - public_key, secret_key = kp.public_key, kp.secret_key - - message = Bytes32(bytes([42] * 32)) - epoch = Slot(0) - signature = TEST_SIGNATURE_SCHEME.sign(secret_key, epoch, message) - - # Serialize to bytes using SSZ - sig_bytes = signature.encode_bytes() - - # Deserialize from bytes - recovered_sig = Signature.decode_bytes(sig_bytes) - - # Verify the recovered signature matches the original - assert recovered_sig.path.siblings == signature.path.siblings - assert recovered_sig.rho == signature.rho - assert recovered_sig.hashes == signature.hashes - assert recovered_sig == signature - - # Verify the signature still verifies - assert TEST_SIGNATURE_SCHEME.verify(public_key, epoch, message, recovered_sig) - - -def test_secret_key_ssz_roundtrip() -> None: - """Test that SecretKey can be SSZ serialized and deserialized.""" - # Generate a key pair - activation_slot = Slot(0) - num_active_slots = Uint64(32) - kp = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots) - public_key, secret_key = kp.public_key, kp.secret_key - - # Serialize to bytes using SSZ - sk_bytes = secret_key.encode_bytes() - - # Deserialize from bytes - recovered_sk = SecretKey.decode_bytes(sk_bytes) - - # Verify the recovered secret key matches the original - assert recovered_sk.prf_key == secret_key.prf_key - assert recovered_sk.parameter == secret_key.parameter - assert recovered_sk.activation_slot == secret_key.activation_slot - assert recovered_sk.num_active_slots == secret_key.num_active_slots - assert recovered_sk.top_tree == secret_key.top_tree - assert recovered_sk.left_bottom_tree_index == secret_key.left_bottom_tree_index - assert recovered_sk.left_bottom_tree == secret_key.left_bottom_tree - assert recovered_sk.right_bottom_tree == secret_key.right_bottom_tree - assert recovered_sk == secret_key - - # Verify the recovered secret key can still sign - message = Bytes32(bytes([99] * 32)) - epoch = Slot(1) - signature = TEST_SIGNATURE_SCHEME.sign(recovered_sk, epoch, message) - assert TEST_SIGNATURE_SCHEME.verify(public_key, epoch, message, signature) - - -def test_deterministic_serialization() -> None: - """Test that serialization is deterministic.""" - # Generate a key pair - activation_slot = Slot(0) - num_active_slots = Uint64(32) - kp = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots) - public_key, secret_key = kp.public_key, kp.secret_key - - # Serialize multiple times - pk_bytes1 = public_key.encode_bytes() - pk_bytes2 = public_key.encode_bytes() - sk_bytes1 = secret_key.encode_bytes() - sk_bytes2 = secret_key.encode_bytes() - - # Verify serialization is deterministic - assert pk_bytes1 == pk_bytes2 - assert sk_bytes1 == sk_bytes2 - - # Sign a message multiple times with deterministic randomness - message = Bytes32(bytes([42] * 32)) - epoch = Slot(0) - sig1 = TEST_SIGNATURE_SCHEME.sign(secret_key, epoch, message) - sig2 = TEST_SIGNATURE_SCHEME.sign(secret_key, epoch, message) - - # Signatures should be identical (deterministic signing) - assert sig1 == sig2 - - sig_bytes1 = sig1.encode_bytes() - sig_bytes2 = sig2.encode_bytes() - assert sig_bytes1 == sig_bytes2 - - -def test_signature_size_matches_config() -> None: - """Verify SIGNATURE_LEN_BYTES matches actual SSZ-encoded size.""" - activation_slot = Slot(0) - num_active_slots = Uint64(32) - secret_key = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots).secret_key - - message = Bytes32(bytes([42] * 32)) - epoch = Slot(0) - signature = TEST_SIGNATURE_SCHEME.sign(secret_key, epoch, message) - - encoded = signature.encode_bytes() - assert len(encoded) == TEST_CONFIG.SIGNATURE_LEN_BYTES - - -def test_public_key_size_matches_config() -> None: - """Verify the encoded public key size matches its SSZ shape. - - A public key is a HASH_LEN_FE-element hash digest plus a PARAMETER_LEN - public parameter, each field element packed into P_BYTES bytes. - """ - activation_slot = Slot(0) - num_active_slots = Uint64(32) - public_key = TEST_SIGNATURE_SCHEME.key_gen(activation_slot, num_active_slots).public_key - - encoded = public_key.encode_bytes() - expected = (TEST_CONFIG.HASH_LEN_FE + TEST_CONFIG.PARAMETER_LEN) * P_BYTES - assert len(encoded) == expected diff --git a/tests/lean_spec/subspecs/xmss/test_types.py b/tests/lean_spec/subspecs/xmss/test_types.py new file mode 100644 index 000000000..36f78c90f --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_types.py @@ -0,0 +1,92 @@ +"""Tests for the base SSZ types of the XMSS signature scheme.""" + +import pytest + +from lean_spec.subspecs.koalabear import Fp +from lean_spec.subspecs.xmss.constants import TEST_CONFIG +from lean_spec.subspecs.xmss.field import random_domain +from lean_spec.subspecs.xmss.types import ( + NODE_LIST_LIMIT, + ChainTweak, + HashDigestList, + HashDigestVector, + HashTreeOpening, + Parameter, + Randomness, + TreeTweak, +) +from lean_spec.types import Uint64 +from lean_spec.types.exceptions import SSZValueError + + +def test_tree_tweak_fields() -> None: + """A tree tweak stores its level and node index in order.""" + assert TreeTweak(level=3, index=Uint64(7)) == (3, Uint64(7)) + + +def test_chain_tweak_fields() -> None: + """A chain tweak stores its epoch, chain index, and step in order.""" + assert ChainTweak(epoch=Uint64(2), chain_index=5, step=1) == (Uint64(2), 5, 1) + + +def test_node_list_limit_is_twice_the_leaf_row() -> None: + """The sparse-layer cap is twice the widest bottom-tree leaf row.""" + assert NODE_LIST_LIMIT == 2 * TEST_CONFIG.LEAVES_PER_BOTTOM_TREE + + +def test_hash_digest_vector_length_is_digest_length() -> None: + """A digest vector holds exactly one Poseidon output worth of elements.""" + assert HashDigestVector.LENGTH == TEST_CONFIG.HASH_LEN_FE + + +def test_hash_digest_vector_accepts_exact_length() -> None: + """A digest vector of the configured length validates.""" + data = [Fp(value=i) for i in range(TEST_CONFIG.HASH_LEN_FE)] + assert HashDigestVector(data=data).data == tuple(data) + + +def test_hash_digest_vector_rejects_wrong_length() -> None: + """A digest vector of the wrong length fails validation.""" + with pytest.raises(SSZValueError): + HashDigestVector(data=[Fp(value=0)] * (TEST_CONFIG.HASH_LEN_FE + 1)) + + +def test_parameter_length_is_parameter_length() -> None: + """A parameter holds the configured number of personalization elements.""" + assert Parameter.LENGTH == TEST_CONFIG.PARAMETER_LEN + + +def test_randomness_length_is_randomness_length() -> None: + """The signing randomness holds the configured number of elements.""" + assert Randomness.LENGTH == TEST_CONFIG.RAND_LEN_FE + + +def test_hash_digest_list_limit_is_node_list_limit() -> None: + """The digest list cap matches the sparse-layer node limit.""" + assert HashDigestList.LIMIT == NODE_LIST_LIMIT + + +def test_hash_digest_list_accepts_limit_entries() -> None: + """A digest list filled to the cap validates.""" + nodes = [random_domain(TEST_CONFIG) for _ in range(NODE_LIST_LIMIT)] + assert len(HashDigestList(data=nodes).data) == NODE_LIST_LIMIT + + +def test_hash_digest_list_rejects_over_limit() -> None: + """A digest list one entry past the cap fails validation.""" + nodes = [random_domain(TEST_CONFIG) for _ in range(NODE_LIST_LIMIT + 1)] + with pytest.raises(SSZValueError): + HashDigestList(data=nodes) + + +def test_hash_tree_opening_roundtrips_through_ssz() -> None: + """An opening encodes and decodes back to an equal value.""" + siblings = [random_domain(TEST_CONFIG) for _ in range(3)] + opening = HashTreeOpening(siblings=HashDigestList(data=siblings)) + assert HashTreeOpening.decode_bytes(opening.encode_bytes()) == opening + + +def test_hash_tree_opening_empty_is_allowed() -> None: + """An opening with no siblings is a valid empty path.""" + opening = HashTreeOpening(siblings=HashDigestList(data=[])) + assert len(opening.siblings) == 0 diff --git a/tests/lean_spec/subspecs/xmss/test_utils.py b/tests/lean_spec/subspecs/xmss/test_utils.py deleted file mode 100644 index 40d0bfbd8..000000000 --- a/tests/lean_spec/subspecs/xmss/test_utils.py +++ /dev/null @@ -1,213 +0,0 @@ -"""Tests for the utility functions in the XMSS signature scheme.""" - -import secrets -from typing import List - -import pytest - -from lean_spec.subspecs.koalabear.field import Fp, P -from lean_spec.subspecs.xmss.constants import TEST_CONFIG -from lean_spec.subspecs.xmss.prf import TEST_PRF -from lean_spec.subspecs.xmss.rand import TEST_RAND -from lean_spec.subspecs.xmss.subtree import HashSubTree -from lean_spec.subspecs.xmss.tweak_hash import TEST_TWEAK_HASHER -from lean_spec.subspecs.xmss.types import Parameter -from lean_spec.subspecs.xmss.utils import ( - expand_activation_time, - int_to_base_p, -) -from lean_spec.types import Uint64 - - -@pytest.mark.parametrize( - "value, num_limbs, expected_values", - [ - (0, 4, [0, 0, 0, 0]), - (123, 4, [123, 0, 0, 0]), - (P, 4, [0, 1, 0, 0]), - (P - 1, 4, [P - 1, 0, 0, 0]), - (3 * (P**2) + 2 * P + 1, 4, [1, 2, 3, 0]), - (P**3 - 1, 3, [P - 1, P - 1, P - 1]), - ], -) -def test_int_to_base_p(value: int, num_limbs: int, expected_values: List[int]) -> None: - """Validates the base-P decomposition of an integer with known-answer tests.""" - # Convert the list of expected integer values to a list of Fp objects for comparison. - expected_limbs = [Fp(value=v) for v in expected_values] - # Perform the decomposition. - actual_limbs = int_to_base_p(value, num_limbs) - # Assert that the result matches the expected output. - assert actual_limbs == expected_limbs - - -def test_int_to_base_p_roundtrip() -> None: - """Ensures that the base-P decomposition is perfectly reversible.""" - # Create a large, random multi-limb integer. - num_limbs = 5 - original_limbs = [secrets.randbelow(P) for _ in range(num_limbs)] - original_value = sum(val * (P**i) for i, val in enumerate(original_limbs)) - - # Decompose the integer into base-P limbs using the function under test. - decomposed_limbs_fp = int_to_base_p(original_value, num_limbs) - decomposed_limbs = [int(fp) for fp in decomposed_limbs_fp] - - # Reconstruct the integer from the decomposed limbs. - reconstructed_value = sum(val * (P**i) for i, val in enumerate(decomposed_limbs)) - - # Assert that the original and reconstructed values are identical. - assert original_value == reconstructed_value - # Also assert that the original and decomposed limbs match. - assert original_limbs == decomposed_limbs - - -@pytest.mark.parametrize( - "log_lifetime, desired_activation, desired_num, expected_start_tree, expected_end_tree", - [ - # Test case 1: Request falls on boundary, minimum duration - (8, 0, 16, 0, 2), # C = 16, requested [0, 16), aligned [0, 32) = 2 trees - # Test case 2: Request needs rounding - (8, 10, 5, 0, 2), # C = 16, requested [10, 15), aligned [0, 32) = 2 trees - # Test case 3: Larger request - (8, 0, 100, 0, 7), # C = 16, requested [0, 100), aligned [0, 112) = 7 trees - # Test case 4: Request that exceeds lifetime - (4, 0, 300, 0, 4), # C = 4, LIFETIME = 16, clamped to [0, 16) = 4 trees - # Test case 5: Request in middle - (8, 32, 16, 2, 4), # C = 16, requested [32, 48), aligned [32, 48) = 2 trees - ], -) -def test_expand_activation_time( - log_lifetime: int, - desired_activation: int, - desired_num: int, - expected_start_tree: int, - expected_end_tree: int, -) -> None: - """Tests that expand_activation_time correctly aligns and expands activation intervals.""" - start_tree, end_tree = expand_activation_time(log_lifetime, desired_activation, desired_num) - assert start_tree == expected_start_tree - assert end_tree == expected_end_tree - - # Verify minimum duration constraint (at least 2 bottom trees) - assert end_tree - start_tree >= 2 - - # Verify alignment - c = 1 << (log_lifetime // 2) - actual_start_slot = start_tree * c - actual_end_slot = end_tree * c - assert actual_start_slot % c == 0 - assert actual_end_slot % c == 0 - - # Verify it covers the desired range (if the desired range fits within lifetime) - lifetime = c * c - desired_end_slot = desired_activation + desired_num - if desired_end_slot <= lifetime: - assert actual_start_slot <= desired_activation - assert actual_end_slot >= desired_end_slot - else: - # If desired range exceeds lifetime, verify it's clamped to lifetime bounds - assert actual_start_slot >= 0 - assert actual_end_slot <= lifetime - - -def test_hash_subtree_from_prf_key() -> None: - """Tests that HashSubTree.from_prf_key generates a valid bottom tree.""" - config = TEST_CONFIG - - # Generate a PRF key - prf_key = TEST_PRF.key_gen() - - # Generate a random parameter - parameter = Parameter( - data=[Fp(value=secrets.randbelow(P)) for _ in range(config.PARAMETER_LEN)] - ) - - # Generate bottom tree 0 - bottom_tree = HashSubTree.from_prf_key( - prf=TEST_PRF, - hasher=TEST_TWEAK_HASHER, - rand=TEST_RAND, - config=config, - prf_key=prf_key, - bottom_tree_index=Uint64(0), - parameter=parameter, - ) - - # Verify structure - assert bottom_tree.depth == Uint64(config.LOG_LIFETIME) - assert bottom_tree.lowest_layer == Uint64(0) - assert len(bottom_tree.layers) > 0 - - # Verify the root layer has exactly one node - root_layer = bottom_tree.layers.data[-1] - assert len(root_layer.nodes) == 1 - - # Verify the leaf layer covers the right range - leafs_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) - leaf_layer = bottom_tree.layers.data[0] - assert len(leaf_layer.nodes) == leafs_per_bottom_tree - - -def test_hash_subtree_from_prf_key_deterministic() -> None: - """Tests that HashSubTree.from_prf_key is deterministic.""" - config = TEST_CONFIG - prf_key = TEST_PRF.key_gen() - parameter = Parameter( - data=[Fp(value=secrets.randbelow(P)) for _ in range(config.PARAMETER_LEN)] - ) - - # Generate the same bottom tree twice - tree1 = HashSubTree.from_prf_key( - prf=TEST_PRF, - hasher=TEST_TWEAK_HASHER, - rand=TEST_RAND, - config=config, - prf_key=prf_key, - bottom_tree_index=Uint64(0), - parameter=parameter, - ) - - tree2 = HashSubTree.from_prf_key( - prf=TEST_PRF, - hasher=TEST_TWEAK_HASHER, - rand=TEST_RAND, - config=config, - prf_key=prf_key, - bottom_tree_index=Uint64(0), - parameter=parameter, - ) - - # Verify the roots are identical - assert tree1.layers.data[-1].nodes[0] == tree2.layers.data[-1].nodes[0] - - -def test_hash_subtree_from_prf_key_different_indices() -> None: - """Tests that different bottom tree indices produce different trees.""" - config = TEST_CONFIG - prf_key = TEST_PRF.key_gen() - parameter = Parameter( - data=[Fp(value=secrets.randbelow(P)) for _ in range(config.PARAMETER_LEN)] - ) - - # Generate two different bottom trees - tree0 = HashSubTree.from_prf_key( - prf=TEST_PRF, - hasher=TEST_TWEAK_HASHER, - rand=TEST_RAND, - config=config, - prf_key=prf_key, - bottom_tree_index=Uint64(0), - parameter=parameter, - ) - - tree1 = HashSubTree.from_prf_key( - prf=TEST_PRF, - hasher=TEST_TWEAK_HASHER, - rand=TEST_RAND, - config=config, - prf_key=prf_key, - bottom_tree_index=Uint64(1), - parameter=parameter, - ) - - # Verify the roots are different - assert tree0.layers.data[-1].nodes[0] != tree1.layers.data[-1].nodes[0]